diff --git a/payjoin-cli/src/app/v1.rs b/payjoin-cli/src/app/v1.rs index f557facd3..41681a6d1 100644 --- a/payjoin-cli/src/app/v1.rs +++ b/payjoin-cli/src/app/v1.rs @@ -13,7 +13,7 @@ use hyper_util::rt::TokioIo; use payjoin::bitcoin::consensus::encode::serialize_hex; use payjoin::bitcoin::{Amount, FeeRate}; use payjoin::receive::v1::{PayjoinProposal, UncheckedOriginalPayload}; -use payjoin::receive::Error; +use payjoin::receive::{check_references, Error}; use payjoin::send::v1::SenderBuilder; use payjoin::{ImplementationError, IntoUrl, Uri, UriExt}; use tokio::net::TcpListener; @@ -347,33 +347,37 @@ impl App { let wallet = self.wallet(); // Receive Check 1: Can Broadcast - let proposal = proposal.check_broadcast_suitability(None, |tx| { - wallet - .can_broadcast(tx) - .map_err(|e| ImplementationError::from(e.into_boxed_dyn_error())) - })?; + let is_broadcast_suitable = wallet + .can_broadcast(&proposal.extract_tx_to_check_broadcast_suitability()) + .map_err(|e| ImplementationError::from(e.into_boxed_dyn_error()))?; + let proposal = proposal.apply_broadcast_suitability(None, is_broadcast_suitable)?; tracing::trace!("check1"); // in a payment processor where the sender could go offline, this is where you schedule to broadcast the original_tx let _to_broadcast_in_failure_case = proposal.extract_tx_to_schedule_broadcast(); // Receive Check 2: receiver can't sign for proposal inputs - let proposal = proposal.check_inputs_not_owned(&mut |input| { + let refs = proposal.get_input_script_refs()?; + let checked_refs = check_references(refs, &mut |input| { wallet.is_mine(input).map_err(|e| ImplementationError::from(e.into_boxed_dyn_error())) })?; + let proposal = proposal.apply_input_owned_checks(checked_refs)?; tracing::trace!("check2"); // Receive Check 3: have we seen this input before? More of a check for non-interactive i.e. payment processor receivers. - let payjoin = proposal.check_no_inputs_seen_before(&mut |input| { - Ok(self.db.insert_input_seen_before(*input)?) - })?; + let refs = proposal.get_input_outpoint_refs(); + let checked_refs = + check_references(refs, &mut |input| Ok(self.db.insert_input_seen_before(*input)?))?; + let payjoin = proposal.apply_input_seen_checks(checked_refs)?; tracing::trace!("check3"); - let payjoin = payjoin.identify_receiver_outputs(&mut |output_script| { + let refs = payjoin.get_output_script_refs(); + let checked_refs = check_references(refs, &mut |output_script| { wallet .is_mine(output_script) .map_err(|e| ImplementationError::from(e.into_boxed_dyn_error())) })?; + let payjoin = payjoin.apply_output_owned_checks(checked_refs)?; let payjoin = payjoin .substitute_receiver_script( diff --git a/payjoin-cli/src/app/v2/mod.rs b/payjoin-cli/src/app/v2/mod.rs index 67dca2910..739418618 100644 --- a/payjoin-cli/src/app/v2/mod.rs +++ b/payjoin-cli/src/app/v2/mod.rs @@ -5,6 +5,7 @@ use anyhow::{anyhow, Context, Result}; use payjoin::bitcoin::consensus::encode::serialize_hex; use payjoin::bitcoin::{Amount, FeeRate}; use payjoin::persist::{OptionalTransitionOutcome, SessionPersister}; +use payjoin::receive::check_references; use payjoin::receive::v2::{ replay_event_log as replay_receiver_event_log, HasReplyableError, Initialized, MaybeInputsOwned, MaybeInputsSeen, Monitor, OutputsUnknown, PayjoinProposal, @@ -702,13 +703,11 @@ impl App { persister: &ReceiverPersister, ) -> Result<()> { let wallet = self.wallet(); - let proposal = proposal - .check_broadcast_suitability(None, |tx| { - wallet - .can_broadcast(tx) - .map_err(|e| ImplementationError::from(e.into_boxed_dyn_error())) - }) - .save(persister)?; + let is_broadcast_suitable = wallet + .can_broadcast(&proposal.extract_tx_to_check_broadcast_suitability()) + .map_err(|e| ImplementationError::from(e.into_boxed_dyn_error()))?; + let proposal = + proposal.apply_broadcast_suitability(None, is_broadcast_suitable).save(persister)?; println!("Fallback transaction received. Consider broadcasting this to get paid if the Payjoin fails:"); println!("{}", serialize_hex(&proposal.extract_tx_to_schedule_broadcast())); @@ -721,13 +720,11 @@ impl App { persister: &ReceiverPersister, ) -> Result<()> { let wallet = self.wallet(); - let proposal = proposal - .check_inputs_not_owned(&mut |input| { - wallet - .is_mine(input) - .map_err(|e| ImplementationError::from(e.into_boxed_dyn_error())) - }) - .save(persister)?; + let refs = proposal.get_input_script_refs()?; + let checked_refs = check_references(refs, &mut |input| { + wallet.is_mine(input).map_err(|e| ImplementationError::from(e.into_boxed_dyn_error())) + })?; + let proposal = proposal.apply_input_owned_checks(checked_refs).save(persister)?; self.check_no_inputs_seen_before(proposal, persister).await } @@ -736,11 +733,10 @@ impl App { proposal: Receiver, persister: &ReceiverPersister, ) -> Result<()> { - let proposal = proposal - .check_no_inputs_seen_before(&mut |input| { - Ok(self.db.insert_input_seen_before(*input)?) - }) - .save(persister)?; + let refs = proposal.get_input_outpoint_refs(); + let checked_refs = + check_references(refs, &mut |input| Ok(self.db.insert_input_seen_before(*input)?))?; + let proposal = proposal.apply_input_seen_checks(checked_refs).save(persister)?; self.identify_receiver_outputs(proposal, persister).await } @@ -750,13 +746,13 @@ impl App { persister: &ReceiverPersister, ) -> Result<()> { let wallet = self.wallet(); - let proposal = proposal - .identify_receiver_outputs(&mut |output_script| { - wallet - .is_mine(output_script) - .map_err(|e| ImplementationError::from(e.into_boxed_dyn_error())) - }) - .save(persister)?; + let refs = proposal.get_output_script_refs(); + let checked_refs = check_references(refs, &mut |output_script| { + wallet + .is_mine(output_script) + .map_err(|e| ImplementationError::from(e.into_boxed_dyn_error())) + })?; + let proposal = proposal.apply_output_owned_checks(checked_refs).save(persister)?; self.commit_outputs(proposal, persister).await } @@ -804,13 +800,11 @@ impl App { persister: &ReceiverPersister, ) -> Result<()> { let wallet = self.wallet(); - let proposal = proposal - .finalize_proposal(|psbt| { - wallet - .process_psbt(psbt) - .map_err(|e| ImplementationError::from(e.into_boxed_dyn_error())) - }) - .save(persister)?; + let psbt = proposal.psbt_to_sign(); + let signed_psbt = wallet + .process_psbt(&psbt) + .map_err(|e| ImplementationError::from(e.into_boxed_dyn_error()))?; + let proposal = proposal.finalize_signed_proposal(&signed_psbt).save(persister)?; self.send_payjoin_proposal(proposal, persister).await } @@ -847,27 +841,32 @@ impl App { tracing::debug!("Polling for payment confirmation"); + let fallback_txid = proposal.extract_fallback_txid(); + let payjoin_txid = proposal.extract_payjoin_proposal_txid(); + let get_raw_tx = |txid| { + self.wallet() + .get_raw_transaction(&txid) + .map_err(|e| ImplementationError::from(e.into_boxed_dyn_error())) + }; + match proposal.check_fallback_monitorable().save(persister)? { + OptionalTransitionOutcome::Progress(_) => { + println!("Unable to monitor for fallback tx containing non-segwit inputs, completing session"); + return Ok(()); + } + OptionalTransitionOutcome::Stasis(_) => {} + } let result = tokio::time::timeout(timeout_duration, async { loop { interval.tick().await; - let check_result = proposal - .check_payment(|txid| { - self.wallet() - .get_raw_transaction(&txid) - .map_err(|e| ImplementationError::from(e.into_boxed_dyn_error())) - }) - .save(persister); - - match check_result { - Ok(_) => { - println!("Payjoin transaction detected in the mempool!"); - return Ok(()); - } - Err(_) => { - // keep polling - - continue; - } + if let Some(tx) = get_raw_tx(payjoin_txid)? { + proposal.payjoin_tx_exists(tx).save(persister)?; + println!("Payjoin transaction detected in the mempool!"); + return Ok(()); + }; + if get_raw_tx(fallback_txid)?.is_some() { + proposal.fallback_tx_exists().save(persister)?; + println!("Fallback transaction detected in the mempool!"); + return Ok(()); } } }) diff --git a/payjoin-ffi/csharp/IntegrationTests.cs b/payjoin-ffi/csharp/IntegrationTests.cs index 032686f65..10951837c 100644 --- a/payjoin-ffi/csharp/IntegrationTests.cs +++ b/payjoin-ffi/csharp/IntegrationTests.cs @@ -5,6 +5,12 @@ namespace Payjoin.Tests { + public enum TransitionMode + { + Callback, + Nonblocking, + } + public class IntegrationTests : IAsyncLifetime { private static string RpcCall(RpcClient rpc, string method, params string?[] args) => rpc.Call(method, args); @@ -168,6 +174,7 @@ private static InputPair[] GetInputs(RpcClient rpc) RpcClient receiverRpc, InMemoryReceiverPersister recvPersister, string ohttpRelay, + TransitionMode mode, CancellationToken cancellationToken) { var request = receiver.CreatePollRequest(ohttpRelay); @@ -192,7 +199,7 @@ private static InputPair[] GetInputs(RpcClient rpc) if (outcome is InitializedTransitionOutcome.Progress progress) { using var proposal = progress.inner; - return await ProcessUncheckedProposal(proposal, receiverRpc, recvPersister); + return await ProcessUncheckedProposal(proposal, receiverRpc, recvPersister, mode); } throw new InvalidOperationException("Unknown initialized transition outcome"); @@ -201,88 +208,157 @@ private static InputPair[] GetInputs(RpcClient rpc) private Task ProcessUncheckedProposal( UncheckedOriginalPayload proposal, RpcClient receiverRpc, - InMemoryReceiverPersister recvPersister) + InMemoryReceiverPersister recvPersister, + TransitionMode mode) { - using var checkedTransition = proposal.CheckBroadcastSuitability(null, new MempoolAcceptanceCallback(receiverRpc)); - using var maybeInputsOwned = checkedTransition.Save(recvPersister); + MaybeInputsOwned maybeInputsOwned; + + if (mode == TransitionMode.Callback) + { + using var checkedTransition = proposal.CheckBroadcastSuitability(null, new MempoolAcceptanceCallback(receiverRpc)); + maybeInputsOwned = checkedTransition.Save(recvPersister); + } + else + { + var canBroadcast = new MempoolAcceptanceCallback(receiverRpc).Callback(proposal.ExtractTxToCheckBroadcastSuitability()); + using var checkedTransition = proposal.ApplyBroadcastSuitability(null, canBroadcast); + maybeInputsOwned = checkedTransition.Save(recvPersister); + } - return ProcessMaybeInputsOwned(maybeInputsOwned, receiverRpc, recvPersister); + return ProcessMaybeInputsOwned(maybeInputsOwned, receiverRpc, recvPersister, mode); } private Task ProcessMaybeInputsOwned( MaybeInputsOwned proposal, RpcClient receiverRpc, - InMemoryReceiverPersister recvPersister) + InMemoryReceiverPersister recvPersister, + TransitionMode mode) { - using var transition = proposal.CheckInputsNotOwned(new IsScriptOwnedCallback(receiverRpc)); - using var maybeInputsSeen = transition.Save(recvPersister); + MaybeInputsSeen maybeInputsSeen; - return ProcessMaybeInputsSeen(maybeInputsSeen, receiverRpc, recvPersister); + if (mode == TransitionMode.Callback) + { + using var transition = proposal.CheckInputsNotOwned(new IsScriptOwnedCallback(receiverRpc)); + maybeInputsSeen = transition.Save(recvPersister); + } + else + { + var taggedRefs = proposal.GetInputScriptRefs() + .Select(r => r.Mark(new IsScriptOwnedCallback(receiverRpc).Callback(r.GetValue()))) + .ToArray(); + using var transition = proposal.ApplyInputOwnedChecks(taggedRefs); + maybeInputsSeen = transition.Save(recvPersister); + } + + return ProcessMaybeInputsSeen(maybeInputsSeen, receiverRpc, recvPersister, mode); } private Task ProcessMaybeInputsSeen( MaybeInputsSeen proposal, RpcClient receiverRpc, - InMemoryReceiverPersister recvPersister) + InMemoryReceiverPersister recvPersister, + TransitionMode mode) { - using var transition = proposal.CheckNoInputsSeenBefore(new CheckInputsNotSeenCallback()); - using var outputsUnknown = transition.Save(recvPersister); + OutputsUnknown outputsUnknown; - return ProcessOutputsUnknown(outputsUnknown, receiverRpc, recvPersister); + if (mode == TransitionMode.Callback) + { + using var transition = proposal.CheckNoInputsSeenBefore(new CheckInputsNotSeenCallback()); + outputsUnknown = transition.Save(recvPersister); + } + else + { + var taggedRefs = proposal.GetInputOutpointRefs() + .Select(r => r.Mark(new CheckInputsNotSeenCallback().Callback(r.GetValue()))) + .ToArray(); + using var transition = proposal.ApplyInputSeenChecks(taggedRefs); + outputsUnknown = transition.Save(recvPersister); + } + + return ProcessOutputsUnknown(outputsUnknown, receiverRpc, recvPersister, mode); } private Task ProcessOutputsUnknown( OutputsUnknown proposal, RpcClient receiverRpc, - InMemoryReceiverPersister recvPersister) + InMemoryReceiverPersister recvPersister, + TransitionMode mode) { - using var transition = proposal.IdentifyReceiverOutputs(new IsScriptOwnedCallback(receiverRpc)); - using var wantsOutputs = transition.Save(recvPersister); + WantsOutputs wantsOutputs; - return ProcessWantsOutputs(wantsOutputs, receiverRpc, recvPersister); + if (mode == TransitionMode.Callback) + { + using var transition = proposal.IdentifyReceiverOutputs(new IsScriptOwnedCallback(receiverRpc)); + wantsOutputs = transition.Save(recvPersister); + } + else + { + var taggedRefs = proposal.GetOutputScriptRefs() + .Select(r => r.Mark(new IsScriptOwnedCallback(receiverRpc).Callback(r.GetValue()))) + .ToArray(); + using var transition = proposal.ApplyOutputOwnedChecks(taggedRefs); + wantsOutputs = transition.Save(recvPersister); + } + + return ProcessWantsOutputs(wantsOutputs, receiverRpc, recvPersister, mode); } private Task ProcessWantsOutputs( WantsOutputs proposal, RpcClient receiverRpc, - InMemoryReceiverPersister recvPersister) + InMemoryReceiverPersister recvPersister, + TransitionMode mode) { using var transition = proposal.CommitOutputs(); using var wantsInputs = transition.Save(recvPersister); - return ProcessWantsInputs(wantsInputs, receiverRpc, recvPersister); + return ProcessWantsInputs(wantsInputs, receiverRpc, recvPersister, mode); } private Task ProcessWantsInputs( WantsInputs proposal, RpcClient receiverRpc, - InMemoryReceiverPersister recvPersister) + InMemoryReceiverPersister recvPersister, + TransitionMode mode) { using var contributed = proposal.ContributeInputs(GetInputs(receiverRpc)); using var transition = contributed.CommitInputs(); using var wantsFeeRange = transition.Save(recvPersister); - return ProcessWantsFeeRange(wantsFeeRange, receiverRpc, recvPersister); + return ProcessWantsFeeRange(wantsFeeRange, receiverRpc, recvPersister, mode); } private Task ProcessWantsFeeRange( WantsFeeRange proposal, RpcClient receiverRpc, - InMemoryReceiverPersister recvPersister) + InMemoryReceiverPersister recvPersister, + TransitionMode mode) { using var transition = proposal.ApplyFeeRange(1, 10); using var provisional = transition.Save(recvPersister); - return ProcessProvisionalProposal(provisional, receiverRpc, recvPersister); + return ProcessProvisionalProposal(provisional, receiverRpc, recvPersister, mode); } private Task ProcessProvisionalProposal( ProvisionalProposal proposal, RpcClient receiverRpc, - InMemoryReceiverPersister recvPersister) + InMemoryReceiverPersister recvPersister, + TransitionMode mode) { - using var transition = proposal.FinalizeProposal(new ProcessPsbtCallback(receiverRpc)); - var payjoinProposal = transition.Save(recvPersister); + PayjoinProposal payjoinProposal; + + if (mode == TransitionMode.Callback) + { + using var transition = proposal.FinalizeProposal(new ProcessPsbtCallback(receiverRpc)); + payjoinProposal = transition.Save(recvPersister); + } + else + { + var signedPsbt = new ProcessPsbtCallback(receiverRpc).Callback(proposal.PsbtToSign()); + using var transition = proposal.FinalizeSignedProposal(signedPsbt); + payjoinProposal = transition.Save(recvPersister); + } return Task.FromResult(payjoinProposal); } @@ -438,8 +514,10 @@ public void TestFfiValidation() }); } - [Fact] - public async Task TestIntegrationV2ToV2() + [Theory] + [InlineData(TransitionMode.Callback)] + [InlineData(TransitionMode.Nonblocking)] + public async Task TestIntegrationV2ToV2(TransitionMode mode) { var cancellationToken = TestContext.Current.CancellationToken; @@ -465,7 +543,7 @@ public async Task TestIntegrationV2ToV2() using var receiveTransition = receiverBuilder.Build(); using var session = receiveTransition.Save(recvPersister); - var initial = await RetrieveReceiverProposal(session, receiver, recvPersister, ohttpRelay, cancellationToken); + var initial = await RetrieveReceiverProposal(session, receiver, recvPersister, ohttpRelay, mode, cancellationToken); Assert.Null(initial); // ***************************** @@ -500,7 +578,7 @@ public async Task TestIntegrationV2ToV2() // ********************* // RECEIVER SIDE // Poll for the proposal - using var payjoinProposal = await RetrieveReceiverProposal(session, receiver, recvPersister, ohttpRelay, cancellationToken); + using var payjoinProposal = await RetrieveReceiverProposal(session, receiver, recvPersister, ohttpRelay, mode, cancellationToken); Assert.NotNull(payjoinProposal); Assert.IsType(payjoinProposal); diff --git a/payjoin-ffi/dart/test/test_payjoin_integration_test.dart b/payjoin-ffi/dart/test/test_payjoin_integration_test.dart index b40d40ea9..df6ee4aef 100644 --- a/payjoin-ffi/dart/test/test_payjoin_integration_test.dart +++ b/payjoin-ffi/dart/test/test_payjoin_integration_test.dart @@ -15,6 +15,8 @@ late test_utils.BitcoindInstance bitcoind; late test_utils.RpcClient receiver; late test_utils.RpcClient sender; +enum TransitionMode { callback, nonblocking } + class MempoolAcceptanceCallback implements payjoin.CanBroadcast { final payjoin.RpcClient connection; @@ -202,91 +204,186 @@ List get_inputs(payjoin.RpcClient rpc_connection) { Future process_provisional_proposal( payjoin.ProvisionalProposal proposal, InMemoryReceiverPersister recv_persister, + TransitionMode mode, ) async { - final payjoin_proposal = proposal - .finalizeProposal(processPsbt: ProcessPsbtCallback(receiver)) - .save(persister: recv_persister); + final payjoin.PayjoinProposal payjoin_proposal; + if (mode == TransitionMode.callback) { + payjoin_proposal = proposal + .finalizeProposal(processPsbt: ProcessPsbtCallback(receiver)) + .save(persister: recv_persister); + } else { + final signed_psbt = ProcessPsbtCallback( + receiver, + ).callback(proposal.psbtToSign()); + payjoin_proposal = proposal + .finalizeSignedProposal(signedPsbt: signed_psbt) + .save(persister: recv_persister); + } return payjoin.PayjoinProposalReceiveSession(payjoin_proposal); } Future process_wants_fee_range( payjoin.WantsFeeRange proposal, InMemoryReceiverPersister recv_persister, + TransitionMode mode, ) async { final wants_fee_range = proposal .applyFeeRange(minFeeRateSatPerVb: 1, maxEffectiveFeeRateSatPerVb: 10) .save(persister: recv_persister); - return await process_provisional_proposal(wants_fee_range, recv_persister); + return await process_provisional_proposal( + wants_fee_range, + recv_persister, + mode, + ); } Future process_wants_inputs( payjoin.WantsInputs proposal, InMemoryReceiverPersister recv_persister, + TransitionMode mode, ) async { final provisional_proposal = proposal .contributeInputs(replacementInputs: get_inputs(receiver)) .commitInputs() .save(persister: recv_persister); - return await process_wants_fee_range(provisional_proposal, recv_persister); + return await process_wants_fee_range( + provisional_proposal, + recv_persister, + mode, + ); } Future process_wants_outputs( payjoin.WantsOutputs proposal, InMemoryReceiverPersister recv_persister, + TransitionMode mode, ) async { final wants_inputs = proposal.commitOutputs().save(persister: recv_persister); - return await process_wants_inputs(wants_inputs, recv_persister); + return await process_wants_inputs(wants_inputs, recv_persister, mode); } Future process_outputs_unknown( payjoin.OutputsUnknown proposal, InMemoryReceiverPersister recv_persister, + TransitionMode mode, ) async { - final wants_outputs = proposal - .identifyReceiverOutputs( - isReceiverOutput: IsScriptOwnedCallback(receiver), - ) - .save(persister: recv_persister); - return await process_wants_outputs(wants_outputs, recv_persister); + final payjoin.WantsOutputs wants_outputs; + if (mode == TransitionMode.callback) { + wants_outputs = proposal + .identifyReceiverOutputs( + isReceiverOutput: IsScriptOwnedCallback(receiver), + ) + .save(persister: recv_persister); + } else { + final tagged_refs = proposal + .getOutputScriptRefs() + .map( + (ref) => ref.mark( + result: IsScriptOwnedCallback(receiver).callback(ref.getValue()), + ), + ) + .toList(); + wants_outputs = proposal + .applyOutputOwnedChecks(checkedOutputScripts: tagged_refs) + .save(persister: recv_persister); + } + return await process_wants_outputs(wants_outputs, recv_persister, mode); } Future process_maybe_inputs_seen( payjoin.MaybeInputsSeen proposal, InMemoryReceiverPersister recv_persister, + TransitionMode mode, ) async { - final outputs_unknown = proposal - .checkNoInputsSeenBefore(isKnown: CheckInputsNotSeenCallback(receiver)) - .save(persister: recv_persister); - return await process_outputs_unknown(outputs_unknown, recv_persister); + final payjoin.OutputsUnknown outputs_unknown; + if (mode == TransitionMode.callback) { + outputs_unknown = proposal + .checkNoInputsSeenBefore(isKnown: CheckInputsNotSeenCallback(receiver)) + .save(persister: recv_persister); + } else { + final tagged_refs = proposal + .getInputOutpointRefs() + .map( + (ref) => ref.mark( + result: CheckInputsNotSeenCallback( + receiver, + ).callback(ref.getValue()), + ), + ) + .toList(); + outputs_unknown = proposal + .applyInputSeenChecks(checkedInputOutpoints: tagged_refs) + .save(persister: recv_persister); + } + return await process_outputs_unknown(outputs_unknown, recv_persister, mode); } Future process_maybe_inputs_owned( payjoin.MaybeInputsOwned proposal, InMemoryReceiverPersister recv_persister, + TransitionMode mode, ) async { - final maybe_inputs_owned = proposal - .checkInputsNotOwned(isOwned: IsScriptOwnedCallback(receiver)) - .save(persister: recv_persister); - return await process_maybe_inputs_seen(maybe_inputs_owned, recv_persister); + final payjoin.MaybeInputsSeen maybe_inputs_owned; + if (mode == TransitionMode.callback) { + maybe_inputs_owned = proposal + .checkInputsNotOwned(isOwned: IsScriptOwnedCallback(receiver)) + .save(persister: recv_persister); + } else { + final tagged_refs = proposal + .getInputScriptRefs() + .map( + (ref) => ref.mark( + result: IsScriptOwnedCallback(receiver).callback(ref.getValue()), + ), + ) + .toList(); + maybe_inputs_owned = proposal + .applyInputOwnedChecks(checkedInputScripts: tagged_refs) + .save(persister: recv_persister); + } + return await process_maybe_inputs_seen( + maybe_inputs_owned, + recv_persister, + mode, + ); } Future process_unchecked_proposal( payjoin.UncheckedOriginalPayload proposal, InMemoryReceiverPersister recv_persister, + TransitionMode mode, ) async { - final unchecked_proposal = proposal - .checkBroadcastSuitability( - minFeeRateSatPerKwu: null, - canBroadcast: MempoolAcceptanceCallback(receiver), - ) - .save(persister: recv_persister); - return await process_maybe_inputs_owned(unchecked_proposal, recv_persister); + final payjoin.MaybeInputsOwned unchecked_proposal; + if (mode == TransitionMode.callback) { + unchecked_proposal = proposal + .checkBroadcastSuitability( + minFeeRateSatPerKwu: null, + canBroadcast: MempoolAcceptanceCallback(receiver), + ) + .save(persister: recv_persister); + } else { + final can_broadcast = MempoolAcceptanceCallback( + receiver, + ).callback(proposal.extractTxToCheckBroadcastSuitability()); + unchecked_proposal = proposal + .applyBroadcastSuitability( + minFeeRateSatPerKwu: null, + canBroadcast: can_broadcast, + ) + .save(persister: recv_persister); + } + return await process_maybe_inputs_owned( + unchecked_proposal, + recv_persister, + mode, + ); } Future retrieve_receiver_proposal( payjoin.Initialized receiver, InMemoryReceiverPersister recv_persister, String ohttp_relay, + TransitionMode mode, ) async { var agent = http.Client(); var request = receiver.createPollRequest(ohttpRelay: ohttp_relay); @@ -303,7 +400,7 @@ Future retrieve_receiver_proposal( return null; } else if (res is payjoin.ProgressInitializedTransitionOutcome) { var proposal = res.inner; - return await process_unchecked_proposal(proposal, recv_persister); + return await process_unchecked_proposal(proposal, recv_persister, mode); } throw Exception("Unknown initialized transition outcome: $res"); @@ -313,12 +410,14 @@ Future process_receiver_proposal( payjoin.ReceiveSession receiver, InMemoryReceiverPersister recv_persister, String ohttp_relay, + TransitionMode mode, ) async { if (receiver is payjoin.InitializedReceiveSession) { var res = await retrieve_receiver_proposal( receiver.inner, recv_persister, ohttp_relay, + mode, ); if (res == null) { return null; @@ -327,25 +426,41 @@ Future process_receiver_proposal( } if (receiver is payjoin.UncheckedOriginalPayloadReceiveSession) { - return await process_unchecked_proposal(receiver.inner, recv_persister); + return await process_unchecked_proposal( + receiver.inner, + recv_persister, + mode, + ); } if (receiver is payjoin.MaybeInputsOwnedReceiveSession) { - return await process_maybe_inputs_owned(receiver.inner, recv_persister); + return await process_maybe_inputs_owned( + receiver.inner, + recv_persister, + mode, + ); } if (receiver is payjoin.MaybeInputsSeenReceiveSession) { - return await process_maybe_inputs_seen(receiver.inner, recv_persister); + return await process_maybe_inputs_seen( + receiver.inner, + recv_persister, + mode, + ); } if (receiver is payjoin.OutputsUnknownReceiveSession) { - return await process_outputs_unknown(receiver.inner, recv_persister); + return await process_outputs_unknown(receiver.inner, recv_persister, mode); } if (receiver is payjoin.WantsOutputsReceiveSession) { - return await process_wants_outputs(receiver.inner, recv_persister); + return await process_wants_outputs(receiver.inner, recv_persister, mode); } if (receiver is payjoin.WantsInputsReceiveSession) { - return await process_wants_inputs(receiver.inner, recv_persister); + return await process_wants_inputs(receiver.inner, recv_persister, mode); } if (receiver is payjoin.ProvisionalProposalReceiveSession) { - return await process_provisional_proposal(receiver.inner, recv_persister); + return await process_provisional_proposal( + receiver.inner, + recv_persister, + mode, + ); } if (receiver is payjoin.PayjoinProposalReceiveSession) { return receiver; @@ -354,6 +469,167 @@ Future process_receiver_proposal( throw Exception("Unknown receiver state: $receiver"); } +Future run_integration_v2_to_v2(TransitionMode mode) async { + env = test_utils.initBitcoindSenderReceiver(); + bitcoind = env.getBitcoind(); + receiver = env.getReceiver(); + sender = env.getSender(); + var receiver_address = + jsonDecode(receiver.call(method: "getnewaddress", params: [])) as String; + var services = test_utils.TestServices.initialize(); + + services.waitForServicesReady(); + var directory = services.directoryUrl(); + var ohttp_keys = services.fetchOhttpKeys(); + var ohttp_relay = services.ohttpRelayUrl(); + var agent = http.Client(); + + // ********************** + // Inside the Receiver: + var recv_persister = InMemoryReceiverPersister(); + var sender_persister = InMemorySenderPersister(); + var session = create_receiver_context( + receiver_address, + directory, + ohttp_keys, + recv_persister, + ); + var process_response = await process_receiver_proposal( + payjoin.InitializedReceiveSession(session), + recv_persister, + ohttp_relay, + mode, + ); + expect(process_response, isNull); + + // ********************** + // Inside the Sender: + // Create a funded PSBT (not broadcasted) to address with amount given in the pj_uri + var pj_uri = session.pjUri(); + var psbt = build_sweep_psbt(sender, pj_uri); + payjoin.WithReplyKey req_ctx = payjoin.SenderBuilder(psbt: psbt, uri: pj_uri) + .buildRecommended(minFeeRateSatPerKwu: 1000) + .save(persister: sender_persister); + payjoin.RequestOhttpContext request = req_ctx.createV2PostRequest( + ohttpRelay: ohttp_relay, + ); + var response = await agent.post( + Uri.parse(request.request.url), + headers: {"Content-Type": request.request.contentType}, + body: request.request.body, + ); + payjoin.PollingForProposal send_ctx = req_ctx + .processResponse(response: response.bodyBytes, postCtx: request.ohttpCtx) + .save(persister: sender_persister); + // POST Original PSBT + + // ********************** + // Inside the Receiver: + + // GET fallback psbt + payjoin.ReceiveSession? payjoin_proposal = await process_receiver_proposal( + payjoin.InitializedReceiveSession(session), + recv_persister, + ohttp_relay, + mode, + ); + expect(payjoin_proposal, isNotNull); + expect(payjoin_proposal, isA()); + + payjoin.PayjoinProposal proposal = + (payjoin_proposal as payjoin.PayjoinProposalReceiveSession).inner; + payjoin.RequestResponse request_response = proposal.createPostRequest( + ohttpRelay: ohttp_relay, + ); + var fallback_response = await agent.post( + Uri.parse(request_response.request.url), + headers: {"Content-Type": request_response.request.contentType}, + body: request_response.request.body, + ); + proposal.processResponse( + body: fallback_response.bodyBytes, + ohttpContext: request_response.clientResponse, + ); + + // ********************** + // Inside the Sender: + // Sender checks, signs, finalizes, extracts, and broadcasts + // Replay post fallback to get the response + payjoin.PollingForProposalTransitionOutcome? poll_outcome; + var attempts = 0; + while (true) { + payjoin.RequestOhttpContext ohttp_context_request = send_ctx + .createPollRequest(ohttpRelay: ohttp_relay); + var final_response = await agent.post( + Uri.parse(ohttp_context_request.request.url), + headers: {"Content-Type": ohttp_context_request.request.contentType}, + body: ohttp_context_request.request.body, + ); + poll_outcome = send_ctx + .processResponse( + response: final_response.bodyBytes, + ohttpCtx: ohttp_context_request.ohttpCtx, + ) + .save(persister: sender_persister); + + if (poll_outcome is payjoin.ProgressPollingForProposalTransitionOutcome) { + break; + } + + attempts += 1; + if (attempts >= 3) { + // Receiver not ready yet; mirror Python's tolerant polling. + return; + } + } + + final progressOutcome = + poll_outcome as payjoin.ProgressPollingForProposalTransitionOutcome; + var payjoin_psbt = jsonDecode( + sender.call( + method: "walletprocesspsbt", + params: [progressOutcome.psbtBase64], + ), + )["psbt"]; + var final_psbt = jsonDecode( + sender.call( + method: "finalizepsbt", + params: [payjoin_psbt, jsonEncode(false)], + ), + )["psbt"]; + var final_tx_hex = jsonDecode( + sender.call(method: "finalizepsbt", params: [final_psbt, jsonEncode(true)]), + )["hex"]; + sender.call(method: "sendrawtransaction", params: [jsonEncode(final_tx_hex)]); + + // Check resulting transaction and balances + var decodedTx = jsonDecode( + sender.call( + method: "decoderawtransaction", + params: [jsonEncode(final_tx_hex)], + ), + ); + var network_fees = + (jsonDecode( + sender.call( + method: "decodepsbt", + params: [jsonEncode(final_psbt)], + ), + )["fee"] + as num) + .toDouble(); + // Sender sent the entire value of their utxo to the receiver (minus fees) + expect(decodedTx["vin"].length, 2); + expect(decodedTx["vout"].length, 1); + expect( + jsonDecode( + receiver.call(method: "getbalances", params: []), + )["mine"]["untrusted_pending"], + 100 - network_fees, + ); + expect(jsonDecode(sender.call(method: "getbalance", params: [])), 0.0); +} + void main() { group('fetchOhttpKeys', () { test( @@ -471,176 +747,16 @@ void main() { ); }); - test('Test integration v2 to v2', () async { - env = test_utils.initBitcoindSenderReceiver(); - bitcoind = env.getBitcoind(); - receiver = env.getReceiver(); - sender = env.getSender(); - var receiver_address = - jsonDecode(receiver.call(method: "getnewaddress", params: [])) - as String; - var services = test_utils.TestServices.initialize(); - - services.waitForServicesReady(); - var directory = services.directoryUrl(); - var ohttp_keys = services.fetchOhttpKeys(); - var ohttp_relay = services.ohttpRelayUrl(); - var agent = http.Client(); - - // ********************** - // Inside the Receiver: - var recv_persister = InMemoryReceiverPersister(); - var sender_persister = InMemorySenderPersister(); - var session = create_receiver_context( - receiver_address, - directory, - ohttp_keys, - recv_persister, - ); - var process_response = await process_receiver_proposal( - payjoin.InitializedReceiveSession(session), - recv_persister, - ohttp_relay, - ); - expect(process_response, isNull); - - // ********************** - // Inside the Sender: - // Create a funded PSBT (not broadcasted) to address with amount given in the pj_uri - var pj_uri = session.pjUri(); - var psbt = build_sweep_psbt(sender, pj_uri); - payjoin.WithReplyKey req_ctx = - payjoin.SenderBuilder(psbt: psbt, uri: pj_uri) - .buildRecommended(minFeeRateSatPerKwu: 1000) - .save(persister: sender_persister); - payjoin.RequestOhttpContext request = req_ctx.createV2PostRequest( - ohttpRelay: ohttp_relay, - ); - var response = await agent.post( - Uri.parse(request.request.url), - headers: {"Content-Type": request.request.contentType}, - body: request.request.body, - ); - payjoin.PollingForProposal send_ctx = req_ctx - .processResponse( - response: response.bodyBytes, - postCtx: request.ohttpCtx, - ) - .save(persister: sender_persister); - // POST Original PSBT - - // ********************** - // Inside the Receiver: - - // GET fallback psbt - payjoin.ReceiveSession? payjoin_proposal = - await process_receiver_proposal( - payjoin.InitializedReceiveSession(session), - recv_persister, - ohttp_relay, - ); - expect(payjoin_proposal, isNotNull); - expect(payjoin_proposal, isA()); - - payjoin.PayjoinProposal proposal = - (payjoin_proposal as payjoin.PayjoinProposalReceiveSession).inner; - payjoin.RequestResponse request_response = proposal.createPostRequest( - ohttpRelay: ohttp_relay, - ); - var fallback_response = await agent.post( - Uri.parse(request_response.request.url), - headers: {"Content-Type": request_response.request.contentType}, - body: request_response.request.body, - ); - proposal.processResponse( - body: fallback_response.bodyBytes, - ohttpContext: request_response.clientResponse, - ); - - // ********************** - // Inside the Sender: - // Sender checks, signs, finalizes, extracts, and broadcasts - // Replay post fallback to get the response - payjoin.PollingForProposalTransitionOutcome? poll_outcome; - var attempts = 0; - while (true) { - payjoin.RequestOhttpContext ohttp_context_request = send_ctx - .createPollRequest(ohttpRelay: ohttp_relay); - var final_response = await agent.post( - Uri.parse(ohttp_context_request.request.url), - headers: {"Content-Type": ohttp_context_request.request.contentType}, - body: ohttp_context_request.request.body, - ); - poll_outcome = send_ctx - .processResponse( - response: final_response.bodyBytes, - ohttpCtx: ohttp_context_request.ohttpCtx, - ) - .save(persister: sender_persister); - - if (poll_outcome - is payjoin.ProgressPollingForProposalTransitionOutcome) { - break; - } - - attempts += 1; - if (attempts >= 3) { - // Receiver not ready yet; mirror Python's tolerant polling. - return; - } - } - - final progressOutcome = - poll_outcome as payjoin.ProgressPollingForProposalTransitionOutcome; - var payjoin_psbt = jsonDecode( - sender.call( - method: "walletprocesspsbt", - params: [progressOutcome.psbtBase64], - ), - )["psbt"]; - var final_psbt = jsonDecode( - sender.call( - method: "finalizepsbt", - params: [payjoin_psbt, jsonEncode(false)], - ), - )["psbt"]; - var final_tx_hex = jsonDecode( - sender.call( - method: "finalizepsbt", - params: [final_psbt, jsonEncode(true)], - ), - )["hex"]; - sender.call( - method: "sendrawtransaction", - params: [jsonEncode(final_tx_hex)], - ); + test( + 'Test integration v2 to v2 (callback)', + () async => run_integration_v2_to_v2(TransitionMode.callback), + timeout: const Timeout(Duration(minutes: 5)), + ); - // Check resulting transaction and balances - var decodedTx = jsonDecode( - sender.call( - method: "decoderawtransaction", - params: [jsonEncode(final_tx_hex)], - ), - ); - var network_fees = - (jsonDecode( - sender.call( - method: "decodepsbt", - params: [jsonEncode(final_psbt)], - ), - )["fee"] - as num) - .toDouble(); - // Sender sent the entire value of their utxo to the receiver (minus fees) - expect(decodedTx["vin"].length, 2); - expect(decodedTx["vout"].length, 1); - expect( - jsonDecode( - receiver.call(method: "getbalances", params: []), - )["mine"]["untrusted_pending"], - 100 - network_fees, - ); - expect(jsonDecode(sender.call(method: "getbalance", params: [])), 0.0); - }, timeout: const Timeout(Duration(minutes: 5))); + test( + 'Test integration v2 to v2 (nonblocking)', + () async => run_integration_v2_to_v2(TransitionMode.nonblocking), + timeout: const Timeout(Duration(minutes: 5)), + ); }); } diff --git a/payjoin-ffi/javascript/test/integration.test.ts b/payjoin-ffi/javascript/test/integration.test.ts index 36aed0f2c..4ff072ddb 100644 --- a/payjoin-ffi/javascript/test/integration.test.ts +++ b/payjoin-ffi/javascript/test/integration.test.ts @@ -7,8 +7,8 @@ import { payjoin as nodejsPayjoin, uniffiInitAsync as nodejsUniffiInitAsync, } from "payjoin"; -import * as webPayjoinModule from "../src/web/generated/payjoin.js"; -import initWebAsync from "../src/web/generated/wasm-bindgen/index.js"; +import * as webPayjoinModule from "../dist/web/generated/payjoin.js"; +import initWebAsync from "../dist/web/generated/wasm-bindgen/index.js"; import { InMemoryReceiverPersister, InMemorySenderPersister } from "./utils.ts"; const __filename = fileURLToPath(import.meta.url); @@ -31,13 +31,17 @@ interface Utxo { scriptPubKey: string; } +type TransitionMode = "callback" | "nonblocking"; + type PayjoinModule = typeof nodejsPayjoin; const webPayjoin = webPayjoinModule as unknown as PayjoinModule; -// Helper types to avoid repeating InstanceType everywhere. -type PJ = InstanceType< - PayjoinModule[K] & (new (...args: any) => any) ->; +type PJ = PayjoinModule[K] extends { + prototype: infer P; +} + ? P + : never; + type PJNested< K extends keyof PayjoinModule, N extends keyof PayjoinModule[K], @@ -137,7 +141,7 @@ class CheckInputsNotSeenCallback { this.connection = connection; } - callback(_outpoint: ArrayBuffer): boolean { + callback(_outpoint: nodejsPayjoin.OutPoint): boolean { if (this.connection) { } return false; @@ -159,18 +163,6 @@ class ProcessPsbtCallback { } } -function createReceiverContext( - payjoin: PayjoinModule, - address: string, - directory: string, - ohttpKeys: ReturnType, - persister: InMemoryReceiverPersister, -): PJ<"Initialized"> { - return new payjoin.ReceiverBuilder(address, directory, ohttpKeys) - .build() - .save(persister); -} - function buildSweepPsbt( sender: testUtils.RpcClient, pjUri: PJ<"PjUri">, @@ -230,254 +222,253 @@ function getInputs( return inputs; } -async function processProvisionalProposal( - proposal: PJ<"ProvisionalProposal">, - receiver: testUtils.RpcClient, - recvPersister: InMemoryReceiverPersister, -): Promise> { - return proposal - .finalizeProposal(new ProcessPsbtCallback(receiver)) - .save(recvPersister); -} +class ReceiverProcessor { + constructor( + private readonly payjoin: PayjoinModule, + private readonly receiver: testUtils.RpcClient, + private readonly recvPersister: InMemoryReceiverPersister, + private readonly mode: TransitionMode, + ) {} + + private async processProvisionalProposal( + proposal: PJ<"ProvisionalProposal">, + ): Promise> { + if (this.mode === "callback") { + return proposal + .finalizeProposal(new ProcessPsbtCallback(this.receiver)) + .save(this.recvPersister) as PJ<"PayjoinProposal">; + } -async function processWantsFeeRange( - proposal: PJ<"WantsFeeRange">, - receiver: testUtils.RpcClient, - recvPersister: InMemoryReceiverPersister, -): Promise> { - const wantsFeeRange = proposal.applyFeeRange(1n, 10n).save(recvPersister); - return await processProvisionalProposal( - wantsFeeRange, - receiver, - recvPersister, - ); -} + const signedPsbt = new ProcessPsbtCallback(this.receiver).callback( + proposal.psbtToSign(), + ); + return proposal + .finalizeSignedProposal(signedPsbt) + .save(this.recvPersister) as PJ<"PayjoinProposal">; + } -async function processWantsInputs( - payjoin: PayjoinModule, - proposal: PJ<"WantsInputs">, - receiver: testUtils.RpcClient, - recvPersister: InMemoryReceiverPersister, -): Promise> { - const provisionalProposal = proposal - .contributeInputs(getInputs(payjoin, receiver)) - .commitInputs() - .save(recvPersister); - return await processWantsFeeRange( - provisionalProposal, - receiver, - recvPersister, - ); -} + private async processWantsFeeRange( + proposal: PJ<"WantsFeeRange">, + ): Promise> { + const provisionalProposal = proposal + .applyFeeRange(1n, 10n) + .save(this.recvPersister) as PJ<"ProvisionalProposal">; + return this.processProvisionalProposal(provisionalProposal); + } -async function processWantsOutputs( - payjoin: PayjoinModule, - proposal: PJ<"WantsOutputs">, - receiver: testUtils.RpcClient, - recvPersister: InMemoryReceiverPersister, -): Promise> { - const wantsInputs = proposal.commitOutputs().save(recvPersister); - return await processWantsInputs( - payjoin, - wantsInputs, - receiver, - recvPersister, - ); -} + private async processWantsInputs( + proposal: PJ<"WantsInputs">, + ): Promise> { + const provisionalProposal = proposal + .contributeInputs(getInputs(this.payjoin, this.receiver)) + .commitInputs() + .save(this.recvPersister) as PJ<"WantsFeeRange">; + return this.processWantsFeeRange(provisionalProposal); + } -async function processOutputsUnknown( - payjoin: PayjoinModule, - proposal: PJ<"OutputsUnknown">, - receiver: testUtils.RpcClient, - recvPersister: InMemoryReceiverPersister, -): Promise> { - const wantsOutputs = proposal - .identifyReceiverOutputs(new IsScriptOwnedCallback(receiver)) - .save(recvPersister); - return await processWantsOutputs( - payjoin, - wantsOutputs, - receiver, - recvPersister, - ); -} + private async processWantsOutputs( + proposal: PJ<"WantsOutputs">, + ): Promise> { + const wantsInputs = proposal + .commitOutputs() + .save(this.recvPersister) as PJ<"WantsInputs">; + return this.processWantsInputs(wantsInputs); + } -async function processMaybeInputsSeen( - payjoin: PayjoinModule, - proposal: PJ<"MaybeInputsSeen">, - receiver: testUtils.RpcClient, - recvPersister: InMemoryReceiverPersister, -): Promise> { - const outputsUnknown = proposal - .checkNoInputsSeenBefore(new CheckInputsNotSeenCallback(receiver)) - .save(recvPersister); - return await processOutputsUnknown( - payjoin, - outputsUnknown, - receiver, - recvPersister, - ); -} + private async processOutputsUnknown( + proposal: PJ<"OutputsUnknown">, + ): Promise> { + let wantsOutputs: PJ<"WantsOutputs">; + + if (this.mode === "callback") { + wantsOutputs = proposal + .identifyReceiverOutputs( + new IsScriptOwnedCallback(this.receiver), + ) + .save(this.recvPersister) as PJ<"WantsOutputs">; + } else { + const taggedRefs = proposal + .getOutputScriptRefs() + .map((ref) => + ref.mark( + new IsScriptOwnedCallback(this.receiver).callback( + ref.getValue(), + ), + ), + ); + wantsOutputs = proposal + .applyOutputOwnedChecks(taggedRefs) + .save(this.recvPersister) as PJ<"WantsOutputs">; + } -async function processMaybeInputsOwned( - payjoin: PayjoinModule, - proposal: PJ<"MaybeInputsOwned">, - receiver: testUtils.RpcClient, - recvPersister: InMemoryReceiverPersister, -): Promise> { - const maybeInputsOwned = proposal - .checkInputsNotOwned(new IsScriptOwnedCallback(receiver)) - .save(recvPersister); - return await processMaybeInputsSeen( - payjoin, - maybeInputsOwned, - receiver, - recvPersister, - ); -} + return this.processWantsOutputs(wantsOutputs); + } -async function processUncheckedProposal( - payjoin: PayjoinModule, - proposal: PJ<"UncheckedOriginalPayload">, - receiver: testUtils.RpcClient, - recvPersister: InMemoryReceiverPersister, -): Promise> { - const uncheckedProposal = proposal - .checkBroadcastSuitability( - undefined, - new MempoolAcceptanceCallback(receiver), - ) - .save(recvPersister); - return await processMaybeInputsOwned( - payjoin, - uncheckedProposal, - receiver, - recvPersister, - ); -} + private async processMaybeInputsSeen( + proposal: PJ<"MaybeInputsSeen">, + ): Promise> { + let outputsUnknown: PJ<"OutputsUnknown">; + + if (this.mode === "callback") { + outputsUnknown = proposal + .checkNoInputsSeenBefore( + new CheckInputsNotSeenCallback(this.receiver), + ) + .save(this.recvPersister) as PJ<"OutputsUnknown">; + } else { + const taggedRefs = proposal + .getInputOutpointRefs() + .map((ref) => + ref.mark( + new CheckInputsNotSeenCallback(this.receiver).callback( + ref.getValue(), + ), + ), + ); + outputsUnknown = proposal + .applyInputSeenChecks(taggedRefs) + .save(this.recvPersister) as PJ<"OutputsUnknown">; + } -async function retrieveReceiverProposal( - payjoin: PayjoinModule, - receiver: PJ<"Initialized">, - receiverRpc: testUtils.RpcClient, - recvPersister: InMemoryReceiverPersister, - ohttpRelay: string, -): Promise | null> { - const request = receiver.createPollRequest(ohttpRelay); - const response = await fetch(request.request.url, { - method: "POST", - headers: { "Content-Type": request.request.contentType }, - body: request.request.body, - }); - const responseBuffer = await response.arrayBuffer(); - const res = receiver - .processResponse(responseBuffer, request.clientResponse) - .save(recvPersister); - - if (res instanceof payjoin.InitializedTransitionOutcome.Stasis) { - return null; - } else if (res instanceof payjoin.InitializedTransitionOutcome.Progress) { - const proposal = res.inner.inner; - return await processUncheckedProposal( - payjoin, - proposal, - receiverRpc, - recvPersister, - ); + return this.processOutputsUnknown(outputsUnknown); } - throw new Error(`Unknown initialized transition outcome`); -} + private async processMaybeInputsOwned( + proposal: nodejsPayjoin.MaybeInputsOwned, + ): Promise> { + let maybeInputsSeen: PJ<"MaybeInputsSeen">; + + if (this.mode === "callback") { + maybeInputsSeen = proposal + .checkInputsNotOwned(new IsScriptOwnedCallback(this.receiver)) + .save(this.recvPersister) as PJ<"MaybeInputsSeen">; + } else { + const taggedRefs = proposal + .getInputScriptRefs() + .map((ref) => + ref.mark( + new IsScriptOwnedCallback(this.receiver).callback( + ref.getValue(), + ), + ), + ); + maybeInputsSeen = proposal + .applyInputOwnedChecks(taggedRefs) + .save(this.recvPersister) as PJ<"MaybeInputsSeen">; + } -async function processReceiverProposal( - payjoin: PayjoinModule, - receiver: - | PJ<"Initialized"> - | PJ<"UncheckedOriginalPayload"> - | PJ<"MaybeInputsOwned"> - | PJ<"MaybeInputsSeen"> - | PJ<"OutputsUnknown"> - | PJ<"WantsOutputs"> - | PJ<"WantsInputs"> - | PJ<"WantsFeeRange"> - | PJ<"ProvisionalProposal"> - | PJ<"PayjoinProposal">, - receiverRpc: testUtils.RpcClient, - recvPersister: InMemoryReceiverPersister, - ohttpRelay: string, -): Promise | null> { - if (receiver instanceof payjoin.Initialized) { - return await retrieveReceiverProposal( - payjoin, - receiver, - receiverRpc, - recvPersister, - ohttpRelay, - ); - } - if (receiver instanceof payjoin.UncheckedOriginalPayload) { - return await processUncheckedProposal( - payjoin, - receiver, - receiverRpc, - recvPersister, - ); + return this.processMaybeInputsSeen(maybeInputsSeen); } - if (receiver instanceof payjoin.MaybeInputsOwned) { - return await processMaybeInputsOwned( - payjoin, - receiver, - receiverRpc, - recvPersister, - ); - } - if (receiver instanceof payjoin.MaybeInputsSeen) { - return await processMaybeInputsSeen( - payjoin, - receiver, - receiverRpc, - recvPersister, - ); - } - if (receiver instanceof payjoin.OutputsUnknown) { - return await processOutputsUnknown( - payjoin, - receiver, - receiverRpc, - recvPersister, - ); - } - if (receiver instanceof payjoin.WantsOutputs) { - return await processWantsOutputs( - payjoin, - receiver, - receiverRpc, - recvPersister, - ); - } - if (receiver instanceof payjoin.WantsInputs) { - return await processWantsInputs( - payjoin, - receiver, - receiverRpc, - recvPersister, - ); - } - if (receiver instanceof payjoin.WantsFeeRange) { - return await processWantsFeeRange(receiver, receiverRpc, recvPersister); + + private async processUncheckedProposal( + proposal: PJ<"UncheckedOriginalPayload">, + ): Promise> { + let maybeInputsOwned: PJ<"MaybeInputsOwned">; + + if (this.mode === "callback") { + maybeInputsOwned = proposal + .checkBroadcastSuitability( + undefined, + new MempoolAcceptanceCallback(this.receiver), + ) + .save(this.recvPersister) as PJ<"MaybeInputsOwned">; + } else { + const canBroadcastResult = new MempoolAcceptanceCallback( + this.receiver, + ).callback(proposal.extractTxToCheckBroadcastSuitability()); + maybeInputsOwned = proposal + .applyBroadcastSuitability(undefined, canBroadcastResult) + .save(this.recvPersister) as PJ<"MaybeInputsOwned">; + } + + return this.processMaybeInputsOwned(maybeInputsOwned); } - if (receiver instanceof payjoin.ProvisionalProposal) { - return await processProvisionalProposal( - receiver, - receiverRpc, - recvPersister, - ); + + createReceiverContext( + address: string, + directory: string, + ohttpKeys: ReturnType, + ): PJ<"Initialized"> { + return new this.payjoin.ReceiverBuilder(address, directory, ohttpKeys) + .build() + .save(this.recvPersister) as PJ<"Initialized">; } - if (receiver instanceof payjoin.PayjoinProposal) { - return receiver; + + private async retrieveReceiverProposal( + session: PJ<"Initialized">, + ohttpRelay: string, + ): Promise | null> { + const request = session.createPollRequest(ohttpRelay); + const response = await fetch(request.request.url, { + method: "POST", + headers: { "Content-Type": request.request.contentType }, + body: request.request.body, + }); + const responseBuffer = await response.arrayBuffer(); + const res = session + .processResponse(responseBuffer, request.clientResponse) + .save(this.recvPersister); + + if (res instanceof this.payjoin.InitializedTransitionOutcome.Stasis) { + return null; + } else if ( + res instanceof this.payjoin.InitializedTransitionOutcome.Progress + ) { + return this.processUncheckedProposal( + res.inner.inner as PJ<"UncheckedOriginalPayload">, + ); + } + + throw new Error(`Unknown initialized transition outcome`); } - throw new Error(`Unknown receiver state`); + async processReceiverProposal( + receiver: + | PJ<"Initialized"> + | PJ<"UncheckedOriginalPayload"> + | PJ<"MaybeInputsOwned"> + | PJ<"MaybeInputsSeen"> + | PJ<"OutputsUnknown"> + | PJ<"WantsOutputs"> + | PJ<"WantsInputs"> + | PJ<"WantsFeeRange"> + | PJ<"ProvisionalProposal"> + | PJ<"PayjoinProposal">, + ohttpRelay: string, + ): Promise | null> { + if (receiver instanceof this.payjoin.Initialized) { + return this.retrieveReceiverProposal(receiver, ohttpRelay); + } + if (receiver instanceof this.payjoin.UncheckedOriginalPayload) { + return this.processUncheckedProposal(receiver); + } + if (receiver instanceof this.payjoin.MaybeInputsOwned) { + return this.processMaybeInputsOwned(receiver); + } + if (receiver instanceof this.payjoin.MaybeInputsSeen) { + return this.processMaybeInputsSeen(receiver); + } + if (receiver instanceof this.payjoin.OutputsUnknown) { + return this.processOutputsUnknown(receiver); + } + if (receiver instanceof this.payjoin.WantsOutputs) { + return this.processWantsOutputs(receiver); + } + if (receiver instanceof this.payjoin.WantsInputs) { + return this.processWantsInputs(receiver); + } + if (receiver instanceof this.payjoin.WantsFeeRange) { + return this.processWantsFeeRange(receiver); + } + if (receiver instanceof this.payjoin.ProvisionalProposal) { + return this.processProvisionalProposal(receiver); + } + if (receiver instanceof this.payjoin.PayjoinProposal) { + return receiver; + } + + throw new Error(`Unknown receiver state`); + } } function testFfiValidation(payjoin: PayjoinModule): void { @@ -580,7 +571,10 @@ function testFfiValidation(payjoin: PayjoinModule): void { }, /AmountOutOfRange/); } -async function testIntegrationV2ToV2(payjoin: PayjoinModule): Promise { +async function testIntegrationV2ToV2( + payjoin: PayjoinModule, + mode: TransitionMode, +): Promise { const env = testUtils.initBitcoindSenderReceiver(); const receiver = env.getReceiver(); const sender = env.getSender(); @@ -598,21 +592,22 @@ async function testIntegrationV2ToV2(payjoin: PayjoinModule): Promise { ); const recvPersister = new InMemoryReceiverPersister(); + const recvProcessor = new ReceiverProcessor( + payjoin, + receiver, + recvPersister, + mode, + ); const senderPersister = new InMemorySenderPersister(); - const session = createReceiverContext( - payjoin, + const session = recvProcessor.createReceiverContext( receiverAddress, directory, ohttpKeys, - recvPersister, ); - let processResponse = await processReceiverProposal( - payjoin, + const processResponse = await recvProcessor.processReceiverProposal( session, - receiver, - recvPersister, ohttpRelay, ); assert.strictEqual( @@ -622,7 +617,7 @@ async function testIntegrationV2ToV2(payjoin: PayjoinModule): Promise { ); const pjUri = session.pjUri(); - const psbt = buildSweepPsbt(sender, pjUri); + const psbt = buildSweepPsbt(sender, pjUri as PJ<"PjUri">); const reqCtx = new payjoin.SenderBuilder(psbt, pjUri) .buildRecommended(1000n) .save(senderPersister); @@ -638,11 +633,8 @@ async function testIntegrationV2ToV2(payjoin: PayjoinModule): Promise { .processResponse(responseBuffer, request.ohttpCtx) .save(senderPersister); - let payjoinProposal = await processReceiverProposal( - payjoin, + const payjoinProposal = await recvProcessor.processReceiverProposal( session, - receiver, - recvPersister, ohttpRelay, ); assert.notStrictEqual( @@ -736,11 +728,13 @@ async function testIntegrationV2ToV2(payjoin: PayjoinModule): Promise { async function runTests(): Promise { await nodejsUniffiInitAsync(); testFfiValidation(nodejsPayjoin); - await testIntegrationV2ToV2(nodejsPayjoin); + await testIntegrationV2ToV2(nodejsPayjoin, "callback"); + await testIntegrationV2ToV2(nodejsPayjoin, "nonblocking"); await webUniffiInitAsync(); testFfiValidation(webPayjoin); - await testIntegrationV2ToV2(webPayjoin); + await testIntegrationV2ToV2(webPayjoin, "callback"); + await testIntegrationV2ToV2(webPayjoin, "nonblocking"); } runTests().catch((error: unknown) => { diff --git a/payjoin-ffi/python/test/test_payjoin_integration_test.py b/payjoin-ffi/python/test/test_payjoin_integration_test.py index 146fc41fe..ef518015a 100644 --- a/payjoin-ffi/python/test/test_payjoin_integration_test.py +++ b/payjoin-ffi/python/test/test_payjoin_integration_test.py @@ -2,7 +2,7 @@ import sys import httpx import json -from typing import cast, Protocol, Any +from typing import cast, Protocol, Any, Literal from payjoin import * from payjoin.http import fetch_ohttp_keys @@ -22,6 +22,9 @@ class HasInner(Protocol): inner: Any +TransitionMode = Literal["callback", "nonblocking"] + + class TestPayjoin(unittest.IsolatedAsyncioTestCase): @classmethod def setUpClass(cls): @@ -92,12 +95,14 @@ async def process_receiver_proposal( receiver: ReceiveSession, recv_persister: InMemoryReceiverPersister, ohttp_relay: str, + mode: TransitionMode, ) -> Optional[ReceiveSession.PAYJOIN_PROPOSAL]: if receiver.is_INITIALIZED(): res = await self.retrieve_receiver_proposal( cast(ReceiveSession.INITIALIZED, receiver).inner, recv_persister, ohttp_relay, + mode, ) if res is None: return None @@ -107,35 +112,49 @@ async def process_receiver_proposal( return await self.process_unchecked_proposal( cast(ReceiveSession.UNCHECKED_ORIGINAL_PAYLOAD, receiver).inner, recv_persister, + mode, ) if receiver.is_MAYBE_INPUTS_OWNED(): return await self.process_maybe_inputs_owned( - cast(ReceiveSession.MAYBE_INPUTS_OWNED, receiver).inner, recv_persister + cast(ReceiveSession.MAYBE_INPUTS_OWNED, receiver).inner, + recv_persister, + mode, ) if receiver.is_MAYBE_INPUTS_SEEN(): return await self.process_maybe_inputs_seen( - cast(ReceiveSession.MAYBE_INPUTS_SEEN, receiver).inner, recv_persister + cast(ReceiveSession.MAYBE_INPUTS_SEEN, receiver).inner, + recv_persister, + mode, ) if receiver.is_OUTPUTS_UNKNOWN(): return await self.process_outputs_unknown( - cast(ReceiveSession.OUTPUTS_UNKNOWN, receiver).inner, recv_persister + cast(ReceiveSession.OUTPUTS_UNKNOWN, receiver).inner, + recv_persister, + mode, ) if receiver.is_WANTS_OUTPUTS(): return await self.process_wants_outputs( - cast(ReceiveSession.WANTS_OUTPUTS, receiver).inner, recv_persister + cast(ReceiveSession.WANTS_OUTPUTS, receiver).inner, + recv_persister, + mode, ) if receiver.is_WANTS_INPUTS(): return await self.process_wants_inputs( - cast(ReceiveSession.WANTS_INPUTS, receiver).inner, recv_persister + cast(ReceiveSession.WANTS_INPUTS, receiver).inner, + recv_persister, + mode, ) if receiver.is_WANTS_FEE_RANGE(): return await self.process_wants_fee_range( - cast(ReceiveSession.WANTS_FEE_RANGE, receiver).inner, recv_persister + cast(ReceiveSession.WANTS_FEE_RANGE, receiver).inner, + recv_persister, + mode, ) if receiver.is_PROVISIONAL_PROPOSAL(): return await self.process_provisional_proposal( cast(ReceiveSession.PROVISIONAL_PROPOSAL, receiver).inner, recv_persister, + mode, ) if receiver.is_PAYJOIN_PROPOSAL(): return cast(ReceiveSession.PAYJOIN_PROPOSAL, receiver) @@ -161,6 +180,7 @@ async def retrieve_receiver_proposal( receiver: Initialized, recv_persister: InMemoryReceiverPersister, ohttp_relay: str, + mode: TransitionMode, ): agent = httpx.AsyncClient() request: RequestResponse = receiver.create_poll_request(ohttp_relay) @@ -175,80 +195,155 @@ async def retrieve_receiver_proposal( if res.is_STASIS(): return None return await self.process_unchecked_proposal( - cast(ReceiveSession.UNCHECKED_ORIGINAL_PAYLOAD, res).inner, recv_persister + cast(ReceiveSession.UNCHECKED_ORIGINAL_PAYLOAD, res).inner, + recv_persister, + mode, ) async def process_unchecked_proposal( self, proposal: UncheckedOriginalPayload, recv_persister: InMemoryReceiverPersister, + mode: TransitionMode, ): - receiver = proposal.check_broadcast_suitability( - None, MempoolAcceptanceCallback(self.receiver) - ).save(recv_persister) - return await self.process_maybe_inputs_owned(receiver, recv_persister) + if mode == "callback": + receiver = proposal.check_broadcast_suitability( + None, MempoolAcceptanceCallback(self.receiver) + ).save(recv_persister) + else: + can_broadcast = MempoolAcceptanceCallback(self.receiver).callback( + proposal.extract_tx_to_check_broadcast_suitability() + ) + receiver = proposal.apply_broadcast_suitability(None, can_broadcast).save( + recv_persister + ) + return await self.process_maybe_inputs_owned(receiver, recv_persister, mode) async def process_maybe_inputs_owned( self, proposal: MaybeInputsOwned, recv_persister: InMemoryReceiverPersister, + mode: TransitionMode, ): - maybe_inputs_owned = proposal.check_inputs_not_owned( - IsScriptOwnedCallback(self.receiver) - ).save(recv_persister) - return await self.process_maybe_inputs_seen(maybe_inputs_owned, recv_persister) + if mode == "callback": + maybe_inputs_owned = proposal.check_inputs_not_owned( + IsScriptOwnedCallback(self.receiver) + ).save(recv_persister) + else: + tagged_refs = [ + ref.mark(IsScriptOwnedCallback(self.receiver).callback(ref.get_value())) + for ref in proposal.get_input_script_refs() + ] + maybe_inputs_owned = proposal.apply_input_owned_checks(tagged_refs).save( + recv_persister + ) + return await self.process_maybe_inputs_seen( + maybe_inputs_owned, recv_persister, mode + ) async def process_maybe_inputs_seen( - self, proposal: MaybeInputsSeen, recv_persister: InMemoryReceiverPersister + self, + proposal: MaybeInputsSeen, + recv_persister: InMemoryReceiverPersister, + mode: TransitionMode, ): - outputs_unknown = proposal.check_no_inputs_seen_before( - CheckInputsNotSeenCallback(self.receiver) - ).save(recv_persister) - return await self.process_outputs_unknown(outputs_unknown, recv_persister) + if mode == "callback": + outputs_unknown = proposal.check_no_inputs_seen_before( + CheckInputsNotSeenCallback(self.receiver) + ).save(recv_persister) + else: + tagged_refs = [ + ref.mark( + CheckInputsNotSeenCallback(self.receiver).callback(ref.get_value()) + ) + for ref in proposal.get_input_outpoint_refs() + ] + outputs_unknown = proposal.apply_input_seen_checks(tagged_refs).save( + recv_persister + ) + return await self.process_outputs_unknown(outputs_unknown, recv_persister, mode) async def process_outputs_unknown( - self, proposal: OutputsUnknown, recv_persister: InMemoryReceiverPersister + self, + proposal: OutputsUnknown, + recv_persister: InMemoryReceiverPersister, + mode: TransitionMode, ): - wants_outputs = proposal.identify_receiver_outputs( - IsScriptOwnedCallback(self.receiver) - ).save(recv_persister) - return await self.process_wants_outputs(wants_outputs, recv_persister) + if mode == "callback": + wants_outputs = proposal.identify_receiver_outputs( + IsScriptOwnedCallback(self.receiver) + ).save(recv_persister) + else: + tagged_refs = [ + ref.mark(IsScriptOwnedCallback(self.receiver).callback(ref.get_value())) + for ref in proposal.get_output_script_refs() + ] + wants_outputs = proposal.apply_output_owned_checks(tagged_refs).save( + recv_persister + ) + return await self.process_wants_outputs(wants_outputs, recv_persister, mode) async def process_wants_outputs( - self, proposal: WantsOutputs, recv_persister: InMemoryReceiverPersister + self, + proposal: WantsOutputs, + recv_persister: InMemoryReceiverPersister, + mode: TransitionMode, ): wants_inputs = proposal.commit_outputs().save(recv_persister) - return await self.process_wants_inputs(wants_inputs, recv_persister) + return await self.process_wants_inputs(wants_inputs, recv_persister, mode) async def process_wants_inputs( - self, proposal: WantsInputs, recv_persister: InMemoryReceiverPersister + self, + proposal: WantsInputs, + recv_persister: InMemoryReceiverPersister, + mode: TransitionMode, ): provisional_proposal = ( proposal.contribute_inputs(get_inputs(self.receiver)) .commit_inputs() .save(recv_persister) ) - return await self.process_wants_fee_range(provisional_proposal, recv_persister) + return await self.process_wants_fee_range( + provisional_proposal, recv_persister, mode + ) async def process_wants_fee_range( - self, proposal: WantsFeeRange, recv_persister: InMemoryReceiverPersister + self, + proposal: WantsFeeRange, + recv_persister: InMemoryReceiverPersister, + mode: TransitionMode, ): provisional_proposal = proposal.apply_fee_range(1, 10).save(recv_persister) return await self.process_provisional_proposal( - provisional_proposal, recv_persister + provisional_proposal, recv_persister, mode ) async def process_provisional_proposal( self, proposal: ProvisionalProposal, recv_persister: InMemoryReceiverPersister, + mode: TransitionMode, ): - payjoin_proposal = proposal.finalize_proposal( - ProcessPsbtCallback(self.receiver) - ).save(recv_persister) + if mode == "callback": + payjoin_proposal = proposal.finalize_proposal( + ProcessPsbtCallback(self.receiver) + ).save(recv_persister) + else: + signed_psbt = ProcessPsbtCallback(self.receiver).callback( + proposal.psbt_to_sign() + ) + payjoin_proposal = proposal.finalize_signed_proposal(signed_psbt).save( + recv_persister + ) return ReceiveSession.PAYJOIN_PROPOSAL(payjoin_proposal) - async def test_integration_v2_to_v2(self): + def setUp(self): + sender_address = json.loads(self.sender.call("getnewaddress", [])) + self.sender.call( + "generatetoaddress", [json.dumps(101), json.dumps(sender_address)] + ) + + async def _run_integration_v2_to_v2(self, mode: TransitionMode): try: receiver_address = json.loads(self.receiver.call("getnewaddress", [])) init_tracing() @@ -271,6 +366,7 @@ async def test_integration_v2_to_v2(self): cast(ReceiveSession, ReceiveSession.INITIALIZED(session)), recv_persister, ohttp_relay, + mode, ) self.assertIsNone(process_response) @@ -305,6 +401,7 @@ async def test_integration_v2_to_v2(self): cast(ReceiveSession, ReceiveSession.INITIALIZED(session)), recv_persister, ohttp_relay, + mode, ) self.assertIsNotNone(payjoin_proposal) self.assertEqual( @@ -385,6 +482,12 @@ async def test_integration_v2_to_v2(self): print("Caught:", e) raise + async def test_integration_v2_to_v2_callback(self): + await self._run_integration_v2_to_v2("callback") + + async def test_integration_v2_to_v2_nonblocking(self): + await self._run_integration_v2_to_v2("nonblocking") + def build_sweep_psbt(sender: RpcClient, pj_uri: PjUri) -> str: outputs = {} diff --git a/payjoin-ffi/src/receive/mod.rs b/payjoin-ffi/src/receive/mod.rs index 47d9833f0..b2af0b1af 100644 --- a/payjoin-ffi/src/receive/mod.rs +++ b/payjoin-ffi/src/receive/mod.rs @@ -344,9 +344,6 @@ impl InitialReceiveTransition { } } -#[derive(Clone, Debug, uniffi::Object)] -pub struct ReceiverBuilder(payjoin::receive::v2::ReceiverBuilder); - /// Primitive representation of a transaction output for the FFI boundary. #[derive(Clone, Debug, serde::Serialize, serde::Deserialize, uniffi::Record)] pub struct TxOut { @@ -394,7 +391,7 @@ impl TxIn { } /// Primitive representation of an outpoint for the FFI boundary. -#[derive(Clone, Debug, serde::Serialize, serde::Deserialize, uniffi::Record)] +#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize, uniffi::Record)] pub struct OutPoint { /// Hex-encoded txid (big-endian). pub txid: String, @@ -458,6 +455,9 @@ impl From for Weight { fn from(value: payjoin::bitcoin::Weight) -> Self { Weight { weight_units: value.to_wu() } } } +#[derive(Clone, Debug, uniffi::Object)] +pub struct ReceiverBuilder(payjoin::receive::v2::ReceiverBuilder); + #[uniffi::export] impl ReceiverBuilder { /// Creates a new [`Initialized`] with the provided parameters. @@ -632,7 +632,11 @@ impl Initialized { .map_err(Into::into) } - /// The response can either be an UncheckedOriginalPayload or an ACCEPTED message indicating no UncheckedOriginalPayload is available yet. + /// Process the polling response from the directory. + /// + /// Returns an [`InitializedTransition`] that, once persisted, yields either + /// an [`UncheckedOriginalPayload`] if the sender's Original PSBT is available, + /// or [`Initialized`] if no proposal has arrived yet. pub fn process_response(&self, body: &[u8], ctx: &ClientResponse) -> InitializedTransition { InitializedTransition(Arc::new(RwLock::new(Some( self.0.clone().process_response(body, ctx.into()), @@ -713,6 +717,11 @@ pub trait CanBroadcast: Send + Sync { #[uniffi::export] impl UncheckedOriginalPayload { + /// Check that the sender's Original PSBT is suitable for broadcast, ensuring + /// it can be used as a fallback if the payjoin does not complete. + /// + /// Returns an [`UncheckedOriginalPayloadTransition`] that, once persisted, + /// yields a [`MaybeInputsOwned`] to continue validation. pub fn check_broadcast_suitability( &self, min_fee_rate_sat_per_kwu: Option, @@ -728,6 +737,32 @@ impl UncheckedOriginalPayload { ))))) } + /// Extract the transaction from the Original PSBT for external broadcast suitability checks. + /// + /// Returns the consensus-encoded raw transaction bytes. + pub fn extract_tx_to_check_broadcast_suitability(&self) -> Vec { + payjoin::bitcoin::consensus::encode::serialize( + &self.0.clone().extract_tx_to_check_broadcast_suitability(), + ) + } + + /// Apply the result of an external broadcast suitability check, ensuring + /// the Original PSBT can be used as a fallback if the payjoin does + /// not complete. + /// + /// Returns an [`UncheckedOriginalPayloadTransition`] that, once persisted, + /// yields a [`MaybeInputsOwned`] to continue validation. + pub fn apply_broadcast_suitability( + &self, + min_fee_rate_sat_per_kwu: Option, + can_broadcast: bool, + ) -> Result { + let min_fee_rate = validate_fee_rate_sat_per_kwu_opt(min_fee_rate_sat_per_kwu)?; + Ok(UncheckedOriginalPayloadTransition(Arc::new(RwLock::new(Some( + self.0.clone().apply_broadcast_suitability(min_fee_rate, can_broadcast), + ))))) + } + /// Call this method if the only way to initiate a Payjoin with this receiver /// requires manual intervention, as in most consumer wallets. /// @@ -740,6 +775,65 @@ impl UncheckedOriginalPayload { } } +trait FfiTaggedReference { + fn get_result(&self) -> bool; + fn get_value(&self) -> V; +} + +fn to_tagged_references( + references: impl Iterator>, + checks: Vec>, +) -> Result< + impl Iterator>, + payjoin::error::ImplementationError, +> +where + V: Clone, + T: payjoin::receive::Tag, + R: FfiTaggedReference, + Vffi: From + PartialEq, +{ + payjoin::receive::check_references(references, &mut move |pj_ref: &V| { + let found_result = checks.iter().find_map(|ffi_ref| { + if Vffi::from(pj_ref.clone()) == ffi_ref.get_value() { + Some(ffi_ref.get_result()) + } else { + None + } + }); + match found_result { + Some(result) => Ok(result), + None => Err(payjoin::ImplementationError::from( + "Returned input owned checks is missing result for script {script}", + )), + } + }) +} + +#[derive(Debug, uniffi::Object)] +pub struct InputOwnedReference( + payjoin::receive::Reference, +); + +#[uniffi::export] +impl InputOwnedReference { + pub fn get_value(&self) -> Vec { self.0.get_value().to_bytes() } + pub fn mark(&self, result: bool) -> Arc { + Arc::new(InputOwnedTaggedReference { value: self.get_value(), result }) + } +} + +#[derive(Debug, Clone, uniffi::Object)] +pub struct InputOwnedTaggedReference { + value: Vec, + result: bool, +} + +impl FfiTaggedReference> for InputOwnedTaggedReference { + fn get_result(&self) -> bool { self.result } + fn get_value(&self) -> Vec { self.value.clone() } +} + #[derive(Clone, uniffi::Object)] pub struct MaybeInputsOwned(payjoin::receive::v2::Receiver); @@ -777,12 +871,21 @@ pub trait IsScriptOwned: Send + Sync { #[uniffi::export] impl MaybeInputsOwned { - ///The Sender’s Original PSBT + /// Extract the transaction from the Original PSBT for scheduling broadcast as a + /// fallback in case the payjoin does not complete. + /// + /// Returns the consensus-encoded raw transaction bytes. pub fn extract_tx_to_schedule_broadcast(&self) -> Vec { payjoin::bitcoin::consensus::encode::serialize( &self.0.clone().extract_tx_to_schedule_broadcast(), ) } + + /// Check that none of the Original PSBT's inputs belong to the receiver, + /// preventing an attacker from spending the receiver's own inputs. + /// + /// Returns a [`MaybeInputsOwnedTransition`] that, once persisted, + /// yields a [`MaybeInputsSeen`] to continue validation. pub fn check_inputs_not_owned( &self, is_owned: Arc, @@ -793,6 +896,63 @@ impl MaybeInputsOwned { }), )))) } + + /// Get references to the input scripts for external ownership checks. + /// + /// Each reference can be marked with the result via [`InputOwnedReference::mark`] + /// and passed to [`MaybeInputsOwned::apply_input_owned_checks`]. + pub fn get_input_script_refs(&self) -> Result>, ReceiverError> { + self.0 + .clone() + .get_input_script_refs() + .map(|iter| { + iter.map(|input_script_ref| Arc::new(InputOwnedReference(input_script_ref))) + .collect::>() + }) + .map_err(ReceiverError::from) + } + + /// Apply the results of external input ownership checks, ensuring none of the + /// inputs are owned by the receiver. This prevents an attacker from spending + /// the receiver's own inputs. + /// + /// Returns a [`MaybeInputsOwnedTransition`] that, once persisted, + /// yields a [`MaybeInputsSeen`] to continue validation. + pub fn apply_input_owned_checks( + &self, + checked_input_scripts: Vec>, + ) -> Result { + let references = self.0.clone().get_input_script_refs()?; + let checked = to_tagged_references(references, checked_input_scripts) + .map_err(|e| ReceiverError::Implementation(Arc::new(ImplementationError::from(e))))?; + Ok(MaybeInputsOwnedTransition(Arc::new(RwLock::new(Some( + self.0.clone().apply_input_owned_checks(checked), + ))))) + } +} + +#[derive(Debug, uniffi::Object)] +pub struct InputSeenReference( + payjoin::receive::Reference, +); + +#[uniffi::export] +impl InputSeenReference { + pub fn get_value(&self) -> OutPoint { self.0.get_value().into() } + pub fn mark(&self, result: bool) -> Arc { + Arc::new(InputSeenTaggedReference { value: self.get_value(), result }) + } +} + +#[derive(Debug, Clone, uniffi::Object)] +pub struct InputSeenTaggedReference { + value: OutPoint, + result: bool, +} + +impl FfiTaggedReference for InputSeenTaggedReference { + fn get_result(&self) -> bool { self.result } + fn get_value(&self) -> OutPoint { self.value.clone() } } #[derive(Clone, uniffi::Object)] @@ -832,6 +992,12 @@ pub trait IsOutputKnown: Send + Sync { #[uniffi::export] impl MaybeInputsSeen { + /// Check that none of the inputs have been seen before, preventing input + /// probing and replay attacks (where inputs have been used in a previous + /// payjoin attempt). + /// + /// Returns a [`MaybeInputsSeenTransition`] that, once persisted, + /// yields an [`OutputsUnknown`] to continue validation. pub fn check_no_inputs_seen_before( &self, is_known: Arc, @@ -844,6 +1010,60 @@ impl MaybeInputsSeen { }), )))) } + + /// Get references to the input outpoints for external outpoint seen checks. + /// + /// Each reference can be marked with the result via [`InputSeenReference::mark`] + /// and passed to [`MaybeInputsSeen::apply_input_seen_checks`]. + pub fn get_input_outpoint_refs(&self) -> Vec> { + self.0 + .clone() + .get_input_outpoint_refs() + .map(|input_outpoint_ref| Arc::new(InputSeenReference(input_outpoint_ref))) + .collect::>() + } + + /// Apply the results of external outpoint seen checks, ensuring none of + /// the inputs have been seen before. This prevents input probing and replay + /// attacks (where inputs have been used in a previous payjoin attempt). + /// + /// Returns a [`MaybeInputsSeenTransition`] that, once persisted, + /// yields an [`OutputsUnknown`] to continue validation. + pub fn apply_input_seen_checks( + &self, + checked_input_outpoints: Vec>, + ) -> Result { + let references = self.0.clone().get_input_outpoint_refs(); + let checked = to_tagged_references(references, checked_input_outpoints) + .map_err(|e| ReceiverError::Implementation(Arc::new(ImplementationError::from(e))))?; + Ok(MaybeInputsSeenTransition(Arc::new(RwLock::new(Some( + self.0.clone().apply_input_seen_checks(checked), + ))))) + } +} + +#[derive(Debug, uniffi::Object)] +pub struct OutputOwnedReference( + payjoin::receive::Reference, +); + +#[uniffi::export] +impl OutputOwnedReference { + pub fn get_value(&self) -> Vec { self.0.get_value().to_bytes() } + pub fn mark(&self, result: bool) -> Arc { + Arc::new(OutputOwnedTaggedReference { value: self.get_value(), result }) + } +} + +#[derive(Debug, Clone, uniffi::Object)] +pub struct OutputOwnedTaggedReference { + value: Vec, + result: bool, +} + +impl FfiTaggedReference> for OutputOwnedTaggedReference { + fn get_result(&self) -> bool { self.result } + fn get_value(&self) -> Vec { self.value.clone() } } /// The receiver has not yet identified which outputs belong to the receiver. @@ -880,7 +1100,11 @@ impl_save_for_transition!(OutputsUnknownTransition, WantsOutputs); #[uniffi::export] impl OutputsUnknown { - /// Find which outputs belong to the receiver + /// Identify which outputs in the original transaction belong to the receiver + /// and ensure at least one output pays the receiver. + /// + /// Returns an [`OutputsUnknownTransition`] that, once persisted, + /// yields a [`WantsOutputs`] to continue the proposal. pub fn identify_receiver_outputs( &self, is_receiver_output: Arc, @@ -893,6 +1117,36 @@ impl OutputsUnknown { }), )))) } + + /// Get references to the output scripts for external ownership checks. + /// + /// Each reference can be marked with the result via [`OutputOwnedReference::mark`] + /// and passed to [`OutputsUnknown::apply_output_owned_checks`]. + pub fn get_output_script_refs(&self) -> Vec> { + self.0 + .clone() + .get_output_script_refs() + .map(|output_script_ref| Arc::new(OutputOwnedReference(output_script_ref))) + .collect::>() + } + + /// Apply the results of external output ownership checks, identifying which + /// outputs in the original transaction belong to the receiver and ensuring + /// at least one output pays the receiver. + /// + /// Returns an [`OutputsUnknownTransition`] that, once persisted, + /// yields a [`WantsOutputs`] to continue the proposal. + pub fn apply_output_owned_checks( + &self, + checked_output_scripts: Vec>, + ) -> Result { + let references = self.0.clone().get_output_script_refs(); + let checked = to_tagged_references(references, checked_output_scripts) + .map_err(|e| ReceiverError::Implementation(Arc::new(ImplementationError::from(e))))?; + Ok(OutputsUnknownTransition(Arc::new(RwLock::new(Some( + self.0.clone().apply_output_owned_checks(checked), + ))))) + } } #[derive(uniffi::Object)] @@ -923,8 +1177,14 @@ impl_save_for_transition!(WantsOutputsTransition, WantsInputs); #[uniffi::export] impl WantsOutputs { + /// Returns whether output substitution is enabled for this session. pub fn output_substitution(&self) -> OutputSubstitution { self.0.output_substitution() } + /// Replace all receiver outputs with the provided `replacement_outputs`, + /// and set up the `drain_script` as the receiver-owned output whose value + /// may be adjusted based on modifications in subsequent states. + /// + /// Returns an updated [`WantsOutputs`] with the replaced outputs. pub fn replace_receiver_outputs( &self, replacement_outputs: Vec, @@ -942,6 +1202,9 @@ impl WantsOutputs { .map_err(Into::into) } + /// Substitute the receiver output script with the provided script. + /// + /// Returns an updated [`WantsOutputs`] with the substituted output. pub fn substitute_receiver_script( &self, output_script_pubkey: Vec, @@ -955,6 +1218,10 @@ impl WantsOutputs { .map_err(Into::into) } + /// Commit the output modifications and proceed to input contribution. + /// + /// Returns a [`WantsOutputsTransition`] that, once persisted, + /// yields a [`WantsInputs`]. pub fn commit_outputs(&self) -> WantsOutputsTransition { WantsOutputsTransition(Arc::new(RwLock::new(Some(self.0.clone().commit_outputs())))) } @@ -1011,6 +1278,9 @@ impl WantsInputs { } } + /// Add the provided inputs to the payjoin proposal. + /// + /// Returns an updated [`WantsInputs`] with the contributed inputs. pub fn contribute_inputs( &self, replacement_inputs: Vec>, @@ -1024,6 +1294,10 @@ impl WantsInputs { .map_err(Into::into) } + /// Commit the input contributions and proceed to fee negotiation. + /// + /// Returns a [`WantsInputsTransition`] that, once persisted, + /// yields a [`WantsFeeRange`]. pub fn commit_inputs(&self) -> WantsInputsTransition { WantsInputsTransition(Arc::new(RwLock::new(Some(self.0.clone().commit_inputs())))) } @@ -1162,6 +1436,10 @@ pub trait ProcessPsbt: Send + Sync { #[uniffi::export] impl ProvisionalProposal { + /// Finalize the proposal by signing the PSBT via the `process_psbt` callback. + /// + /// Returns a [`ProvisionalProposalTransition`] that, once persisted, + /// yields the final [`PayjoinProposal`]. pub fn finalize_proposal( &self, process_psbt: Arc, @@ -1176,7 +1454,20 @@ impl ProvisionalProposal { )))) } + /// Extract the PSBT that needs to be signed by the receiver's wallet. pub fn psbt_to_sign(&self) -> String { self.0.clone().psbt_to_sign().to_string() } + + /// Finalize the proposal with a signed PSBT. + /// + /// Returns a [`ProvisionalProposalTransition`] that, once persisted, + /// yields the final [`PayjoinProposal`]. + pub fn finalize_signed_proposal(&self, signed_psbt: String) -> ProvisionalProposalTransition { + ProvisionalProposalTransition(Arc::new(RwLock::new(Some( + self.0.clone().finalize_proposal(|_| { + Ok(Psbt::from_str(&signed_psbt).map_err(ImplementationError::new)?) + }), + )))) + } } #[derive(Clone, uniffi::Object)] @@ -1218,6 +1509,8 @@ impl_save_for_transition!(PayjoinProposalTransition, Monitor); #[uniffi::export] impl PayjoinProposal { + /// Returns the outpoints of receiver UTXOs that should be locked to + /// prevent double-spending while the payjoin is in progress. pub fn utxos_to_be_locked(&self) -> Vec { let mut outpoints: Vec = Vec::new(); for o in String { , @@ -1333,6 +1627,8 @@ impl HasReplyableErrorTransition { #[uniffi::export] impl HasReplyableError { + /// Construct an OHTTP encapsulated POST request to post the receiver + /// error response to the directory so it can be retrieved by the sender. pub fn create_error_request( &self, ohttp_relay: String, @@ -1342,6 +1638,11 @@ impl HasReplyableError { }) } + /// Process the response from the directory after posting the receiver + /// error response. + /// + /// Returns a [`HasReplyableErrorTransition`] that, once persisted, + /// completes the error reporting. pub fn process_error_response( &self, body: &[u8], @@ -1422,6 +1723,11 @@ fn try_deserialize_tx( #[uniffi::export] impl Monitor { + /// Check the network for the payjoin or fallback transaction via the + /// `transaction_exists` callback. + /// + /// Returns a [`MonitorTransition`] that, once persisted, completes + /// the session if a transaction is found. pub fn check_payment( &self, transaction_exists: Arc, @@ -1433,6 +1739,50 @@ impl Monitor { .map_err(|e| ImplementationError::new(e).into()) }))))) } + + /// Returns the txid of the fallback transaction. + pub fn extract_fallback_txid(&self) -> String { + self.0.clone().extract_fallback_txid().to_string() + } + + /// Returns the txid of the payjoin proposal transaction. + pub fn extract_payjoin_proposal_txid(&self) -> String { + self.0.clone().extract_payjoin_proposal_txid().to_string() + } + + /// Check whether the fallback transaction can be monitored. If the + /// fallback transaction includes non-SegWit inputs, the fallback + /// transaction ID can change when the sender signs again, making + /// monitoring impossible and concluding the session. + /// + /// Returns a [`MonitorTransition`] that, once persisted, yields a + /// [`Monitor`] to continue monitoring or completes the session if + /// monitoring is not possible. + pub fn check_fallback_monitorable(&self) -> MonitorTransition { + MonitorTransition(Arc::new(RwLock::new(Some(self.0.clone().check_fallback_monitorable())))) + } + + /// Signal that the fallback transaction exists on the network, + /// completing the session. + /// + /// Returns a [`MonitorTransition`] that, once persisted, completes + /// the session. + pub fn fallback_tx_exists(&self) -> MonitorTransition { + MonitorTransition(Arc::new(RwLock::new(Some(self.0.clone().fallback_tx_exists())))) + } + + /// Signal that the payjoin transaction exists on the network, + /// completing the session. + /// + /// Returns a [`MonitorTransition`] that, once persisted, completes + /// the session. + pub fn payjoin_tx_exists( + &self, + payjoin_tx: Vec, + ) -> Result { + let tx = try_deserialize_tx(payjoin_tx)?; + Ok(MonitorTransition(Arc::new(RwLock::new(Some(self.0.clone().payjoin_tx_exists(tx)))))) + } } /// Session persister that should save and load events as JSON strings. diff --git a/payjoin/src/core/receive/common/mod.rs b/payjoin/src/core/receive/common/mod.rs index 9a93fb8f0..635c96159 100644 --- a/payjoin/src/core/receive/common/mod.rs +++ b/payjoin/src/core/receive/common/mod.rs @@ -863,7 +863,7 @@ mod tests { .calculate_psbt_context_with_fee_range(None, None) .expect("Contributed inputs should allow for valid fee contributions"); let payjoin_proposal = - psbt_context.finalize_proposal(|_| Ok(processed_psbt.clone())).expect("Valid psbt"); + psbt_context.finalize_signed_proposal(processed_psbt.clone()).expect("Valid psbt"); assert!(payjoin_proposal.xpub.is_empty()); diff --git a/payjoin/src/core/receive/error.rs b/payjoin/src/core/receive/error.rs index d7bd23e39..bed31cec1 100644 --- a/payjoin/src/core/receive/error.rs +++ b/payjoin/src/core/receive/error.rs @@ -3,6 +3,7 @@ use std::{error, fmt}; use crate::error_codes::ErrorCode::{ self, NotEnoughMoney, OriginalPsbtRejected, Unavailable, VersionUnsupported, }; +use crate::ImplementationError; /// The top-level error type for the payjoin receiver #[derive(Debug)] @@ -29,6 +30,10 @@ impl From for Error { fn from(e: ProtocolError) -> Self { Error::Protocol(e) } } +impl From for Error { + fn from(e: ImplementationError) -> Self { Error::Implementation(e) } +} + impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { diff --git a/payjoin/src/core/receive/mod.rs b/payjoin/src/core/receive/mod.rs index 5b28abcf5..c7a9e9b90 100644 --- a/payjoin/src/core/receive/mod.rs +++ b/payjoin/src/core/receive/mod.rs @@ -10,6 +10,7 @@ //! version 1, refer to the `receive::v1` module documentation after enabling the `v1` feature. use std::collections::BTreeMap; +use std::marker::PhantomData; use std::str::FromStr; use bitcoin::transaction::InputWeightPrediction; @@ -228,6 +229,142 @@ impl<'a> From<&'a InputPair> for InternalInputPair<'a> { fn from(pair: &'a InputPair) -> Self { Self { psbtin: &pair.psbtin, txin: &pair.txin } } } +/// Holds a value that requires some form of boolean check. +#[derive(Debug)] +pub struct Reference { + value: V, + index: usize, + /// The final index in the set of [`Reference`]s to be checked + final_index: usize, + _tag: PhantomData, +} + +impl Reference +where + V: Clone, + T: Tag, +{ + fn new(value: V, index: usize, final_index: usize) -> Self { + Reference { value, index, final_index, _tag: PhantomData } + } + + /// Returns a [`TaggedReference`] that has been marked with the result of the boolean + /// check. + pub fn mark(&self, result: bool) -> TaggedReference { + TaggedReference { + value: self.value.clone(), + index: self.index, + final_index: self.final_index, + tag: T::new(result), + } + } + + /// Extracts the value to to be checked + pub fn get_value(&self) -> V { self.value.clone() } +} + +/// Holds the result of a checked [`Reference`]. Can only be created with [`Reference::mark`]. +#[derive(Debug)] +pub struct TaggedReference { + value: V, + index: usize, + final_index: usize, + tag: T, +} + +impl TaggedReference +where + V: Clone, + T: Tag, +{ + pub fn get_result(&self) -> bool { self.tag.result() } + pub fn get_index(&self) -> usize { self.index } + pub fn get_value(&self) -> V { self.value.clone() } +} + +/// Trait used to distinguish different types of validation +pub trait Tag { + fn new(result: bool) -> Self; + fn result(&self) -> bool; +} + +#[derive(Debug)] +pub struct InputOwnedTag { + is_owned: bool, +} + +impl Tag for InputOwnedTag { + fn new(result: bool) -> InputOwnedTag { InputOwnedTag { is_owned: result } } + fn result(&self) -> bool { self.is_owned } +} + +#[derive(Debug)] +pub struct InputSeenTag { + is_seen: bool, +} + +impl Tag for InputSeenTag { + fn new(result: bool) -> InputSeenTag { InputSeenTag { is_seen: result } } + fn result(&self) -> bool { self.is_seen } +} + +#[derive(Debug)] +pub struct OutputOwnedTag { + is_owned: bool, +} + +impl Tag for OutputOwnedTag { + fn new(result: bool) -> OutputOwnedTag { OutputOwnedTag { is_owned: result } } + fn result(&self) -> bool { self.is_owned } +} + +/// Helper function to run validation callback over a list of [`Reference`]s +pub fn check_references( + references: impl Iterator>, + check: &mut impl FnMut(&V) -> Result, +) -> Result>, ImplementationError> +where + V: Clone, + T: Tag, +{ + let mut checked_references: Vec> = vec![]; + for reference in references { + let result = check(&reference.get_value())?; + checked_references.push(reference.mark(result)); + } + Ok(checked_references.into_iter()) +} + +/// Validate that the [`TaggedReference`]s are in the correct order and are a complete set. +fn validate_checks( + checked_references: impl IntoIterator>, +) -> Result>, ImplementationError> +where + V: Clone, + T: Tag, +{ + let mut current_index = 0; + let mut is_complete = false; + let mut validated_refs: Vec> = vec![]; + for reference in checked_references { + if reference.get_index() != current_index { + return Err(ImplementationError::from( + "Missing reference check at index {current_index}", + )); + } + if reference.get_index() == reference.final_index { + is_complete = true + } else { + current_index += 1; + } + validated_refs.push(reference); + } + if !is_complete { + return Err(ImplementationError::from("Missing reference check at index {current_index}")); + } + Ok(validated_refs.into_iter()) +} + /// Validate the payload of a Payjoin request for PSBT and Params sanity pub(crate) fn parse_payload( base64: &str, @@ -254,7 +391,7 @@ pub struct PsbtContext { impl PsbtContext { /// Prepare the PSBT by creating a new PSBT and copying only the fields allowed by the [spec](https://github.com/bitcoin/bips/blob/master/bip-0078.mediawiki#senders-payjoin-proposal-checklist) - fn prepare_psbt(self, processed_psbt: Psbt) -> Psbt { + fn prepare_psbt(&self, processed_psbt: Psbt) -> Psbt { tracing::trace!("Original PSBT from callback: {processed_psbt:#?}"); // Create a new PSBT and copy only the allowed fields @@ -329,16 +466,12 @@ impl PsbtContext { psbt } - /// Finalizes the Payjoin proposal into a PSBT which the sender will find acceptable before + /// Finalizes the signed payjoin proposal PSBT which the sender will find acceptable before /// they sign the transaction and broadcast it to the network. /// - /// Finalization consists of signing and finalizing the PSBT using the passed `wallet_process_psbt` signing function. - fn finalize_proposal( - self, - wallet_process_psbt: impl Fn(&Psbt) -> Result, - ) -> Result { - let psbt = self.psbt_to_sign(); - let signed_psbt = wallet_process_psbt(&psbt)?; + /// Returns a final payjoin proposal PSBT after verifying the signed PSBT matches the payjoin + /// proposal PSBT and sanitizing it. + fn finalize_signed_proposal(&self, signed_psbt: Psbt) -> Result { let expected_ntxid = self.payjoin_psbt.unsigned_tx.compute_ntxid(); let actual_ntxid = signed_psbt.unsigned_tx.compute_ntxid(); if expected_ntxid != actual_ntxid { @@ -370,6 +503,17 @@ impl OriginalPayload { &self, min_fee_rate: Option, can_broadcast: impl Fn(&bitcoin::Transaction) -> Result, + ) -> Result<(), Error> { + self.apply_broadcast_suitability( + min_fee_rate, + can_broadcast(&self.psbt.clone().extract_tx_unchecked_fee_rate())?, + ) + } + + pub fn apply_broadcast_suitability( + &self, + min_fee_rate: Option, + can_broadcast: bool, ) -> Result<(), Error> { let original_psbt_fee_rate = self.psbt_fee_rate()?; if let Some(min_fee_rate) = min_fee_rate { @@ -381,80 +525,144 @@ impl OriginalPayload { .into()); } } - if can_broadcast(&self.psbt.clone().extract_tx_unchecked_fee_rate()) - .map_err(Error::Implementation)? - { + if can_broadcast { Ok(()) } else { Err(InternalPayloadError::OriginalPsbtNotBroadcastable.into()) } } - /// Check that the original PSBT has no receiver-owned inputs. + /// Check that the original PSBT has no receiver owned inputs. /// /// An attacker can try to spend the receiver's own inputs. This check prevents that. pub fn check_inputs_not_owned( &self, is_owned: &mut impl FnMut(&Script) -> Result, ) -> Result<(), Error> { - let mut err: Result<(), Error> = Ok(()); - if let Some(e) = self + let checked_inputs = + check_references(self.get_input_script_refs()?, &mut |script: &ScriptBuf| { + is_owned(script.as_script()) + })?; + self.apply_input_owned_checks(checked_inputs) + } + + pub fn get_input_script_refs( + &self, + ) -> Result>, Error> { + let final_index = self.psbt.input_pairs().count() - 1; + let script_references = self .psbt .input_pairs() - .scan(&mut err, |err, input| match input.previous_txout() { - Ok(txout) => Some(txout.script_pubkey.to_owned()), - Err(e) => { - **err = Err(InternalPayloadError::PrevTxOut(e).into()); - None - } - }) - .find_map(|script| match is_owned(&script) { - Ok(false) => None, - Ok(true) => Some(InternalPayloadError::InputOwned(script).into()), - Err(e) => Some(Error::Implementation(e)), + .enumerate() + .map(|(index, input)| match input.previous_txout() { + Ok(txout) => Ok(Reference::::new( + txout.script_pubkey.to_owned(), + index, + final_index, + )), + Err(e) => Err(InternalPayloadError::PrevTxOut(e)), }) - { - return Err(e); + .collect::>, InternalPayloadError>>()?; + Ok(script_references.into_iter()) + } + + pub fn apply_input_owned_checks( + &self, + checked_input_scripts: impl IntoIterator>, + ) -> Result<(), Error> { + let validated_checks = validate_checks(checked_input_scripts)?; + match validated_checks.into_iter().find(|checked_input| checked_input.get_result()) { + Some(checked_input) => + Err(InternalPayloadError::InputOwned(checked_input.get_value()).into()), + None => Ok(()), } - err?; - Ok(()) } pub fn check_no_inputs_seen_before( &self, is_known: &mut impl FnMut(&OutPoint) -> Result, ) -> Result<(), Error> { - self.psbt.input_pairs().try_for_each(|input| { - match is_known(&input.txin.previous_output) { - Ok(false) => Ok::<(), Error>(()), - Ok(true) => { - tracing::warn!("Request contains an input we've seen before: {}. Preventing possible probing attack.", input.txin.previous_output); - Err(InternalPayloadError::InputSeen(input.txin.previous_output))? - }, - Err(e) => Err(Error::Implementation(e))?, + let checked_inputs = check_references(self.get_input_outpoint_refs(), is_known)?; + self.apply_input_seen_checks(checked_inputs) + } + + pub fn get_input_outpoint_refs( + &self, + ) -> impl Iterator> { + let final_index = self.psbt.input_pairs().count() - 1; + let outpoint_references = self + .psbt + .input_pairs() + .enumerate() + .map(|(index, input)| { + Reference::::new( + input.txin.previous_output, + index, + final_index, + ) + }) + .collect::>(); + outpoint_references.into_iter() + } + + pub fn apply_input_seen_checks( + &self, + checked_input_outpoints: impl IntoIterator>, + ) -> Result<(), Error> { + let validated_checks = validate_checks(checked_input_outpoints)?; + match validated_checks.into_iter().find(|checked_input| checked_input.get_result()) { + Some(checked_input) => { + tracing::warn!("Request contains an input we've seen before: {}. Preventing possible probing attack.", checked_input.get_value()); + Err(InternalPayloadError::InputSeen(checked_input.get_value()))? } - })?; - Ok(()) + None => Ok(()), + } } pub fn identify_receiver_outputs( self, is_receiver_output: &mut impl FnMut(&Script) -> Result, ) -> Result { - let owned_vouts: Vec = self + let checked_outputs = + check_references(self.get_output_script_refs(), &mut |script: &ScriptBuf| { + is_receiver_output(script.as_script()) + })?; + self.apply_output_owned_checks(checked_outputs) + } + + pub fn get_output_script_refs( + &self, + ) -> impl Iterator> { + let final_index = self.psbt.unsigned_tx.output.len() - 1; + let script_references = self .psbt .unsigned_tx .output .iter() .enumerate() - .filter_map(|(vout, txo)| match is_receiver_output(&txo.script_pubkey) { - Ok(true) => Some(Ok(vout)), - Ok(false) => None, - Err(e) => Some(Err(e)), + .map(|(index, output)| { + Reference::::new( + output.script_pubkey.clone(), + index, + final_index, + ) }) - .collect::, _>>() - .map_err(Error::Implementation)?; + .collect::>(); + script_references.into_iter() + } + pub fn apply_output_owned_checks( + &self, + checked_output_scripts: impl IntoIterator>, + ) -> Result { + let validated_checks = validate_checks(checked_output_scripts)?; + let owned_vouts = validated_checks + .into_iter() + .filter_map(|checked_output| match checked_output.get_result() { + true => Some(checked_output.get_index()), + false => None, + }) + .collect::>(); if owned_vouts.is_empty() { return Err(InternalPayloadError::MissingPayment.into()); } @@ -484,7 +692,9 @@ pub(crate) mod tests { witness, Amount, PubkeyHash, ScriptBuf, ScriptHash, Sequence, Txid, WScriptHash, XOnlyPublicKey, }; - use payjoin_test_utils::{DUMMY20, DUMMY32, PARSED_ORIGINAL_PSBT, QUERY_PARAMS}; + use payjoin_test_utils::{ + DUMMY20, DUMMY32, PARSED_ORIGINAL_PSBT, PARSED_PAYJOIN_PROPOSAL, QUERY_PARAMS, + }; use super::*; use crate::psbt::InternalPsbtInputError::InvalidScriptPubKey; @@ -496,6 +706,24 @@ pub(crate) mod tests { OriginalPayload { psbt: PARSED_ORIGINAL_PSBT.clone(), params } } + pub(crate) fn original_missing_prevtxout_from_test_vector() -> OriginalPayload { + let params = Params::from_query_str(QUERY_PARAMS, &[Version::One]) + .expect("Could not parse params from query str"); + let mut psbt: Psbt = PARSED_ORIGINAL_PSBT.clone(); + for psbtin in psbt.inputs_mut() { + psbtin.non_witness_utxo = None; + psbtin.witness_utxo = None; + } + OriginalPayload { psbt: psbt.clone(), params } + } + + pub(crate) fn psbt_context_from_test_vector() -> PsbtContext { + PsbtContext { + payjoin_psbt: PARSED_PAYJOIN_PROPOSAL.clone(), + original_psbt: PARSED_ORIGINAL_PSBT.clone(), + } + } + #[test] fn input_pair_with_expected_weight() { let p2wsh_txout = TxOut { @@ -830,6 +1058,141 @@ pub(crate) mod tests { assert_eq!(err, PsbtInputError::from(InternalPsbtInputError::ProvidedUnnecessaryWeight)); } + #[test] + fn test_check_broadcast_suitability() { + let original = original_from_test_vector(); + + // Outcome 1: min_fee_rate too high → PsbtBelowFeeRate error + let err = original + .clone() + .check_broadcast_suitability(Some(FeeRate::MAX), |_| Ok(true)) + .expect_err("Should fail when fee rate is below minimum"); + match err { + Error::Protocol(ProtocolError::OriginalPayload(PayloadError( + InternalPayloadError::PsbtBelowFeeRate(original_fee_rate, min_fee_rate), + ))) => { + assert_eq!(original_fee_rate, original.psbt_fee_rate().unwrap()); + assert_eq!(min_fee_rate, FeeRate::MAX); + } + _ => panic!("Expected PsbtBelowFeeRate error, got: {err:?}"), + } + + // Outcome 2: can_broadcast returns false → OriginalPsbtNotBroadcastable error + let err = original + .clone() + .check_broadcast_suitability(None, |_| Ok(false)) + .expect_err("Should fail when can_broadcast returns false"); + match err { + Error::Protocol(ProtocolError::OriginalPayload(PayloadError( + InternalPayloadError::OriginalPsbtNotBroadcastable, + ))) => {} + _ => panic!("Expected OriginalPsbtNotBroadcastable error, got: {err:?}"), + } + + // Outcome 3: can_broadcast returns an implementation error → Error::Implementation + let err = original + .clone() + .check_broadcast_suitability(None, |_| { + Err(ImplementationError::from("broadcast check failed")) + }) + .expect_err("Should fail when can_broadcast returns an implementation error"); + match err { + Error::Implementation(error_message) => { + assert_eq!(error_message.to_string(), "broadcast check failed".to_string()) + } + _ => panic!("Expected Error::Implementation, got: {err:?}"), + } + + // Outcome 4: success + original + .check_broadcast_suitability(None, |_| Ok(true)) + .expect("Should succeed when fee rate is acceptable and can_broadcast returns true"); + } + + #[test] + fn test_check_inputs_not_owned() { + let original = original_from_test_vector(); + let original_missing_prevtxout = original_missing_prevtxout_from_test_vector(); + + // Outcome 1: input_scripts returns a PrevTxOut error → Protocol error + let err = original_missing_prevtxout + .check_inputs_not_owned(&mut |_| Ok(false)) + .expect_err("Should fail when previous txout is missing"); + match err { + Error::Protocol(ProtocolError::OriginalPayload(PayloadError( + InternalPayloadError::PrevTxOut(_), + ))) => {} + _ => panic!("Expected PrevTxOut error, got: {err:?}"), + } + + // Outcome 2: is_owned returns true → InputOwned error + let err = original + .clone() + .check_inputs_not_owned(&mut |_| Ok(true)) + .expect_err("Should fail when input is owned"); + match err { + Error::Protocol(ProtocolError::OriginalPayload(PayloadError( + InternalPayloadError::InputOwned(_), + ))) => {} + _ => panic!("Expected InputOwned error, got: {err:?}"), + } + + // Outcome 3: is_owned returns an implementation error → Error::Implementation + let err = original + .clone() + .check_inputs_not_owned(&mut |_| { + Err(ImplementationError::from("ownership check failed")) + }) + .expect_err("Should fail when is_owned returns an implementation error"); + match err { + Error::Implementation(error_message) => { + assert_eq!(error_message.to_string(), "ownership check failed".to_string()) + } + _ => panic!("Expected Error::Implementation, got: {err:?}"), + } + + // Outcome 4: is_owned returns false → success + original + .check_inputs_not_owned(&mut |_| Ok(false)) + .expect("Should succeed when no inputs are owned"); + } + + #[test] + fn test_check_no_inputs_seen_before() { + let original = original_from_test_vector(); + + // Outcome 1: is_known returns true → InputSeen error + let err = original + .clone() + .check_no_inputs_seen_before(&mut |_| Ok(true)) + .expect_err("Should fail when input has been seen before"); + match err { + Error::Protocol(ProtocolError::OriginalPayload(PayloadError( + InternalPayloadError::InputSeen(_), + ))) => {} + _ => panic!("Expected InputSeen error, got: {err:?}"), + } + + // Outcome 2: is_known returns an implementation error → Error::Implementation + let err = original + .clone() + .check_no_inputs_seen_before(&mut |_| { + Err(ImplementationError::from("input seen check failed")) + }) + .expect_err("Should fail when is_known returns an implementation error"); + match err { + Error::Implementation(error_message) => { + assert_eq!(error_message.to_string(), "input seen check failed".to_string()) + } + _ => panic!("Expected Error::Implementation, got: {err:?}"), + } + + // Outcome 3: is_known returns false → success + original + .check_no_inputs_seen_before(&mut |_| Ok(false)) + .expect("Should succeed when no inputs have been seen before"); + } + #[test] fn test_identify_receiver_outputs() { let original = original_from_test_vector(); @@ -864,4 +1227,23 @@ pub(crate) mod tests { assert_eq!(wants_outputs.owned_vouts, vec![0, 1]); assert_eq!(wants_outputs.params.additional_fee_contribution, None); } + + #[test] + fn test_finalize_proposal() { + // Outcome 1: wallet_process_psbt returns a psbt with mismatched ntxid → ImplementationError + let psbt_context = psbt_context_from_test_vector(); + let err = psbt_context + .clone() + .finalize_signed_proposal( + // return a totally different psbt to trigger ntxid mismatch + PARSED_ORIGINAL_PSBT.clone(), + ) + .expect_err("Should fail when ntxid mismatches"); + assert!(err.to_string().contains("Ntxid mismatch")); + + // Outcome 2: wallet_process_psbt succeeds → Ok(Psbt) + let _psbt = psbt_context + .finalize_signed_proposal(PARSED_PAYJOIN_PROPOSAL.clone()) + .expect("Should succeed when wallet_process_psbt returns a valid signed psbt"); + } } diff --git a/payjoin/src/core/receive/v1/mod.rs b/payjoin/src/core/receive/v1/mod.rs index a43ec727e..e102593e9 100644 --- a/payjoin/src/core/receive/v1/mod.rs +++ b/payjoin/src/core/receive/v1/mod.rs @@ -113,6 +113,42 @@ impl UncheckedOriginalPayload { Ok(MaybeInputsOwned { original: self.original }) } + /// Extracts the original PSBT so caller can check that the proposal can be broadcasted. + /// + /// Result of the broadcastibility check should then be returned to + /// [`Self::apply_broadcast_suitability`]. + /// + /// If the receiver is a non-interactive payment processor (ex. a donation page which generates + /// a new QR code for each visit), then it should make sure that the original PSBT is broadcastable + /// as a fallback mechanism in case the payjoin fails. This validation would be equivalent to + /// `testmempoolaccept` Bitcoin Core RPC call returning `{"allowed": true,...}`. + pub fn extract_tx_to_check_broadcast_suitability(&self) -> bitcoin::Transaction { + self.original.psbt.clone().extract_tx_unchecked_fee_rate() + } + + /// Processes the result of whether the original PSBT in the proposal can be broadcasted. + /// + /// Call [`Self::extract_tx_to_check_broadcast_suitability`] first to acquire the tx + /// to be checked for broadcastibility. + /// + /// If the receiver is a non-interactive payment processor (ex. a donation page which generates + /// a new QR code for each visit), then it should make sure that the original PSBT is broadcastable + /// as a fallback mechanism in case the payjoin fails. This validation would be equivalent to + /// `testmempoolaccept` Bitcoin Core RPC call returning `{"allowed": true,...}`. + /// + /// Receiver can optionally set a minimum fee rate which will be enforced on the original PSBT in the proposal. + /// This can be used to further prevent probing attacks since the attacker would now need to probe the receiver + /// with transactions which are both broadcastable and pay high fee. Unrelated to the probing attack scenario, + /// this parameter also makes operating in a high fee environment easier for the receiver. + pub fn apply_broadcast_suitability( + self, + min_fee_rate: Option, + is_broadcast_suitable: bool, + ) -> Result { + self.original.apply_broadcast_suitability(min_fee_rate, is_broadcast_suitable)?; + Ok(MaybeInputsOwned { original: self.original }) + } + /// Moves on to the next typestate without any of the current typestate's validations. /// /// Use this for interactive payment receivers, where there is no risk of a probing attack since the @@ -151,7 +187,35 @@ impl MaybeInputsOwned { self, is_owned: &mut impl FnMut(&Script) -> Result, ) -> Result { - self.original.check_inputs_not_owned(is_owned)?; + let checked_inputs = + check_references(self.get_input_script_refs()?, &mut |script: &ScriptBuf| { + is_owned(script.as_script()) + })?; + self.apply_input_owned_checks(checked_inputs) + } + + /// Get [`Reference`]s that hold the input scripts that need to be checked for ownership by the + /// receiver. + /// + /// Once completed, these checks should be submitted to [`Self::apply_input_owned_checks`]. + /// + /// An attacker can try to spend the receiver's own inputs. This check prevents that. + pub fn get_input_script_refs( + &self, + ) -> Result>, Error> { + self.original.get_input_script_refs() + } + + /// Applies the input ownership checks to advance the state machine. + /// + /// Use [`Self::get_input_script_refs`] to obtain the references that need to be checked. + /// + /// An attacker can try to spend the receiver's own inputs. This check prevents that. + pub fn apply_input_owned_checks( + self, + checked_input_scripts: impl IntoIterator>, + ) -> Result { + self.original.apply_input_owned_checks(checked_input_scripts)?; Ok(MaybeInputsSeen { original: self.original }) } } @@ -176,7 +240,42 @@ impl MaybeInputsSeen { self, is_known: &mut impl FnMut(&OutPoint) -> Result, ) -> Result { - self.original.check_no_inputs_seen_before(is_known)?; + let checked_inputs = check_references(self.get_input_outpoint_refs(), is_known)?; + self.apply_input_seen_checks(checked_inputs) + } + + /// Get [`Reference`]s that hold the input outpoints that need to be checked for whether they + /// have already been seen by the receiver. + /// + /// Once completed, these checks should be submitted to [`Self::apply_input_seen_checks`]. + /// + /// This check prevents the following attacks: + /// 1. Probing attacks, where the sender can use the exact same proposal (or with minimal change) + /// to have the receiver reveal their UTXO set by contributing to all proposals with different inputs + /// and sending them back to the receiver. + /// 2. Re-entrant payjoin, where the sender uses the payjoin PSBT of a previous payjoin as the + /// original proposal PSBT of the current, new payjoin. + pub fn get_input_outpoint_refs( + &self, + ) -> impl Iterator> { + self.original.get_input_outpoint_refs() + } + + /// Applies the input seen checks to advance the state machine. + /// + /// Use [`Self::get_input_outpoint_refs`] to obtain the references that need to be checked. + /// + /// This check prevents the following attacks: + /// 1. Probing attacks, where the sender can use the exact same proposal (or with minimal change) + /// to have the receiver reveal their UTXO set by contributing to all proposals with different inputs + /// and sending them back to the receiver. + /// 2. Re-entrant payjoin, where the sender uses the payjoin PSBT of a previous payjoin as the + /// original proposal PSBT of the current, new payjoin. + pub fn apply_input_seen_checks( + self, + checked_input_outpoints: impl IntoIterator>, + ) -> Result { + self.original.apply_input_seen_checks(checked_input_outpoints)?; Ok(OutputsUnknown { original: self.original }) } } @@ -208,7 +307,49 @@ impl OutputsUnknown { self, is_receiver_output: &mut impl FnMut(&Script) -> Result, ) -> Result { - self.original.identify_receiver_outputs(is_receiver_output) + let checked_outputs = + check_references(self.get_output_script_refs(), &mut |script: &ScriptBuf| { + is_receiver_output(script.as_script()) + })?; + self.apply_output_owned_checks(checked_outputs) + } + + /// Get [`Reference`]s that hold the output scripts that need to be checked for ownership + /// by the receiver. + /// + /// Once completed, these checks should be submitted to [`Self::apply_output_owned_checks`]. + /// + /// Additionally, this function also protects the receiver from accidentally subtracting fees + /// from their own outputs: when a sender is sending a proposal, + /// they can select an output which they want the receiver to subtract fees from to account for + /// the increased transaction size. If a sender specifies a receiver output for this purpose, this + /// function sets that parameter to None so that it is ignored in subsequent steps of the + /// receiver flow. This protects the receiver from accidentally subtracting fees from their own + /// outputs. + #[cfg_attr(not(feature = "v1"), allow(dead_code))] + pub fn get_output_script_refs( + &self, + ) -> impl Iterator> { + self.original.get_output_script_refs() + } + + /// Applies the output owned checks to advance the state machine. + /// + /// Use [`Self::get_output_script_refs`] to obtain the references that need to be checked. + /// + /// Additionally, this function also protects the receiver from accidentally subtracting fees + /// from their own outputs: when a sender is sending a proposal, + /// they can select an output which they want the receiver to subtract fees from to account for + /// the increased transaction size. If a sender specifies a receiver output for this purpose, this + /// function sets that parameter to None so that it is ignored in subsequent steps of the + /// receiver flow. This protects the receiver from accidentally subtracting fees from their own + /// outputs. + #[cfg_attr(not(feature = "v1"), allow(dead_code))] + pub fn apply_output_owned_checks( + &self, + checked_output_scripts: impl IntoIterator>, + ) -> Result { + self.original.apply_output_owned_checks(checked_output_scripts) } } @@ -290,11 +431,9 @@ impl ProvisionalProposal { self, wallet_process_psbt: impl Fn(&Psbt) -> Result, ) -> Result { - let finalized_psbt = self - .psbt_context - .finalize_proposal(wallet_process_psbt) - .map_err(|e| Error::Implementation(ImplementationError::new(e)))?; - Ok(PayjoinProposal { payjoin_psbt: finalized_psbt }) + let psbt = self.psbt_to_sign(); + let signed_psbt = wallet_process_psbt(&psbt)?; + self.finalize_signed_proposal(&signed_psbt) } /// The Payjoin proposal PSBT that the receiver needs to sign @@ -303,6 +442,17 @@ impl ProvisionalProposal { /// is different from the entity that has access to the private keys, /// so the PSBT to sign must be accessible to such implementers. pub fn psbt_to_sign(&self) -> Psbt { self.psbt_context.psbt_to_sign() } + + /// Finalizes the Payjoin proposal into a PSBT which the sender will find acceptable before + /// they sign the transaction and broadcast it to the network. + /// + /// This takes a receiver signed PSBT payjoin proposal and finalizes it for broadcast to + /// the sender. Use [`Self::psbt_to_sign`] to obtain the payjoin proposal's unsigned + /// PSBT for receiver to sign and return here. + pub fn finalize_signed_proposal(self, signed_psbt: &Psbt) -> Result { + let finalized_psbt = self.psbt_context.finalize_signed_proposal(signed_psbt.clone())?; + Ok(PayjoinProposal { payjoin_psbt: finalized_psbt }) + } } /// A finalized Payjoin proposal, complete with fees and receiver signatures, that the sender diff --git a/payjoin/src/core/receive/v2/mod.rs b/payjoin/src/core/receive/v2/mod.rs index 569ca7d6e..fb3e0d12a 100644 --- a/payjoin/src/core/receive/v2/mod.rs +++ b/payjoin/src/core/receive/v2/mod.rs @@ -30,7 +30,7 @@ use std::time::Duration; use bitcoin::hashes::{sha256, Hash}; use bitcoin::psbt::Psbt; -use bitcoin::{Address, Amount, FeeRate, OutPoint, Script, TxOut, Txid}; +use bitcoin::{Address, Amount, FeeRate, OutPoint, Script, ScriptBuf, Transaction, TxOut, Txid}; pub(crate) use error::InternalSessionError; pub use error::SessionError; use serde::de::Deserializer; @@ -57,7 +57,10 @@ use crate::persist::{ MaybeFatalOrSuccessTransition, MaybeFatalTransition, MaybeFatalTransitionWithNoResults, MaybeSuccessTransition, MaybeTransientTransition, NextStateTransition, TerminalTransition, }; -use crate::receive::{parse_payload, InputPair, OriginalPayload, PsbtContext}; +use crate::receive::{ + check_references, parse_payload, InputOwnedTag, InputPair, InputSeenTag, OriginalPayload, + OutputOwnedTag, PsbtContext, Reference, TaggedReference, +}; use crate::time::Time; use crate::uri::ShortId; use crate::{ImplementationError, IntoUrl, IntoUrlError, Request, Version}; @@ -640,7 +643,68 @@ impl Receiver { Error, Receiver, > { - match self.state.original.check_broadcast_suitability(min_fee_rate, can_broadcast) { + let tx = self.extract_tx_to_check_broadcast_suitability(); + match can_broadcast(&tx) { + Ok(is_broadcast_suitable) => + self.apply_broadcast_suitability(min_fee_rate, is_broadcast_suitable), + Err(e) => MaybeFatalTransition::transient(e.into()), + } + } + + /// Moves on to the next typestate without any of the current typestate's validations. + /// + /// Use this for interactive payment receivers, where there is no risk of a probing attack since the + /// receiver needs to manually create payjoin URIs. + pub fn assume_interactive_receiver( + self, + ) -> NextStateTransition> { + NextStateTransition::success( + SessionEvent::CheckedBroadcastSuitability(), + Receiver { + state: MaybeInputsOwned { original: self.original.clone() }, + session_context: self.session_context, + }, + ) + } + + /// Extracts the original PSBT so caller can check that the proposal can be broadcasted. + /// + /// Result of the broadcastibility check should then be returned to + /// [`Receiver::apply_broadcast_suitability`]. + /// + /// If the receiver is a non-interactive payment processor (ex. a donation page which generates + /// a new QR code for each visit), then it should make sure that the original PSBT is broadcastable + /// as a fallback mechanism in case the payjoin fails. This validation would be equivalent to + /// `testmempoolaccept` Bitcoin Core RPC call returning `{"allowed": true,...}`. + pub fn extract_tx_to_check_broadcast_suitability(&self) -> bitcoin::Transaction { + self.original.psbt.clone().extract_tx_unchecked_fee_rate() + } + + /// Processes the result of whether the original PSBT in the proposal can be broadcasted. + /// + /// Call [`Receiver::extract_tx_to_check_broadcast_suitability`] first to + /// acquire the tx to be checked for broadcastibility. + /// + /// If the receiver is a non-interactive payment processor (ex. a donation page which generates + /// a new QR code for each visit), then it should make sure that the original PSBT is broadcastable + /// as a fallback mechanism in case the payjoin fails. This validation would be equivalent to + /// `testmempoolaccept` Bitcoin Core RPC call returning `{"allowed": true,...}`. + /// + /// Receiver can optionally set a minimum fee rate which will be enforced on the original PSBT in the proposal. + /// This can be used to further prevent probing attacks since the attacker would now need to probe the receiver + /// with transactions which are both broadcastable and pay high fee. Unrelated to the probing attack scenario, + /// this parameter also makes operating in a high fee environment easier for the receiver. + pub fn apply_broadcast_suitability( + self, + min_fee_rate: Option, + is_broadcast_suitable: bool, + ) -> MaybeFatalTransition< + SessionEvent, + Receiver, + Error, + Receiver, + > { + match self.state.original.apply_broadcast_suitability(min_fee_rate, is_broadcast_suitable) { Ok(()) => MaybeFatalTransition::success( SessionEvent::CheckedBroadcastSuitability(), Receiver { @@ -661,22 +725,6 @@ impl Receiver { } } - /// Moves on to the next typestate without any of the current typestate's validations. - /// - /// Use this for interactive payment receivers, where there is no risk of a probing attack since the - /// receiver needs to manually create payjoin URIs. - pub fn assume_interactive_receiver( - self, - ) -> NextStateTransition> { - NextStateTransition::success( - SessionEvent::CheckedBroadcastSuitability(), - Receiver { - state: MaybeInputsOwned { original: self.original.clone() }, - session_context: self.session_context, - }, - ) - } - pub(crate) fn apply_checked_broadcast_suitability(self) -> ReceiveSession { let new_state = Receiver { state: MaybeInputsOwned { original: self.original.clone() }, @@ -720,31 +768,74 @@ impl Receiver { Error, Receiver, > { - match self.state.original.check_inputs_not_owned(is_owned) { - Ok(inner) => inner, + match self.get_input_script_refs() { + Ok(input_scripts) => match check_references(input_scripts, &mut |script: &ScriptBuf| { + is_owned(script.as_script()) + }) { + Ok(checked_input_scripts) => self.apply_input_owned_checks(checked_input_scripts), + Err(e) => MaybeFatalTransition::transient(e.into()), + }, Err(e) => match e { - Error::Implementation(_) => { - return MaybeFatalTransition::transient(e); - } - _ => { - return MaybeFatalTransition::replyable_error( - SessionEvent::GotReplyableError((&e).into()), - Receiver { - state: HasReplyableError { error_reply: (&e).into() }, - session_context: self.session_context, - }, - e, - ); - } + Error::Implementation(_) => MaybeFatalTransition::transient(e), + _ => MaybeFatalTransition::replyable_error( + SessionEvent::GotReplyableError((&e).into()), + Receiver { + state: HasReplyableError { error_reply: (&e).into() }, + session_context: self.session_context, + }, + e, + ), }, - }; - MaybeFatalTransition::success( - SessionEvent::CheckedInputsNotOwned(), - Receiver { - state: MaybeInputsSeen { original: self.original.clone() }, - session_context: self.session_context, + } + } + + /// Get [`Reference`]s that hold the input scripts that need to be checked for ownership by the + /// receiver. + /// + /// Once completed, these checks should be submitted to + /// [`Receiver::apply_input_owned_checks`]. + /// + /// An attacker can try to spend the receiver's own inputs. This check prevents that. + pub fn get_input_script_refs( + &self, + ) -> Result>, Error> { + self.state.original.get_input_script_refs() + } + + /// Applies the input ownership checks to advance the state machine. + /// + /// Use [`Receiver::get_input_script_refs`] to obtain the references that need to be checked. + /// + /// An attacker can try to spend the receiver's own inputs. This check prevents that. + pub fn apply_input_owned_checks( + self, + checked_input_scripts: impl IntoIterator>, + ) -> MaybeFatalTransition< + SessionEvent, + Receiver, + Error, + Receiver, + > { + match self.state.original.apply_input_owned_checks(checked_input_scripts) { + Ok(()) => MaybeFatalTransition::success( + SessionEvent::CheckedInputsNotOwned(), + Receiver { + state: MaybeInputsSeen { original: self.original.clone() }, + session_context: self.session_context, + }, + ), + Err(e) => match e { + Error::Implementation(_) => MaybeFatalTransition::transient(e), + _ => MaybeFatalTransition::replyable_error( + SessionEvent::GotReplyableError((&e).into()), + Receiver { + state: HasReplyableError { error_reply: (&e).into() }, + session_context: self.session_context, + }, + e, + ), }, - ) + } } pub(crate) fn apply_checked_inputs_not_owned(self) -> ReceiveSession { @@ -782,31 +873,69 @@ impl Receiver { Error, Receiver, > { - match self.state.original.check_no_inputs_seen_before(is_known) { - Ok(inner) => inner, + match check_references(self.get_input_outpoint_refs(), is_known) { + Ok(checked_input_outpoints) => self.apply_input_seen_checks(checked_input_outpoints), + Err(e) => MaybeFatalTransition::transient(e.into()), + } + } + + /// Get [`Reference`]s that hold the input outpoints that need to be checked for whether they + /// have already been seen by the receiver. + /// + /// Once completed, these checks should be submitted to + /// [`Receiver::apply_input_seen_checks`]. + /// + /// This check prevents the following attacks: + /// 1. Probing attacks, where the sender can use the exact same proposal (or with minimal change) + /// to have the receiver reveal their UTXO set by contributing to all proposals with different inputs + /// and sending them back to the receiver. + /// 2. Re-entrant payjoin, where the sender uses the payjoin PSBT of a previous payjoin as the + /// original proposal PSBT of the current, new payjoin. + pub fn get_input_outpoint_refs( + &self, + ) -> impl Iterator> { + self.state.original.get_input_outpoint_refs() + } + + /// Applies the input seen checks to advance the state machine. + /// + /// Use [`Receiver::get_input_outpoint_refs`] to obtain the references that need to be checked. + /// + /// This check prevents the following attacks: + /// 1. Probing attacks, where the sender can use the exact same proposal (or with minimal change) + /// to have the receiver reveal their UTXO set by contributing to all proposals with different inputs + /// and sending them back to the receiver. + /// 2. Re-entrant payjoin, where the sender uses the payjoin PSBT of a previous payjoin as the + /// original proposal PSBT of the current, new payjoin. + pub fn apply_input_seen_checks( + self, + checked_input_outpoints: impl IntoIterator>, + ) -> MaybeFatalTransition< + SessionEvent, + Receiver, + Error, + Receiver, + > { + match self.state.original.apply_input_seen_checks(checked_input_outpoints) { + Ok(()) => MaybeFatalTransition::success( + SessionEvent::CheckedNoInputsSeenBefore(), + Receiver { + state: OutputsUnknown { original: self.original.clone() }, + session_context: self.session_context, + }, + ), Err(e) => match e { - Error::Implementation(_) => { - return MaybeFatalTransition::transient(e); - } - _ => { - return MaybeFatalTransition::replyable_error( - SessionEvent::GotReplyableError((&e).into()), - Receiver { - state: HasReplyableError { error_reply: (&e).into() }, - session_context: self.session_context, - }, - e, - ); - } - }, - }; - MaybeFatalTransition::success( - SessionEvent::CheckedNoInputsSeenBefore(), - Receiver { - state: OutputsUnknown { original: self.original.clone() }, - session_context: self.session_context, + Error::Implementation(_) => MaybeFatalTransition::transient(e), + _ => MaybeFatalTransition::replyable_error( + SessionEvent::GotReplyableError((&e).into()), + Receiver { + state: HasReplyableError { error_reply: (&e).into() }, + session_context: self.session_context, + }, + e, + ), }, - ) + } } pub(crate) fn apply_checked_no_inputs_seen_before(self) -> ReceiveSession { @@ -849,28 +978,71 @@ impl Receiver { Error, Receiver, > { - let inner = match self.state.original.identify_receiver_outputs(is_receiver_output) { - Ok(inner) => inner, + match check_references(self.get_output_script_refs(), &mut |script: &ScriptBuf| { + is_receiver_output(script.as_script()) + }) { + Ok(checked_output_scripts) => self.apply_output_owned_checks(checked_output_scripts), + Err(e) => MaybeFatalTransition::transient(e.into()), + } + } + + /// Get [`Reference`]s that hold the output scripts that need to be checked for ownership + /// by the receiver. + /// + /// Once completed, these checks should be submitted to + /// [`Receiver::apply_output_owned_checks`]. + /// + /// Additionally, this function also protects the receiver from accidentally subtracting fees + /// from their own outputs: when a sender is sending a proposal, + /// they can select an output which they want the receiver to subtract fees from to account for + /// the increased transaction size. If a sender specifies a receiver output for this purpose, this + /// function sets that parameter to None so that it is ignored in subsequent steps of the + /// receiver flow. This protects the receiver from accidentally subtracting fees from their own + /// outputs. + pub fn get_output_script_refs( + &self, + ) -> impl Iterator> { + self.state.original.get_output_script_refs() + } + + /// Applies the output owned checks to advance the state machine. + /// + /// Use [`Receiver::get_output_script_refs`] to obtain the references that need + /// to be checked. + /// + /// Additionally, this function also protects the receiver from accidentally subtracting fees + /// from their own outputs: when a sender is sending a proposal, + /// they can select an output which they want the receiver to subtract fees from to account for + /// the increased transaction size. If a sender specifies a receiver output for this purpose, this + /// function sets that parameter to None so that it is ignored in subsequent steps of the + /// receiver flow. This protects the receiver from accidentally subtracting fees from their own + /// outputs. + pub fn apply_output_owned_checks( + self, + checked_output_scripts: impl IntoIterator>, + ) -> MaybeFatalTransition< + SessionEvent, + Receiver, + Error, + Receiver, + > { + match self.state.original.apply_output_owned_checks(checked_output_scripts) { + Ok(inner) => MaybeFatalTransition::success( + SessionEvent::IdentifiedReceiverOutputs(inner.owned_vouts.clone()), + Receiver { state: WantsOutputs { inner }, session_context: self.session_context }, + ), Err(e) => match e { - Error::Implementation(_) => { - return MaybeFatalTransition::transient(e); - } - _ => { - return MaybeFatalTransition::replyable_error( - SessionEvent::GotReplyableError((&e).into()), - Receiver { - state: HasReplyableError { error_reply: (&e).into() }, - session_context: self.session_context, - }, - e, - ); - } + Error::Implementation(_) => MaybeFatalTransition::transient(e), + _ => MaybeFatalTransition::replyable_error( + SessionEvent::GotReplyableError((&e).into()), + Receiver { + state: HasReplyableError { error_reply: (&e).into() }, + session_context: self.session_context, + }, + e, + ), }, - }; - MaybeFatalTransition::success( - SessionEvent::IdentifiedReceiverOutputs(inner.owned_vouts.clone()), - Receiver { state: WantsOutputs { inner }, session_context: self.session_context }, - ) + } } pub(crate) fn apply_identified_receiver_outputs( @@ -1054,23 +1226,20 @@ impl Receiver { ) -> MaybeFatalTransition, ProtocolError> { let max_effective_fee_rate = max_effective_fee_rate.or(Some(self.session_context.max_fee_rate)); - let psbt_context = match self + match self .state .inner .calculate_psbt_context_with_fee_range(min_fee_rate, max_effective_fee_rate) { - Ok(inner) => inner, - Err(e) => { - return MaybeFatalTransition::transient(ProtocolError::OriginalPayload(e.into())); - } - }; - MaybeFatalTransition::success( - SessionEvent::AppliedFeeRange(psbt_context.clone()), - Receiver { - state: ProvisionalProposal { psbt_context }, - session_context: self.session_context, - }, - ) + Ok(psbt_context) => MaybeFatalTransition::success( + SessionEvent::AppliedFeeRange(psbt_context.clone()), + Receiver { + state: ProvisionalProposal { psbt_context }, + session_context: self.session_context, + }, + ), + Err(e) => MaybeFatalTransition::transient(ProtocolError::OriginalPayload(e.into())), + } } pub(crate) fn apply_applied_fee_range(self, psbt_context: PsbtContext) -> ReceiveSession { @@ -1102,19 +1271,12 @@ impl Receiver { wallet_process_psbt: impl Fn(&Psbt) -> Result, ) -> MaybeTransientTransition, ImplementationError> { - let original_psbt = self.state.psbt_context.original_psbt.clone(); - let inner = match self.state.psbt_context.finalize_proposal(wallet_process_psbt) { - Ok(inner) => inner, - Err(e) => { - return MaybeTransientTransition::transient(e); - } - }; - let psbt_context = PsbtContext { payjoin_psbt: inner.clone(), original_psbt }; - let payjoin_proposal = PayjoinProposal { psbt_context: psbt_context.clone() }; - MaybeTransientTransition::success( - SessionEvent::FinalizedProposal(inner), - Receiver { state: payjoin_proposal, session_context: self.session_context }, - ) + let psbt = self.psbt_to_sign(); + let signed_psbt = wallet_process_psbt(&psbt); + match signed_psbt { + Ok(signed_psbt) => self.finalize_signed_proposal(&signed_psbt), + Err(e) => MaybeTransientTransition::transient(e), + } } /// The Payjoin proposal PSBT that the receiver needs to sign @@ -1124,6 +1286,33 @@ impl Receiver { /// so the PSBT to sign must be accessible to such implementers. pub fn psbt_to_sign(&self) -> Psbt { self.state.psbt_context.psbt_to_sign() } + /// Finalizes the Payjoin proposal into a PSBT which the sender will find acceptable before + /// they sign the transaction and broadcast it to the network. + /// + /// This takes a receiver signed PSBT payjoin proposal and finalizes it for broadcast to + /// the sender. Use [`Receiver::psbt_to_sign`] to obtain the payjoin + /// proposal's unsigned PSBT for receiver to sign and return here. + pub fn finalize_signed_proposal( + self, + signed_psbt: &Psbt, + ) -> MaybeTransientTransition, ImplementationError> + { + let original_psbt = self.state.psbt_context.original_psbt.clone(); + let payjoin_psbt = + match self.state.psbt_context.finalize_signed_proposal(signed_psbt.clone()) { + Ok(payjoin_psbt) => payjoin_psbt, + Err(e) => { + return MaybeTransientTransition::transient(e); + } + }; + let psbt_context = PsbtContext { payjoin_psbt: payjoin_psbt.clone(), original_psbt }; + let payjoin_proposal = PayjoinProposal { psbt_context: psbt_context.clone() }; + MaybeTransientTransition::success( + SessionEvent::FinalizedProposal(payjoin_psbt), + Receiver { state: payjoin_proposal, session_context: self.session_context }, + ) + } + pub(crate) fn apply_payjoin_proposal(self, payjoin_psbt: Psbt) -> ReceiveSession { let psbt_context = PsbtContext { payjoin_psbt, @@ -1329,66 +1518,95 @@ impl Receiver { &self, transaction_exists: impl Fn(Txid) -> Result, ImplementationError>, ) -> MaybeFatalOrSuccessTransition { - let fallback_tx = self - .state - .psbt_context - .original_psbt - .clone() - .extract_tx_fee_rate_limit() - .expect("fallback transaction should be in the receiver context"); - - // If the fallback transaction included any non-SegWit inputs, then the transaction ID of - // the Payjoin proposal is going to change when the sender signs their non-SegWit address - // one more time. The receiver cannot monitor the transaction, and should conclude the session. - if fallback_tx.input.iter().any(|txin| txin.witness.is_empty()) { - return MaybeFatalOrSuccessTransition::success(SessionEvent::Closed( - SessionOutcome::PayjoinProposalSent, - )); + if let transition @ MaybeFatalOrSuccessTransition::Success(_) = + self.check_fallback_monitorable() + { + return transition; } - let payjoin_proposal = &self.state.psbt_context.payjoin_psbt; - let payjoin_txid = payjoin_proposal.unsigned_tx.compute_txid(); // If the sender is spending SegWit-only inputs, then the transaction ID of the Payjoin proposal // is not going to change when the sender signs it. So we can use the TXID to check the // network for the Payjoin proposal. - match transaction_exists(payjoin_txid) { - Ok(Some(tx)) => { - let tx_id = tx.compute_txid(); - if tx_id != payjoin_txid { - return MaybeFatalOrSuccessTransition::transient(Error::Implementation( - ImplementationError::from(format!("Payjoin transaction ID mismatch. Expected: {payjoin_txid}, Got: {tx_id}").as_str()), - )); - } - // TODO: should we check for witness and scriptsig on the tx? - let mut sender_witnesses = vec![]; - - for i in self.state.psbt_context.sender_input_indexes() { - let input = - tx.input.get(i).expect("sender_input_indexes should return valid indices"); - sender_witnesses.push((input.script_sig.clone(), input.witness.clone())); - } - // Payjoin transaction with SegWit inputs was detected. Log the signatures and complete the session. - return MaybeFatalOrSuccessTransition::success(SessionEvent::Closed( - SessionOutcome::Success(sender_witnesses), - )); - } + match transaction_exists(self.extract_payjoin_proposal_txid()) { + Ok(Some(tx)) => return self.payjoin_tx_exists(tx), Ok(None) => {} Err(e) => return MaybeFatalOrSuccessTransition::transient(Error::Implementation(e)), } // If the Payjoin proposal was not found, check the fallback transaction, as it is // the second of two transactions whose IDs the receiver is aware of. - match transaction_exists(fallback_tx.compute_txid()) { - Ok(Some(_)) => - return MaybeFatalOrSuccessTransition::success(SessionEvent::Closed( - SessionOutcome::FallbackBroadcasted, - )), + match transaction_exists(self.extract_fallback_txid()) { + Ok(Some(_)) => return self.fallback_tx_exists(), Ok(None) => {} Err(e) => return MaybeFatalOrSuccessTransition::transient(Error::Implementation(e)), } MaybeFatalOrSuccessTransition::no_results(self.clone()) } + + pub fn extract_fallback_tx(&self) -> Transaction { + self.state.psbt_context.original_psbt.clone().extract_tx_unchecked_fee_rate() + } + + pub fn extract_fallback_txid(&self) -> Txid { self.extract_fallback_tx().compute_txid() } + + pub fn extract_payjoin_proposal_txid(&self) -> Txid { + self.state.psbt_context.payjoin_psbt.clone().extract_tx_unchecked_fee_rate().compute_txid() + } + + pub fn check_fallback_monitorable( + &self, + ) -> MaybeFatalOrSuccessTransition { + // If the fallback transaction included any non-SegWit inputs, then the transaction ID of + // the Payjoin proposal is going to change when the sender signs their non-SegWit address + // one more time. The receiver cannot monitor the transaction, and should conclude the session. + if has_empty_witness(&self.extract_fallback_tx()) { + return MaybeFatalOrSuccessTransition::success(SessionEvent::Closed( + SessionOutcome::PayjoinProposalSent, + )); + } + + MaybeFatalOrSuccessTransition::no_results(self.clone()) + } + + pub fn fallback_tx_exists(&self) -> MaybeFatalOrSuccessTransition { + MaybeFatalOrSuccessTransition::success(SessionEvent::Closed( + SessionOutcome::FallbackBroadcasted, + )) + } + + pub fn payjoin_tx_exists( + &self, + tx: Transaction, + ) -> MaybeFatalOrSuccessTransition { + // TODO: should we check for witness and scriptsig on the tx? + let payjoin_txid = self.state.psbt_context.payjoin_psbt.unsigned_tx.compute_txid(); + let tx_id = tx.compute_txid(); + if tx_id != payjoin_txid { + return MaybeFatalOrSuccessTransition::transient(Error::Implementation( + ImplementationError::from( + format!( + "Payjoin transaction ID mismatch. Expected: {payjoin_txid}, Got: {tx_id}" + ) + .as_str(), + ), + )); + } + let mut sender_witnesses = vec![]; + + for i in self.state.psbt_context.sender_input_indexes() { + let input = tx.input.get(i).expect("sender_input_indexes should return valid indices"); + sender_witnesses.push((input.script_sig.clone(), input.witness.clone())); + } + // Payjoin transaction with SegWit inputs was detected. Log the signatures and complete the session. + MaybeFatalOrSuccessTransition::success(SessionEvent::Closed(SessionOutcome::Success( + sender_witnesses, + ))) + } +} + +fn has_empty_witness(tx: &Transaction) -> bool { + tx.input.iter().any(|txin| txin.witness.is_empty()) } /// Derive a mailbox endpoint on a directory given a [`ShortId`]. @@ -1599,20 +1817,22 @@ pub mod test { Ok(ret) } - let maybe_inputs_seen = - receiver.check_inputs_not_owned(&mut |_| mock_callback(&mut call_count, false)); + let maybe_inputs_seen = receiver + .check_inputs_not_owned(&mut |_| mock_callback(&mut call_count, false)) + .save(&persister) + .expect("Persister shouldn't fail"); assert_eq!(call_count, 1); let outputs_unknown = maybe_inputs_seen - .save(&persister) - .expect("Persister shouldn't fail") .check_no_inputs_seen_before(&mut |_| mock_callback(&mut call_count, false)) .save(&persister) .expect("Persister shouldn't fail"); assert_eq!(call_count, 2); let _wants_outputs = outputs_unknown - .identify_receiver_outputs(&mut |_| mock_callback(&mut call_count, true)); + .identify_receiver_outputs(&mut |_| mock_callback(&mut call_count, true)) + .save(&persister) + .expect("Persister shouldn't fail"); // there are 2 receiver outputs so we should expect this callback to run twice incrementing // call count twice assert_eq!(call_count, 4);