diff --git a/payjoin/src/core/receive/v2/mod.rs b/payjoin/src/core/receive/v2/mod.rs index 569ca7d6e..2c3f1bd5b 100644 --- a/payjoin/src/core/receive/v2/mod.rs +++ b/payjoin/src/core/receive/v2/mod.rs @@ -721,30 +721,25 @@ impl Receiver { Receiver, > { match self.state.original.check_inputs_not_owned(is_owned) { - Ok(inner) => inner, + Ok(()) => MaybeFatalTransition::success( + SessionEvent::CheckedInputsNotOwned(), + Receiver { + state: MaybeInputsSeen { original: self.original.clone() }, + session_context: self.session_context, + }, + ), Err(e) => match e { - Error::Implementation(_) => { - return MaybeFatalTransition::transient(e); - } - _ => { - return MaybeFatalTransition::replyable_error( - SessionEvent::GotReplyableError((&e).into()), - Receiver { - state: HasReplyableError { error_reply: (&e).into() }, - session_context: self.session_context, - }, - e, - ); - } - }, - }; - MaybeFatalTransition::success( - SessionEvent::CheckedInputsNotOwned(), - Receiver { - state: MaybeInputsSeen { original: self.original.clone() }, - session_context: self.session_context, + Error::Implementation(_) => MaybeFatalTransition::transient(e), + _ => MaybeFatalTransition::replyable_error( + SessionEvent::GotReplyableError((&e).into()), + Receiver { + state: HasReplyableError { error_reply: (&e).into() }, + session_context: self.session_context, + }, + e, + ), }, - ) + } } pub(crate) fn apply_checked_inputs_not_owned(self) -> ReceiveSession { @@ -783,30 +778,25 @@ impl Receiver { Receiver, > { match self.state.original.check_no_inputs_seen_before(is_known) { - Ok(inner) => inner, + Ok(()) => MaybeFatalTransition::success( + SessionEvent::CheckedNoInputsSeenBefore(), + Receiver { + state: OutputsUnknown { original: self.original.clone() }, + session_context: self.session_context, + }, + ), Err(e) => match e { - Error::Implementation(_) => { - return MaybeFatalTransition::transient(e); - } - _ => { - return MaybeFatalTransition::replyable_error( - SessionEvent::GotReplyableError((&e).into()), - Receiver { - state: HasReplyableError { error_reply: (&e).into() }, - session_context: self.session_context, - }, - e, - ); - } - }, - }; - MaybeFatalTransition::success( - SessionEvent::CheckedNoInputsSeenBefore(), - Receiver { - state: OutputsUnknown { original: self.original.clone() }, - session_context: self.session_context, + Error::Implementation(_) => MaybeFatalTransition::transient(e), + _ => MaybeFatalTransition::replyable_error( + SessionEvent::GotReplyableError((&e).into()), + Receiver { + state: HasReplyableError { error_reply: (&e).into() }, + session_context: self.session_context, + }, + e, + ), }, - ) + } } pub(crate) fn apply_checked_no_inputs_seen_before(self) -> ReceiveSession { @@ -849,28 +839,23 @@ impl Receiver { Error, Receiver, > { - let inner = match self.state.original.identify_receiver_outputs(is_receiver_output) { - Ok(inner) => inner, + match self.state.original.identify_receiver_outputs(is_receiver_output) { + Ok(inner) => MaybeFatalTransition::success( + SessionEvent::IdentifiedReceiverOutputs(inner.owned_vouts.clone()), + Receiver { state: WantsOutputs { inner }, session_context: self.session_context }, + ), Err(e) => match e { - Error::Implementation(_) => { - return MaybeFatalTransition::transient(e); - } - _ => { - return MaybeFatalTransition::replyable_error( - SessionEvent::GotReplyableError((&e).into()), - Receiver { - state: HasReplyableError { error_reply: (&e).into() }, - session_context: self.session_context, - }, - e, - ); - } + Error::Implementation(_) => MaybeFatalTransition::transient(e), + _ => MaybeFatalTransition::replyable_error( + SessionEvent::GotReplyableError((&e).into()), + Receiver { + state: HasReplyableError { error_reply: (&e).into() }, + session_context: self.session_context, + }, + e, + ), }, - }; - MaybeFatalTransition::success( - SessionEvent::IdentifiedReceiverOutputs(inner.owned_vouts.clone()), - Receiver { state: WantsOutputs { inner }, session_context: self.session_context }, - ) + } } pub(crate) fn apply_identified_receiver_outputs( @@ -1054,23 +1039,20 @@ impl Receiver { ) -> MaybeFatalTransition, ProtocolError> { let max_effective_fee_rate = max_effective_fee_rate.or(Some(self.session_context.max_fee_rate)); - let psbt_context = match self + match self .state .inner .calculate_psbt_context_with_fee_range(min_fee_rate, max_effective_fee_rate) { - Ok(inner) => inner, - Err(e) => { - return MaybeFatalTransition::transient(ProtocolError::OriginalPayload(e.into())); - } - }; - MaybeFatalTransition::success( - SessionEvent::AppliedFeeRange(psbt_context.clone()), - Receiver { - state: ProvisionalProposal { psbt_context }, - session_context: self.session_context, - }, - ) + Ok(psbt_context) => MaybeFatalTransition::success( + SessionEvent::AppliedFeeRange(psbt_context.clone()), + Receiver { + state: ProvisionalProposal { psbt_context }, + session_context: self.session_context, + }, + ), + Err(e) => MaybeFatalTransition::transient(ProtocolError::OriginalPayload(e.into())), + } } pub(crate) fn apply_applied_fee_range(self, psbt_context: PsbtContext) -> ReceiveSession { @@ -1103,16 +1085,16 @@ impl Receiver { ) -> MaybeTransientTransition, ImplementationError> { let original_psbt = self.state.psbt_context.original_psbt.clone(); - let inner = match self.state.psbt_context.finalize_proposal(wallet_process_psbt) { - Ok(inner) => inner, + let payjoin_psbt = match self.state.psbt_context.finalize_proposal(wallet_process_psbt) { + Ok(payjoin_psbt) => payjoin_psbt, Err(e) => { return MaybeTransientTransition::transient(e); } }; - let psbt_context = PsbtContext { payjoin_psbt: inner.clone(), original_psbt }; + let psbt_context = PsbtContext { payjoin_psbt: payjoin_psbt.clone(), original_psbt }; let payjoin_proposal = PayjoinProposal { psbt_context: psbt_context.clone() }; MaybeTransientTransition::success( - SessionEvent::FinalizedProposal(inner), + SessionEvent::FinalizedProposal(payjoin_psbt), Receiver { state: payjoin_proposal, session_context: self.session_context }, ) } @@ -1599,20 +1581,22 @@ pub mod test { Ok(ret) } - let maybe_inputs_seen = - receiver.check_inputs_not_owned(&mut |_| mock_callback(&mut call_count, false)); + let maybe_inputs_seen = receiver + .check_inputs_not_owned(&mut |_| mock_callback(&mut call_count, false)) + .save(&persister) + .expect("Persister shouldn't fail"); assert_eq!(call_count, 1); let outputs_unknown = maybe_inputs_seen - .save(&persister) - .expect("Persister shouldn't fail") .check_no_inputs_seen_before(&mut |_| mock_callback(&mut call_count, false)) .save(&persister) .expect("Persister shouldn't fail"); assert_eq!(call_count, 2); let _wants_outputs = outputs_unknown - .identify_receiver_outputs(&mut |_| mock_callback(&mut call_count, true)); + .identify_receiver_outputs(&mut |_| mock_callback(&mut call_count, true)) + .save(&persister) + .expect("Persister shouldn't fail"); // there are 2 receiver outputs so we should expect this callback to run twice incrementing // call count twice assert_eq!(call_count, 4);