From 20ddc7b2773cb07532a80fa808d1a73355d9400a Mon Sep 17 00:00:00 2001 From: xstoicunicornx Date: Tue, 12 May 2026 14:26:09 -0500 Subject: [PATCH 1/8] Add tests for OriginalPayload outcomes Cover the success and error paths for each validation step in the receiver's OriginalPayload's broadcast suitability, input ownership, and input known validation as well as PsbtContext's proposal finalization. --- payjoin/src/core/receive/mod.rs | 185 +++++++++++++++++++++++++++++++- 1 file changed, 184 insertions(+), 1 deletion(-) diff --git a/payjoin/src/core/receive/mod.rs b/payjoin/src/core/receive/mod.rs index 5b28abcf5..bac207c53 100644 --- a/payjoin/src/core/receive/mod.rs +++ b/payjoin/src/core/receive/mod.rs @@ -484,7 +484,9 @@ pub(crate) mod tests { witness, Amount, PubkeyHash, ScriptBuf, ScriptHash, Sequence, Txid, WScriptHash, XOnlyPublicKey, }; - use payjoin_test_utils::{DUMMY20, DUMMY32, PARSED_ORIGINAL_PSBT, QUERY_PARAMS}; + use payjoin_test_utils::{ + DUMMY20, DUMMY32, PARSED_ORIGINAL_PSBT, PARSED_PAYJOIN_PROPOSAL, QUERY_PARAMS, + }; use super::*; use crate::psbt::InternalPsbtInputError::InvalidScriptPubKey; @@ -496,6 +498,24 @@ pub(crate) mod tests { OriginalPayload { psbt: PARSED_ORIGINAL_PSBT.clone(), params } } + pub(crate) fn original_missing_prevtxout_from_test_vector() -> OriginalPayload { + let params = Params::from_query_str(QUERY_PARAMS, &[Version::One]) + .expect("Could not parse params from query str"); + let mut psbt: Psbt = PARSED_ORIGINAL_PSBT.clone(); + for psbtin in psbt.inputs_mut() { + psbtin.non_witness_utxo = None; + psbtin.witness_utxo = None; + } + OriginalPayload { psbt: psbt.clone(), params } + } + + pub(crate) fn psbt_context_from_test_vector() -> PsbtContext { + PsbtContext { + payjoin_psbt: PARSED_PAYJOIN_PROPOSAL.clone(), + original_psbt: PARSED_ORIGINAL_PSBT.clone(), + } + } + #[test] fn input_pair_with_expected_weight() { let p2wsh_txout = TxOut { @@ -830,6 +850,141 @@ pub(crate) mod tests { assert_eq!(err, PsbtInputError::from(InternalPsbtInputError::ProvidedUnnecessaryWeight)); } + #[test] + fn test_check_broadcast_suitability() { + let original = original_from_test_vector(); + + // Outcome 1: min_fee_rate too high → PsbtBelowFeeRate error + let err = original + .clone() + .check_broadcast_suitability(Some(FeeRate::MAX), |_| Ok(true)) + .expect_err("Should fail when fee rate is below minimum"); + match err { + Error::Protocol(ProtocolError::OriginalPayload(PayloadError( + InternalPayloadError::PsbtBelowFeeRate(original_fee_rate, min_fee_rate), + ))) => { + assert_eq!(original_fee_rate, original.psbt_fee_rate().unwrap()); + assert_eq!(min_fee_rate, FeeRate::MAX); + } + _ => panic!("Expected PsbtBelowFeeRate error, got: {err:?}"), + } + + // Outcome 2: can_broadcast returns false → OriginalPsbtNotBroadcastable error + let err = original + .clone() + .check_broadcast_suitability(None, |_| Ok(false)) + .expect_err("Should fail when can_broadcast returns false"); + match err { + Error::Protocol(ProtocolError::OriginalPayload(PayloadError( + InternalPayloadError::OriginalPsbtNotBroadcastable, + ))) => {} + _ => panic!("Expected OriginalPsbtNotBroadcastable error, got: {err:?}"), + } + + // Outcome 3: can_broadcast returns an implementation error → Error::Implementation + let err = original + .clone() + .check_broadcast_suitability(None, |_| { + Err(ImplementationError::from("broadcast check failed")) + }) + .expect_err("Should fail when can_broadcast returns an implementation error"); + match err { + Error::Implementation(error_message) => { + assert_eq!(error_message.to_string(), "broadcast check failed".to_string()) + } + _ => panic!("Expected Error::Implementation, got: {err:?}"), + } + + // Outcome 4: success + original + .check_broadcast_suitability(None, |_| Ok(true)) + .expect("Should succeed when fee rate is acceptable and can_broadcast returns true"); + } + + #[test] + fn test_check_inputs_not_owned() { + let original = original_from_test_vector(); + let original_missing_prevtxout = original_missing_prevtxout_from_test_vector(); + + // Outcome 1: input_scripts returns a PrevTxOut error → Protocol error + let err = original_missing_prevtxout + .check_inputs_not_owned(&mut |_| Ok(false)) + .expect_err("Should fail when previous txout is missing"); + match err { + Error::Protocol(ProtocolError::OriginalPayload(PayloadError( + InternalPayloadError::PrevTxOut(_), + ))) => {} + _ => panic!("Expected PrevTxOut error, got: {err:?}"), + } + + // Outcome 2: is_owned returns true → InputOwned error + let err = original + .clone() + .check_inputs_not_owned(&mut |_| Ok(true)) + .expect_err("Should fail when input is owned"); + match err { + Error::Protocol(ProtocolError::OriginalPayload(PayloadError( + InternalPayloadError::InputOwned(_), + ))) => {} + _ => panic!("Expected InputOwned error, got: {err:?}"), + } + + // Outcome 3: is_owned returns an implementation error → Error::Implementation + let err = original + .clone() + .check_inputs_not_owned(&mut |_| { + Err(ImplementationError::from("ownership check failed")) + }) + .expect_err("Should fail when is_owned returns an implementation error"); + match err { + Error::Implementation(error_message) => { + assert_eq!(error_message.to_string(), "ownership check failed".to_string()) + } + _ => panic!("Expected Error::Implementation, got: {err:?}"), + } + + // Outcome 4: is_owned returns false → success + original + .check_inputs_not_owned(&mut |_| Ok(false)) + .expect("Should succeed when no inputs are owned"); + } + + #[test] + fn test_check_no_inputs_seen_before() { + let original = original_from_test_vector(); + + // Outcome 1: is_known returns true → InputSeen error + let err = original + .clone() + .check_no_inputs_seen_before(&mut |_| Ok(true)) + .expect_err("Should fail when input has been seen before"); + match err { + Error::Protocol(ProtocolError::OriginalPayload(PayloadError( + InternalPayloadError::InputSeen(_), + ))) => {} + _ => panic!("Expected InputSeen error, got: {err:?}"), + } + + // Outcome 2: is_known returns an implementation error → Error::Implementation + let err = original + .clone() + .check_no_inputs_seen_before(&mut |_| { + Err(ImplementationError::from("input seen check failed")) + }) + .expect_err("Should fail when is_known returns an implementation error"); + match err { + Error::Implementation(error_message) => { + assert_eq!(error_message.to_string(), "input seen check failed".to_string()) + } + _ => panic!("Expected Error::Implementation, got: {err:?}"), + } + + // Outcome 3: is_known returns false → success + original + .check_no_inputs_seen_before(&mut |_| Ok(false)) + .expect("Should succeed when no inputs have been seen before"); + } + #[test] fn test_identify_receiver_outputs() { let original = original_from_test_vector(); @@ -864,4 +1019,32 @@ pub(crate) mod tests { assert_eq!(wants_outputs.owned_vouts, vec![0, 1]); assert_eq!(wants_outputs.params.additional_fee_contribution, None); } + + #[test] + fn test_finalize_proposal() { + let psbt_context = psbt_context_from_test_vector(); + + // Outcome 1: wallet_process_psbt returns an implementation error → ImplementationError + let err = psbt_context + .clone() + .finalize_proposal(|_| Err(ImplementationError::from("wallet signing failed"))) + .expect_err("Should fail when wallet_process_psbt returns an error"); + assert_eq!(err.to_string(), "wallet signing failed"); + + // Outcome 2: wallet_process_psbt returns a psbt with mismatched ntxid → ImplementationError + let psbt_context = psbt_context_from_test_vector(); + let err = psbt_context + .clone() + .finalize_proposal(|_| { + // return a totally different psbt to trigger ntxid mismatch + Ok(PARSED_ORIGINAL_PSBT.clone()) + }) + .expect_err("Should fail when ntxid mismatches"); + assert!(err.to_string().contains("Ntxid mismatch")); + + // Outcome 3: wallet_process_psbt succeeds → Ok(Psbt) + let _psbt = psbt_context + .finalize_proposal(|_| Ok(PARSED_PAYJOIN_PROPOSAL.clone())) + .expect("Should succeed when wallet_process_psbt returns a valid signed psbt"); + } } From 446dc9dd7551d8193b69e585f10eb909c346aaca Mon Sep 17 00:00:00 2001 From: xstoicunicornx Date: Tue, 5 May 2026 10:28:44 -0500 Subject: [PATCH 2/8] Light refactor to v2 receiver for readability Restructure match arms in the v2 receiver typestates to return directly from each branch instead of binding an intermediate value and returning after the match. Also rename a shadowed `inner` binding to `payjoin_psbt` for clarity. --- payjoin/src/core/receive/v2/mod.rs | 156 +++++++++++++---------------- 1 file changed, 70 insertions(+), 86 deletions(-) diff --git a/payjoin/src/core/receive/v2/mod.rs b/payjoin/src/core/receive/v2/mod.rs index 569ca7d6e..2c3f1bd5b 100644 --- a/payjoin/src/core/receive/v2/mod.rs +++ b/payjoin/src/core/receive/v2/mod.rs @@ -721,30 +721,25 @@ impl Receiver { Receiver, > { match self.state.original.check_inputs_not_owned(is_owned) { - Ok(inner) => inner, + Ok(()) => MaybeFatalTransition::success( + SessionEvent::CheckedInputsNotOwned(), + Receiver { + state: MaybeInputsSeen { original: self.original.clone() }, + session_context: self.session_context, + }, + ), Err(e) => match e { - Error::Implementation(_) => { - return MaybeFatalTransition::transient(e); - } - _ => { - return MaybeFatalTransition::replyable_error( - SessionEvent::GotReplyableError((&e).into()), - Receiver { - state: HasReplyableError { error_reply: (&e).into() }, - session_context: self.session_context, - }, - e, - ); - } - }, - }; - MaybeFatalTransition::success( - SessionEvent::CheckedInputsNotOwned(), - Receiver { - state: MaybeInputsSeen { original: self.original.clone() }, - session_context: self.session_context, + Error::Implementation(_) => MaybeFatalTransition::transient(e), + _ => MaybeFatalTransition::replyable_error( + SessionEvent::GotReplyableError((&e).into()), + Receiver { + state: HasReplyableError { error_reply: (&e).into() }, + session_context: self.session_context, + }, + e, + ), }, - ) + } } pub(crate) fn apply_checked_inputs_not_owned(self) -> ReceiveSession { @@ -783,30 +778,25 @@ impl Receiver { Receiver, > { match self.state.original.check_no_inputs_seen_before(is_known) { - Ok(inner) => inner, + Ok(()) => MaybeFatalTransition::success( + SessionEvent::CheckedNoInputsSeenBefore(), + Receiver { + state: OutputsUnknown { original: self.original.clone() }, + session_context: self.session_context, + }, + ), Err(e) => match e { - Error::Implementation(_) => { - return MaybeFatalTransition::transient(e); - } - _ => { - return MaybeFatalTransition::replyable_error( - SessionEvent::GotReplyableError((&e).into()), - Receiver { - state: HasReplyableError { error_reply: (&e).into() }, - session_context: self.session_context, - }, - e, - ); - } - }, - }; - MaybeFatalTransition::success( - SessionEvent::CheckedNoInputsSeenBefore(), - Receiver { - state: OutputsUnknown { original: self.original.clone() }, - session_context: self.session_context, + Error::Implementation(_) => MaybeFatalTransition::transient(e), + _ => MaybeFatalTransition::replyable_error( + SessionEvent::GotReplyableError((&e).into()), + Receiver { + state: HasReplyableError { error_reply: (&e).into() }, + session_context: self.session_context, + }, + e, + ), }, - ) + } } pub(crate) fn apply_checked_no_inputs_seen_before(self) -> ReceiveSession { @@ -849,28 +839,23 @@ impl Receiver { Error, Receiver, > { - let inner = match self.state.original.identify_receiver_outputs(is_receiver_output) { - Ok(inner) => inner, + match self.state.original.identify_receiver_outputs(is_receiver_output) { + Ok(inner) => MaybeFatalTransition::success( + SessionEvent::IdentifiedReceiverOutputs(inner.owned_vouts.clone()), + Receiver { state: WantsOutputs { inner }, session_context: self.session_context }, + ), Err(e) => match e { - Error::Implementation(_) => { - return MaybeFatalTransition::transient(e); - } - _ => { - return MaybeFatalTransition::replyable_error( - SessionEvent::GotReplyableError((&e).into()), - Receiver { - state: HasReplyableError { error_reply: (&e).into() }, - session_context: self.session_context, - }, - e, - ); - } + Error::Implementation(_) => MaybeFatalTransition::transient(e), + _ => MaybeFatalTransition::replyable_error( + SessionEvent::GotReplyableError((&e).into()), + Receiver { + state: HasReplyableError { error_reply: (&e).into() }, + session_context: self.session_context, + }, + e, + ), }, - }; - MaybeFatalTransition::success( - SessionEvent::IdentifiedReceiverOutputs(inner.owned_vouts.clone()), - Receiver { state: WantsOutputs { inner }, session_context: self.session_context }, - ) + } } pub(crate) fn apply_identified_receiver_outputs( @@ -1054,23 +1039,20 @@ impl Receiver { ) -> MaybeFatalTransition, ProtocolError> { let max_effective_fee_rate = max_effective_fee_rate.or(Some(self.session_context.max_fee_rate)); - let psbt_context = match self + match self .state .inner .calculate_psbt_context_with_fee_range(min_fee_rate, max_effective_fee_rate) { - Ok(inner) => inner, - Err(e) => { - return MaybeFatalTransition::transient(ProtocolError::OriginalPayload(e.into())); - } - }; - MaybeFatalTransition::success( - SessionEvent::AppliedFeeRange(psbt_context.clone()), - Receiver { - state: ProvisionalProposal { psbt_context }, - session_context: self.session_context, - }, - ) + Ok(psbt_context) => MaybeFatalTransition::success( + SessionEvent::AppliedFeeRange(psbt_context.clone()), + Receiver { + state: ProvisionalProposal { psbt_context }, + session_context: self.session_context, + }, + ), + Err(e) => MaybeFatalTransition::transient(ProtocolError::OriginalPayload(e.into())), + } } pub(crate) fn apply_applied_fee_range(self, psbt_context: PsbtContext) -> ReceiveSession { @@ -1103,16 +1085,16 @@ impl Receiver { ) -> MaybeTransientTransition, ImplementationError> { let original_psbt = self.state.psbt_context.original_psbt.clone(); - let inner = match self.state.psbt_context.finalize_proposal(wallet_process_psbt) { - Ok(inner) => inner, + let payjoin_psbt = match self.state.psbt_context.finalize_proposal(wallet_process_psbt) { + Ok(payjoin_psbt) => payjoin_psbt, Err(e) => { return MaybeTransientTransition::transient(e); } }; - let psbt_context = PsbtContext { payjoin_psbt: inner.clone(), original_psbt }; + let psbt_context = PsbtContext { payjoin_psbt: payjoin_psbt.clone(), original_psbt }; let payjoin_proposal = PayjoinProposal { psbt_context: psbt_context.clone() }; MaybeTransientTransition::success( - SessionEvent::FinalizedProposal(inner), + SessionEvent::FinalizedProposal(payjoin_psbt), Receiver { state: payjoin_proposal, session_context: self.session_context }, ) } @@ -1599,20 +1581,22 @@ pub mod test { Ok(ret) } - let maybe_inputs_seen = - receiver.check_inputs_not_owned(&mut |_| mock_callback(&mut call_count, false)); + let maybe_inputs_seen = receiver + .check_inputs_not_owned(&mut |_| mock_callback(&mut call_count, false)) + .save(&persister) + .expect("Persister shouldn't fail"); assert_eq!(call_count, 1); let outputs_unknown = maybe_inputs_seen - .save(&persister) - .expect("Persister shouldn't fail") .check_no_inputs_seen_before(&mut |_| mock_callback(&mut call_count, false)) .save(&persister) .expect("Persister shouldn't fail"); assert_eq!(call_count, 2); let _wants_outputs = outputs_unknown - .identify_receiver_outputs(&mut |_| mock_callback(&mut call_count, true)); + .identify_receiver_outputs(&mut |_| mock_callback(&mut call_count, true)) + .save(&persister) + .expect("Persister shouldn't fail"); // there are 2 receiver outputs so we should expect this callback to run twice incrementing // call count twice assert_eq!(call_count, 4); From 70d24b946cb07c89a6e55f55b4bd5108198154cf Mon Sep 17 00:00:00 2001 From: xstoicunicornx Date: Thu, 21 May 2026 21:50:26 -0500 Subject: [PATCH 3/8] Refactor Javascript FFI bindings integration test Consolidate standalone receiver processing functions into a ReceiverProcessor class that encapsulates the payjoin module, RPC client, and persister. Fix the PJ helper type to use prototype inference, update the web import paths from src to dist, and correct the CheckInputsNotSeenCallback parameter type. --- .../javascript/test/integration.test.ts | 440 +++++++----------- 1 file changed, 174 insertions(+), 266 deletions(-) diff --git a/payjoin-ffi/javascript/test/integration.test.ts b/payjoin-ffi/javascript/test/integration.test.ts index 36aed0f2c..0dde4a267 100644 --- a/payjoin-ffi/javascript/test/integration.test.ts +++ b/payjoin-ffi/javascript/test/integration.test.ts @@ -7,8 +7,8 @@ import { payjoin as nodejsPayjoin, uniffiInitAsync as nodejsUniffiInitAsync, } from "payjoin"; -import * as webPayjoinModule from "../src/web/generated/payjoin.js"; -import initWebAsync from "../src/web/generated/wasm-bindgen/index.js"; +import * as webPayjoinModule from "../dist/web/generated/payjoin.js"; +import initWebAsync from "../dist/web/generated/wasm-bindgen/index.js"; import { InMemoryReceiverPersister, InMemorySenderPersister } from "./utils.ts"; const __filename = fileURLToPath(import.meta.url); @@ -34,10 +34,12 @@ interface Utxo { type PayjoinModule = typeof nodejsPayjoin; const webPayjoin = webPayjoinModule as unknown as PayjoinModule; -// Helper types to avoid repeating InstanceType everywhere. -type PJ = InstanceType< - PayjoinModule[K] & (new (...args: any) => any) ->; +type PJ = PayjoinModule[K] extends { + prototype: infer P; +} + ? P + : never; + type PJNested< K extends keyof PayjoinModule, N extends keyof PayjoinModule[K], @@ -137,7 +139,7 @@ class CheckInputsNotSeenCallback { this.connection = connection; } - callback(_outpoint: ArrayBuffer): boolean { + callback(_outpoint: nodejsPayjoin.OutPoint): boolean { if (this.connection) { } return false; @@ -159,18 +161,6 @@ class ProcessPsbtCallback { } } -function createReceiverContext( - payjoin: PayjoinModule, - address: string, - directory: string, - ohttpKeys: ReturnType, - persister: InMemoryReceiverPersister, -): PJ<"Initialized"> { - return new payjoin.ReceiverBuilder(address, directory, ohttpKeys) - .build() - .save(persister); -} - function buildSweepPsbt( sender: testUtils.RpcClient, pjUri: PJ<"PjUri">, @@ -230,254 +220,175 @@ function getInputs( return inputs; } -async function processProvisionalProposal( - proposal: PJ<"ProvisionalProposal">, - receiver: testUtils.RpcClient, - recvPersister: InMemoryReceiverPersister, -): Promise> { - return proposal - .finalizeProposal(new ProcessPsbtCallback(receiver)) - .save(recvPersister); -} +class ReceiverProcessor { + constructor( + private readonly payjoin: PayjoinModule, + private readonly receiver: testUtils.RpcClient, + private readonly recvPersister: InMemoryReceiverPersister, + ) {} + + private async processProvisionalProposal( + proposal: PJ<"ProvisionalProposal">, + ): Promise> { + return proposal + .finalizeProposal(new ProcessPsbtCallback(this.receiver)) + .save(this.recvPersister) as PJ<"PayjoinProposal">; + } -async function processWantsFeeRange( - proposal: PJ<"WantsFeeRange">, - receiver: testUtils.RpcClient, - recvPersister: InMemoryReceiverPersister, -): Promise> { - const wantsFeeRange = proposal.applyFeeRange(1n, 10n).save(recvPersister); - return await processProvisionalProposal( - wantsFeeRange, - receiver, - recvPersister, - ); -} + private async processWantsFeeRange( + proposal: PJ<"WantsFeeRange">, + ): Promise> { + const provisionalProposal = proposal + .applyFeeRange(1n, 10n) + .save(this.recvPersister) as PJ<"ProvisionalProposal">; + return this.processProvisionalProposal(provisionalProposal); + } -async function processWantsInputs( - payjoin: PayjoinModule, - proposal: PJ<"WantsInputs">, - receiver: testUtils.RpcClient, - recvPersister: InMemoryReceiverPersister, -): Promise> { - const provisionalProposal = proposal - .contributeInputs(getInputs(payjoin, receiver)) - .commitInputs() - .save(recvPersister); - return await processWantsFeeRange( - provisionalProposal, - receiver, - recvPersister, - ); -} + private async processWantsInputs( + proposal: PJ<"WantsInputs">, + ): Promise> { + const provisionalProposal = proposal + .contributeInputs(getInputs(this.payjoin, this.receiver)) + .commitInputs() + .save(this.recvPersister) as PJ<"WantsFeeRange">; + return this.processWantsFeeRange(provisionalProposal); + } -async function processWantsOutputs( - payjoin: PayjoinModule, - proposal: PJ<"WantsOutputs">, - receiver: testUtils.RpcClient, - recvPersister: InMemoryReceiverPersister, -): Promise> { - const wantsInputs = proposal.commitOutputs().save(recvPersister); - return await processWantsInputs( - payjoin, - wantsInputs, - receiver, - recvPersister, - ); -} + private async processWantsOutputs( + proposal: PJ<"WantsOutputs">, + ): Promise> { + const wantsInputs = proposal + .commitOutputs() + .save(this.recvPersister) as PJ<"WantsInputs">; + return this.processWantsInputs(wantsInputs); + } -async function processOutputsUnknown( - payjoin: PayjoinModule, - proposal: PJ<"OutputsUnknown">, - receiver: testUtils.RpcClient, - recvPersister: InMemoryReceiverPersister, -): Promise> { - const wantsOutputs = proposal - .identifyReceiverOutputs(new IsScriptOwnedCallback(receiver)) - .save(recvPersister); - return await processWantsOutputs( - payjoin, - wantsOutputs, - receiver, - recvPersister, - ); -} + private async processOutputsUnknown( + proposal: PJ<"OutputsUnknown">, + ): Promise> { + const wantsOutputs = proposal + .identifyReceiverOutputs(new IsScriptOwnedCallback(this.receiver)) + .save(this.recvPersister) as PJ<"WantsOutputs">; + return this.processWantsOutputs(wantsOutputs); + } -async function processMaybeInputsSeen( - payjoin: PayjoinModule, - proposal: PJ<"MaybeInputsSeen">, - receiver: testUtils.RpcClient, - recvPersister: InMemoryReceiverPersister, -): Promise> { - const outputsUnknown = proposal - .checkNoInputsSeenBefore(new CheckInputsNotSeenCallback(receiver)) - .save(recvPersister); - return await processOutputsUnknown( - payjoin, - outputsUnknown, - receiver, - recvPersister, - ); -} + private async processMaybeInputsSeen( + proposal: PJ<"MaybeInputsSeen">, + ): Promise> { + const outputsUnknown = proposal + .checkNoInputsSeenBefore( + new CheckInputsNotSeenCallback(this.receiver), + ) + .save(this.recvPersister) as PJ<"OutputsUnknown">; + return this.processOutputsUnknown(outputsUnknown); + } -async function processMaybeInputsOwned( - payjoin: PayjoinModule, - proposal: PJ<"MaybeInputsOwned">, - receiver: testUtils.RpcClient, - recvPersister: InMemoryReceiverPersister, -): Promise> { - const maybeInputsOwned = proposal - .checkInputsNotOwned(new IsScriptOwnedCallback(receiver)) - .save(recvPersister); - return await processMaybeInputsSeen( - payjoin, - maybeInputsOwned, - receiver, - recvPersister, - ); -} + private async processMaybeInputsOwned( + proposal: nodejsPayjoin.MaybeInputsOwned, + ): Promise> { + const maybeInputsSeen = proposal + .checkInputsNotOwned(new IsScriptOwnedCallback(this.receiver)) + .save(this.recvPersister) as PJ<"MaybeInputsSeen">; + return this.processMaybeInputsSeen(maybeInputsSeen); + } -async function processUncheckedProposal( - payjoin: PayjoinModule, - proposal: PJ<"UncheckedOriginalPayload">, - receiver: testUtils.RpcClient, - recvPersister: InMemoryReceiverPersister, -): Promise> { - const uncheckedProposal = proposal - .checkBroadcastSuitability( - undefined, - new MempoolAcceptanceCallback(receiver), - ) - .save(recvPersister); - return await processMaybeInputsOwned( - payjoin, - uncheckedProposal, - receiver, - recvPersister, - ); -} + private async processUncheckedProposal( + proposal: PJ<"UncheckedOriginalPayload">, + ): Promise> { + const maybeInputsOwned = proposal + .checkBroadcastSuitability( + undefined, + new MempoolAcceptanceCallback(this.receiver), + ) + .save(this.recvPersister) as PJ<"MaybeInputsOwned">; + return this.processMaybeInputsOwned(maybeInputsOwned); + } -async function retrieveReceiverProposal( - payjoin: PayjoinModule, - receiver: PJ<"Initialized">, - receiverRpc: testUtils.RpcClient, - recvPersister: InMemoryReceiverPersister, - ohttpRelay: string, -): Promise | null> { - const request = receiver.createPollRequest(ohttpRelay); - const response = await fetch(request.request.url, { - method: "POST", - headers: { "Content-Type": request.request.contentType }, - body: request.request.body, - }); - const responseBuffer = await response.arrayBuffer(); - const res = receiver - .processResponse(responseBuffer, request.clientResponse) - .save(recvPersister); - - if (res instanceof payjoin.InitializedTransitionOutcome.Stasis) { - return null; - } else if (res instanceof payjoin.InitializedTransitionOutcome.Progress) { - const proposal = res.inner.inner; - return await processUncheckedProposal( - payjoin, - proposal, - receiverRpc, - recvPersister, - ); + createReceiverContext( + address: string, + directory: string, + ohttpKeys: ReturnType, + ): PJ<"Initialized"> { + return new this.payjoin.ReceiverBuilder(address, directory, ohttpKeys) + .build() + .save(this.recvPersister) as PJ<"Initialized">; } - throw new Error(`Unknown initialized transition outcome`); -} + private async retrieveReceiverProposal( + session: PJ<"Initialized">, + ohttpRelay: string, + ): Promise | null> { + const request = session.createPollRequest(ohttpRelay); + const response = await fetch(request.request.url, { + method: "POST", + headers: { "Content-Type": request.request.contentType }, + body: request.request.body, + }); + const responseBuffer = await response.arrayBuffer(); + const res = session + .processResponse(responseBuffer, request.clientResponse) + .save(this.recvPersister); + + if (res instanceof this.payjoin.InitializedTransitionOutcome.Stasis) { + return null; + } else if ( + res instanceof this.payjoin.InitializedTransitionOutcome.Progress + ) { + return this.processUncheckedProposal( + res.inner.inner as PJ<"UncheckedOriginalPayload">, + ); + } -async function processReceiverProposal( - payjoin: PayjoinModule, - receiver: - | PJ<"Initialized"> - | PJ<"UncheckedOriginalPayload"> - | PJ<"MaybeInputsOwned"> - | PJ<"MaybeInputsSeen"> - | PJ<"OutputsUnknown"> - | PJ<"WantsOutputs"> - | PJ<"WantsInputs"> - | PJ<"WantsFeeRange"> - | PJ<"ProvisionalProposal"> - | PJ<"PayjoinProposal">, - receiverRpc: testUtils.RpcClient, - recvPersister: InMemoryReceiverPersister, - ohttpRelay: string, -): Promise | null> { - if (receiver instanceof payjoin.Initialized) { - return await retrieveReceiverProposal( - payjoin, - receiver, - receiverRpc, - recvPersister, - ohttpRelay, - ); - } - if (receiver instanceof payjoin.UncheckedOriginalPayload) { - return await processUncheckedProposal( - payjoin, - receiver, - receiverRpc, - recvPersister, - ); - } - if (receiver instanceof payjoin.MaybeInputsOwned) { - return await processMaybeInputsOwned( - payjoin, - receiver, - receiverRpc, - recvPersister, - ); - } - if (receiver instanceof payjoin.MaybeInputsSeen) { - return await processMaybeInputsSeen( - payjoin, - receiver, - receiverRpc, - recvPersister, - ); - } - if (receiver instanceof payjoin.OutputsUnknown) { - return await processOutputsUnknown( - payjoin, - receiver, - receiverRpc, - recvPersister, - ); - } - if (receiver instanceof payjoin.WantsOutputs) { - return await processWantsOutputs( - payjoin, - receiver, - receiverRpc, - recvPersister, - ); - } - if (receiver instanceof payjoin.WantsInputs) { - return await processWantsInputs( - payjoin, - receiver, - receiverRpc, - recvPersister, - ); - } - if (receiver instanceof payjoin.WantsFeeRange) { - return await processWantsFeeRange(receiver, receiverRpc, recvPersister); - } - if (receiver instanceof payjoin.ProvisionalProposal) { - return await processProvisionalProposal( - receiver, - receiverRpc, - recvPersister, - ); - } - if (receiver instanceof payjoin.PayjoinProposal) { - return receiver; + throw new Error(`Unknown initialized transition outcome`); } - throw new Error(`Unknown receiver state`); + async processReceiverProposal( + receiver: + | PJ<"Initialized"> + | PJ<"UncheckedOriginalPayload"> + | PJ<"MaybeInputsOwned"> + | PJ<"MaybeInputsSeen"> + | PJ<"OutputsUnknown"> + | PJ<"WantsOutputs"> + | PJ<"WantsInputs"> + | PJ<"WantsFeeRange"> + | PJ<"ProvisionalProposal"> + | PJ<"PayjoinProposal">, + ohttpRelay: string, + ): Promise | null> { + if (receiver instanceof this.payjoin.Initialized) { + return this.retrieveReceiverProposal(receiver, ohttpRelay); + } + if (receiver instanceof this.payjoin.UncheckedOriginalPayload) { + return this.processUncheckedProposal(receiver); + } + if (receiver instanceof this.payjoin.MaybeInputsOwned) { + return this.processMaybeInputsOwned(receiver); + } + if (receiver instanceof this.payjoin.MaybeInputsSeen) { + return this.processMaybeInputsSeen(receiver); + } + if (receiver instanceof this.payjoin.OutputsUnknown) { + return this.processOutputsUnknown(receiver); + } + if (receiver instanceof this.payjoin.WantsOutputs) { + return this.processWantsOutputs(receiver); + } + if (receiver instanceof this.payjoin.WantsInputs) { + return this.processWantsInputs(receiver); + } + if (receiver instanceof this.payjoin.WantsFeeRange) { + return this.processWantsFeeRange(receiver); + } + if (receiver instanceof this.payjoin.ProvisionalProposal) { + return this.processProvisionalProposal(receiver); + } + if (receiver instanceof this.payjoin.PayjoinProposal) { + return receiver; + } + + throw new Error(`Unknown receiver state`); + } } function testFfiValidation(payjoin: PayjoinModule): void { @@ -598,21 +509,21 @@ async function testIntegrationV2ToV2(payjoin: PayjoinModule): Promise { ); const recvPersister = new InMemoryReceiverPersister(); + const recvProcessor = new ReceiverProcessor( + payjoin, + receiver, + recvPersister, + ); const senderPersister = new InMemorySenderPersister(); - const session = createReceiverContext( - payjoin, + const session = recvProcessor.createReceiverContext( receiverAddress, directory, ohttpKeys, - recvPersister, ); - let processResponse = await processReceiverProposal( - payjoin, + const processResponse = await recvProcessor.processReceiverProposal( session, - receiver, - recvPersister, ohttpRelay, ); assert.strictEqual( @@ -622,7 +533,7 @@ async function testIntegrationV2ToV2(payjoin: PayjoinModule): Promise { ); const pjUri = session.pjUri(); - const psbt = buildSweepPsbt(sender, pjUri); + const psbt = buildSweepPsbt(sender, pjUri as PJ<"PjUri">); const reqCtx = new payjoin.SenderBuilder(psbt, pjUri) .buildRecommended(1000n) .save(senderPersister); @@ -638,11 +549,8 @@ async function testIntegrationV2ToV2(payjoin: PayjoinModule): Promise { .processResponse(responseBuffer, request.ohttpCtx) .save(senderPersister); - let payjoinProposal = await processReceiverProposal( - payjoin, + const payjoinProposal = await recvProcessor.processReceiverProposal( session, - receiver, - recvPersister, ohttpRelay, ); assert.notStrictEqual( From 181aa738dd049cd87bca3f3cb730a19a57c69691 Mon Sep 17 00:00:00 2001 From: xstoicunicornx Date: Wed, 13 May 2026 00:28:42 -0500 Subject: [PATCH 4/8] Add non-blocking receive interface Introduce an implementation-agnostic interface for receiver typestates that currently require callback-based validation to advance. Previously, each validation step demanded a synchronous closure, coupling the state machine to the caller's execution model. This made integration difficult for wallets where signing, broadcast checks, or ownership lookups are asynchronous or handled by a separate process. Each callback-based transition is now split into a two-phase pattern: a method to extract the data that needs checking (get_*_refs, extract_tx_*, psbt_to_sign) and a corresponding method to submit results and advance the state (apply_*_checks, apply_broadcast_suitability, finalize_signed_proposal). A lightweight Reference/TaggedReference framework with typed tags ensures completeness and ordering of the submitted checks at runtime. This applies across v1 and v2 receiver flows, including input ownership, input-seen, output ownership, broadcast suitability, proposal finalization, and transaction monitoring. The original closure-based methods are preserved as convenience wrappers over the new API, so this is backward-compatible for existing integrators. --- payjoin/src/core/receive/common/mod.rs | 2 +- payjoin/src/core/receive/error.rs | 5 + payjoin/src/core/receive/mod.rs | 320 ++++++++++++++++---- payjoin/src/core/receive/v1/mod.rs | 166 ++++++++++- payjoin/src/core/receive/v2/mod.rs | 392 ++++++++++++++++++++----- 5 files changed, 744 insertions(+), 141 deletions(-) diff --git a/payjoin/src/core/receive/common/mod.rs b/payjoin/src/core/receive/common/mod.rs index 9a93fb8f0..635c96159 100644 --- a/payjoin/src/core/receive/common/mod.rs +++ b/payjoin/src/core/receive/common/mod.rs @@ -863,7 +863,7 @@ mod tests { .calculate_psbt_context_with_fee_range(None, None) .expect("Contributed inputs should allow for valid fee contributions"); let payjoin_proposal = - psbt_context.finalize_proposal(|_| Ok(processed_psbt.clone())).expect("Valid psbt"); + psbt_context.finalize_signed_proposal(processed_psbt.clone()).expect("Valid psbt"); assert!(payjoin_proposal.xpub.is_empty()); diff --git a/payjoin/src/core/receive/error.rs b/payjoin/src/core/receive/error.rs index d7bd23e39..bed31cec1 100644 --- a/payjoin/src/core/receive/error.rs +++ b/payjoin/src/core/receive/error.rs @@ -3,6 +3,7 @@ use std::{error, fmt}; use crate::error_codes::ErrorCode::{ self, NotEnoughMoney, OriginalPsbtRejected, Unavailable, VersionUnsupported, }; +use crate::ImplementationError; /// The top-level error type for the payjoin receiver #[derive(Debug)] @@ -29,6 +30,10 @@ impl From for Error { fn from(e: ProtocolError) -> Self { Error::Protocol(e) } } +impl From for Error { + fn from(e: ImplementationError) -> Self { Error::Implementation(e) } +} + impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { diff --git a/payjoin/src/core/receive/mod.rs b/payjoin/src/core/receive/mod.rs index bac207c53..6144cece3 100644 --- a/payjoin/src/core/receive/mod.rs +++ b/payjoin/src/core/receive/mod.rs @@ -10,6 +10,7 @@ //! version 1, refer to the `receive::v1` module documentation after enabling the `v1` feature. use std::collections::BTreeMap; +use std::marker::PhantomData; use std::str::FromStr; use bitcoin::transaction::InputWeightPrediction; @@ -228,6 +229,142 @@ impl<'a> From<&'a InputPair> for InternalInputPair<'a> { fn from(pair: &'a InputPair) -> Self { Self { psbtin: &pair.psbtin, txin: &pair.txin } } } +/// Holds a value that requires some form of boolean check. +#[derive(Debug)] +pub struct Reference { + value: V, + index: usize, + /// The final index in the set of [`Reference`]s to be checked + final_index: usize, + _tag: PhantomData, +} + +impl Reference +where + V: Clone, + T: Tag, +{ + fn new(value: V, index: usize, final_index: usize) -> Self { + Reference { value, index, final_index, _tag: PhantomData } + } + + /// Returns a [`TaggedReference`] that has been marked with the result of the boolean + /// check. + pub fn mark(&self, result: bool) -> TaggedReference { + TaggedReference { + value: self.value.clone(), + index: self.index, + final_index: self.final_index, + tag: T::new(result), + } + } + + /// Extracts the value to to be checked + pub fn get_value(&self) -> V { self.value.clone() } +} + +/// Holds the result of a checked [`Reference`]. Can only be created with [`Reference::mark`]. +#[derive(Debug)] +pub struct TaggedReference { + value: V, + index: usize, + final_index: usize, + tag: T, +} + +impl TaggedReference +where + V: Clone, + T: Tag, +{ + pub fn get_result(&self) -> bool { self.tag.result() } + pub fn get_index(&self) -> usize { self.index } + pub fn get_value(&self) -> V { self.value.clone() } +} + +/// Trait used to distinguish different types of validation +pub trait Tag { + fn new(result: bool) -> Self; + fn result(&self) -> bool; +} + +#[derive(Debug)] +pub struct InputOwnedTag { + is_owned: bool, +} + +impl Tag for InputOwnedTag { + fn new(result: bool) -> InputOwnedTag { InputOwnedTag { is_owned: result } } + fn result(&self) -> bool { self.is_owned } +} + +#[derive(Debug)] +pub struct InputSeenTag { + is_seen: bool, +} + +impl Tag for InputSeenTag { + fn new(result: bool) -> InputSeenTag { InputSeenTag { is_seen: result } } + fn result(&self) -> bool { self.is_seen } +} + +#[derive(Debug)] +pub struct OutputOwnedTag { + is_owned: bool, +} + +impl Tag for OutputOwnedTag { + fn new(result: bool) -> OutputOwnedTag { OutputOwnedTag { is_owned: result } } + fn result(&self) -> bool { self.is_owned } +} + +/// Helper function to run validation callback over a list of [`Reference`]s +pub fn check_references( + references: impl Iterator>, + check: &mut impl FnMut(&V) -> Result, +) -> Result>, ImplementationError> +where + V: Clone, + T: Tag, +{ + let mut checked_references: Vec> = vec![]; + for reference in references { + let result = check(&reference.get_value())?; + checked_references.push(reference.mark(result)); + } + Ok(checked_references.into_iter()) +} + +/// Validate that the [`TaggedReference`]s are in the correct order and are a complete set. +fn validate_checks( + checked_references: impl IntoIterator>, +) -> Result>, ImplementationError> +where + V: Clone, + T: Tag, +{ + let mut current_index = 0; + let mut is_complete = false; + let mut validated_refs: Vec> = vec![]; + for reference in checked_references { + if reference.get_index() != current_index { + return Err(ImplementationError::from( + "Missing reference check at index {current_index}", + )); + } + if reference.get_index() == reference.final_index { + is_complete = true + } else { + current_index += 1; + } + validated_refs.push(reference); + } + if !is_complete { + return Err(ImplementationError::from("Missing reference check at index {current_index}")); + } + Ok(validated_refs.into_iter()) +} + /// Validate the payload of a Payjoin request for PSBT and Params sanity pub(crate) fn parse_payload( base64: &str, @@ -254,7 +391,7 @@ pub struct PsbtContext { impl PsbtContext { /// Prepare the PSBT by creating a new PSBT and copying only the fields allowed by the [spec](https://github.com/bitcoin/bips/blob/master/bip-0078.mediawiki#senders-payjoin-proposal-checklist) - fn prepare_psbt(self, processed_psbt: Psbt) -> Psbt { + fn prepare_psbt(&self, processed_psbt: Psbt) -> Psbt { tracing::trace!("Original PSBT from callback: {processed_psbt:#?}"); // Create a new PSBT and copy only the allowed fields @@ -339,6 +476,15 @@ impl PsbtContext { ) -> Result { let psbt = self.psbt_to_sign(); let signed_psbt = wallet_process_psbt(&psbt)?; + self.finalize_signed_proposal(signed_psbt) + } + + /// Finalizes the signed payjoin proposal PSBT which the sender will find acceptable before + /// they sign the transaction and broadcast it to the network. + /// + /// Returns a final payjoin proposal PSBT after verifying the signed PSBT matches the payjoin + /// proposal PSBT and sanitizing it. + fn finalize_signed_proposal(&self, signed_psbt: Psbt) -> Result { let expected_ntxid = self.payjoin_psbt.unsigned_tx.compute_ntxid(); let actual_ntxid = signed_psbt.unsigned_tx.compute_ntxid(); if expected_ntxid != actual_ntxid { @@ -370,6 +516,17 @@ impl OriginalPayload { &self, min_fee_rate: Option, can_broadcast: impl Fn(&bitcoin::Transaction) -> Result, + ) -> Result<(), Error> { + self.apply_broadcast_suitability( + min_fee_rate, + can_broadcast(&self.psbt.clone().extract_tx_unchecked_fee_rate())?, + ) + } + + pub fn apply_broadcast_suitability( + &self, + min_fee_rate: Option, + can_broadcast: bool, ) -> Result<(), Error> { let original_psbt_fee_rate = self.psbt_fee_rate()?; if let Some(min_fee_rate) = min_fee_rate { @@ -381,80 +538,144 @@ impl OriginalPayload { .into()); } } - if can_broadcast(&self.psbt.clone().extract_tx_unchecked_fee_rate()) - .map_err(Error::Implementation)? - { + if can_broadcast { Ok(()) } else { Err(InternalPayloadError::OriginalPsbtNotBroadcastable.into()) } } - /// Check that the original PSBT has no receiver-owned inputs. + /// Check that the original PSBT has no receiver owned inputs. /// /// An attacker can try to spend the receiver's own inputs. This check prevents that. pub fn check_inputs_not_owned( &self, is_owned: &mut impl FnMut(&Script) -> Result, ) -> Result<(), Error> { - let mut err: Result<(), Error> = Ok(()); - if let Some(e) = self + let checked_inputs = + check_references(self.get_input_script_refs()?, &mut |script: &ScriptBuf| { + is_owned(script.as_script()) + })?; + self.apply_input_owned_checks(checked_inputs) + } + + pub fn get_input_script_refs( + &self, + ) -> Result>, Error> { + let final_index = self.psbt.input_pairs().count() - 1; + let script_references = self .psbt .input_pairs() - .scan(&mut err, |err, input| match input.previous_txout() { - Ok(txout) => Some(txout.script_pubkey.to_owned()), - Err(e) => { - **err = Err(InternalPayloadError::PrevTxOut(e).into()); - None - } - }) - .find_map(|script| match is_owned(&script) { - Ok(false) => None, - Ok(true) => Some(InternalPayloadError::InputOwned(script).into()), - Err(e) => Some(Error::Implementation(e)), + .enumerate() + .map(|(index, input)| match input.previous_txout() { + Ok(txout) => Ok(Reference::::new( + txout.script_pubkey.to_owned(), + index, + final_index, + )), + Err(e) => Err(InternalPayloadError::PrevTxOut(e)), }) - { - return Err(e); + .collect::>, InternalPayloadError>>()?; + Ok(script_references.into_iter()) + } + + pub fn apply_input_owned_checks( + &self, + checked_input_scripts: impl IntoIterator>, + ) -> Result<(), Error> { + let validated_checks = validate_checks(checked_input_scripts)?; + match validated_checks.into_iter().find(|checked_input| checked_input.get_result()) { + Some(checked_input) => + Err(InternalPayloadError::InputOwned(checked_input.get_value()).into()), + None => Ok(()), } - err?; - Ok(()) } pub fn check_no_inputs_seen_before( &self, is_known: &mut impl FnMut(&OutPoint) -> Result, ) -> Result<(), Error> { - self.psbt.input_pairs().try_for_each(|input| { - match is_known(&input.txin.previous_output) { - Ok(false) => Ok::<(), Error>(()), - Ok(true) => { - tracing::warn!("Request contains an input we've seen before: {}. Preventing possible probing attack.", input.txin.previous_output); - Err(InternalPayloadError::InputSeen(input.txin.previous_output))? - }, - Err(e) => Err(Error::Implementation(e))?, + let checked_inputs = check_references(self.get_input_outpoint_refs(), is_known)?; + self.apply_input_seen_checks(checked_inputs) + } + + pub fn get_input_outpoint_refs( + &self, + ) -> impl Iterator> { + let final_index = self.psbt.input_pairs().count() - 1; + let outpoint_references = self + .psbt + .input_pairs() + .enumerate() + .map(|(index, input)| { + Reference::::new( + input.txin.previous_output, + index, + final_index, + ) + }) + .collect::>(); + outpoint_references.into_iter() + } + + pub fn apply_input_seen_checks( + &self, + checked_input_outpoints: impl IntoIterator>, + ) -> Result<(), Error> { + let validated_checks = validate_checks(checked_input_outpoints)?; + match validated_checks.into_iter().find(|checked_input| checked_input.get_result()) { + Some(checked_input) => { + tracing::warn!("Request contains an input we've seen before: {}. Preventing possible probing attack.", checked_input.get_value()); + Err(InternalPayloadError::InputSeen(checked_input.get_value()))? } - })?; - Ok(()) + None => Ok(()), + } } pub fn identify_receiver_outputs( self, is_receiver_output: &mut impl FnMut(&Script) -> Result, ) -> Result { - let owned_vouts: Vec = self + let checked_outputs = + check_references(self.get_output_script_refs(), &mut |script: &ScriptBuf| { + is_receiver_output(script.as_script()) + })?; + self.apply_output_owned_checks(checked_outputs) + } + + pub fn get_output_script_refs( + &self, + ) -> impl Iterator> { + let final_index = self.psbt.unsigned_tx.output.len() - 1; + let script_references = self .psbt .unsigned_tx .output .iter() .enumerate() - .filter_map(|(vout, txo)| match is_receiver_output(&txo.script_pubkey) { - Ok(true) => Some(Ok(vout)), - Ok(false) => None, - Err(e) => Some(Err(e)), + .map(|(index, output)| { + Reference::::new( + output.script_pubkey.clone(), + index, + final_index, + ) }) - .collect::, _>>() - .map_err(Error::Implementation)?; + .collect::>(); + script_references.into_iter() + } + pub fn apply_output_owned_checks( + &self, + checked_output_scripts: impl IntoIterator>, + ) -> Result { + let validated_checks = validate_checks(checked_output_scripts)?; + let owned_vouts = validated_checks + .into_iter() + .filter_map(|checked_output| match checked_output.get_result() { + true => Some(checked_output.get_index()), + false => None, + }) + .collect::>(); if owned_vouts.is_empty() { return Err(InternalPayloadError::MissingPayment.into()); } @@ -1022,29 +1243,20 @@ pub(crate) mod tests { #[test] fn test_finalize_proposal() { + // Outcome 1: wallet_process_psbt returns a psbt with mismatched ntxid → ImplementationError let psbt_context = psbt_context_from_test_vector(); - - // Outcome 1: wallet_process_psbt returns an implementation error → ImplementationError let err = psbt_context .clone() - .finalize_proposal(|_| Err(ImplementationError::from("wallet signing failed"))) - .expect_err("Should fail when wallet_process_psbt returns an error"); - assert_eq!(err.to_string(), "wallet signing failed"); - - // Outcome 2: wallet_process_psbt returns a psbt with mismatched ntxid → ImplementationError - let psbt_context = psbt_context_from_test_vector(); - let err = psbt_context - .clone() - .finalize_proposal(|_| { + .finalize_signed_proposal( // return a totally different psbt to trigger ntxid mismatch - Ok(PARSED_ORIGINAL_PSBT.clone()) - }) + PARSED_ORIGINAL_PSBT.clone(), + ) .expect_err("Should fail when ntxid mismatches"); assert!(err.to_string().contains("Ntxid mismatch")); - // Outcome 3: wallet_process_psbt succeeds → Ok(Psbt) + // Outcome 2: wallet_process_psbt succeeds → Ok(Psbt) let _psbt = psbt_context - .finalize_proposal(|_| Ok(PARSED_PAYJOIN_PROPOSAL.clone())) + .finalize_signed_proposal(PARSED_PAYJOIN_PROPOSAL.clone()) .expect("Should succeed when wallet_process_psbt returns a valid signed psbt"); } } diff --git a/payjoin/src/core/receive/v1/mod.rs b/payjoin/src/core/receive/v1/mod.rs index a43ec727e..e102593e9 100644 --- a/payjoin/src/core/receive/v1/mod.rs +++ b/payjoin/src/core/receive/v1/mod.rs @@ -113,6 +113,42 @@ impl UncheckedOriginalPayload { Ok(MaybeInputsOwned { original: self.original }) } + /// Extracts the original PSBT so caller can check that the proposal can be broadcasted. + /// + /// Result of the broadcastibility check should then be returned to + /// [`Self::apply_broadcast_suitability`]. + /// + /// If the receiver is a non-interactive payment processor (ex. a donation page which generates + /// a new QR code for each visit), then it should make sure that the original PSBT is broadcastable + /// as a fallback mechanism in case the payjoin fails. This validation would be equivalent to + /// `testmempoolaccept` Bitcoin Core RPC call returning `{"allowed": true,...}`. + pub fn extract_tx_to_check_broadcast_suitability(&self) -> bitcoin::Transaction { + self.original.psbt.clone().extract_tx_unchecked_fee_rate() + } + + /// Processes the result of whether the original PSBT in the proposal can be broadcasted. + /// + /// Call [`Self::extract_tx_to_check_broadcast_suitability`] first to acquire the tx + /// to be checked for broadcastibility. + /// + /// If the receiver is a non-interactive payment processor (ex. a donation page which generates + /// a new QR code for each visit), then it should make sure that the original PSBT is broadcastable + /// as a fallback mechanism in case the payjoin fails. This validation would be equivalent to + /// `testmempoolaccept` Bitcoin Core RPC call returning `{"allowed": true,...}`. + /// + /// Receiver can optionally set a minimum fee rate which will be enforced on the original PSBT in the proposal. + /// This can be used to further prevent probing attacks since the attacker would now need to probe the receiver + /// with transactions which are both broadcastable and pay high fee. Unrelated to the probing attack scenario, + /// this parameter also makes operating in a high fee environment easier for the receiver. + pub fn apply_broadcast_suitability( + self, + min_fee_rate: Option, + is_broadcast_suitable: bool, + ) -> Result { + self.original.apply_broadcast_suitability(min_fee_rate, is_broadcast_suitable)?; + Ok(MaybeInputsOwned { original: self.original }) + } + /// Moves on to the next typestate without any of the current typestate's validations. /// /// Use this for interactive payment receivers, where there is no risk of a probing attack since the @@ -151,7 +187,35 @@ impl MaybeInputsOwned { self, is_owned: &mut impl FnMut(&Script) -> Result, ) -> Result { - self.original.check_inputs_not_owned(is_owned)?; + let checked_inputs = + check_references(self.get_input_script_refs()?, &mut |script: &ScriptBuf| { + is_owned(script.as_script()) + })?; + self.apply_input_owned_checks(checked_inputs) + } + + /// Get [`Reference`]s that hold the input scripts that need to be checked for ownership by the + /// receiver. + /// + /// Once completed, these checks should be submitted to [`Self::apply_input_owned_checks`]. + /// + /// An attacker can try to spend the receiver's own inputs. This check prevents that. + pub fn get_input_script_refs( + &self, + ) -> Result>, Error> { + self.original.get_input_script_refs() + } + + /// Applies the input ownership checks to advance the state machine. + /// + /// Use [`Self::get_input_script_refs`] to obtain the references that need to be checked. + /// + /// An attacker can try to spend the receiver's own inputs. This check prevents that. + pub fn apply_input_owned_checks( + self, + checked_input_scripts: impl IntoIterator>, + ) -> Result { + self.original.apply_input_owned_checks(checked_input_scripts)?; Ok(MaybeInputsSeen { original: self.original }) } } @@ -176,7 +240,42 @@ impl MaybeInputsSeen { self, is_known: &mut impl FnMut(&OutPoint) -> Result, ) -> Result { - self.original.check_no_inputs_seen_before(is_known)?; + let checked_inputs = check_references(self.get_input_outpoint_refs(), is_known)?; + self.apply_input_seen_checks(checked_inputs) + } + + /// Get [`Reference`]s that hold the input outpoints that need to be checked for whether they + /// have already been seen by the receiver. + /// + /// Once completed, these checks should be submitted to [`Self::apply_input_seen_checks`]. + /// + /// This check prevents the following attacks: + /// 1. Probing attacks, where the sender can use the exact same proposal (or with minimal change) + /// to have the receiver reveal their UTXO set by contributing to all proposals with different inputs + /// and sending them back to the receiver. + /// 2. Re-entrant payjoin, where the sender uses the payjoin PSBT of a previous payjoin as the + /// original proposal PSBT of the current, new payjoin. + pub fn get_input_outpoint_refs( + &self, + ) -> impl Iterator> { + self.original.get_input_outpoint_refs() + } + + /// Applies the input seen checks to advance the state machine. + /// + /// Use [`Self::get_input_outpoint_refs`] to obtain the references that need to be checked. + /// + /// This check prevents the following attacks: + /// 1. Probing attacks, where the sender can use the exact same proposal (or with minimal change) + /// to have the receiver reveal their UTXO set by contributing to all proposals with different inputs + /// and sending them back to the receiver. + /// 2. Re-entrant payjoin, where the sender uses the payjoin PSBT of a previous payjoin as the + /// original proposal PSBT of the current, new payjoin. + pub fn apply_input_seen_checks( + self, + checked_input_outpoints: impl IntoIterator>, + ) -> Result { + self.original.apply_input_seen_checks(checked_input_outpoints)?; Ok(OutputsUnknown { original: self.original }) } } @@ -208,7 +307,49 @@ impl OutputsUnknown { self, is_receiver_output: &mut impl FnMut(&Script) -> Result, ) -> Result { - self.original.identify_receiver_outputs(is_receiver_output) + let checked_outputs = + check_references(self.get_output_script_refs(), &mut |script: &ScriptBuf| { + is_receiver_output(script.as_script()) + })?; + self.apply_output_owned_checks(checked_outputs) + } + + /// Get [`Reference`]s that hold the output scripts that need to be checked for ownership + /// by the receiver. + /// + /// Once completed, these checks should be submitted to [`Self::apply_output_owned_checks`]. + /// + /// Additionally, this function also protects the receiver from accidentally subtracting fees + /// from their own outputs: when a sender is sending a proposal, + /// they can select an output which they want the receiver to subtract fees from to account for + /// the increased transaction size. If a sender specifies a receiver output for this purpose, this + /// function sets that parameter to None so that it is ignored in subsequent steps of the + /// receiver flow. This protects the receiver from accidentally subtracting fees from their own + /// outputs. + #[cfg_attr(not(feature = "v1"), allow(dead_code))] + pub fn get_output_script_refs( + &self, + ) -> impl Iterator> { + self.original.get_output_script_refs() + } + + /// Applies the output owned checks to advance the state machine. + /// + /// Use [`Self::get_output_script_refs`] to obtain the references that need to be checked. + /// + /// Additionally, this function also protects the receiver from accidentally subtracting fees + /// from their own outputs: when a sender is sending a proposal, + /// they can select an output which they want the receiver to subtract fees from to account for + /// the increased transaction size. If a sender specifies a receiver output for this purpose, this + /// function sets that parameter to None so that it is ignored in subsequent steps of the + /// receiver flow. This protects the receiver from accidentally subtracting fees from their own + /// outputs. + #[cfg_attr(not(feature = "v1"), allow(dead_code))] + pub fn apply_output_owned_checks( + &self, + checked_output_scripts: impl IntoIterator>, + ) -> Result { + self.original.apply_output_owned_checks(checked_output_scripts) } } @@ -290,11 +431,9 @@ impl ProvisionalProposal { self, wallet_process_psbt: impl Fn(&Psbt) -> Result, ) -> Result { - let finalized_psbt = self - .psbt_context - .finalize_proposal(wallet_process_psbt) - .map_err(|e| Error::Implementation(ImplementationError::new(e)))?; - Ok(PayjoinProposal { payjoin_psbt: finalized_psbt }) + let psbt = self.psbt_to_sign(); + let signed_psbt = wallet_process_psbt(&psbt)?; + self.finalize_signed_proposal(&signed_psbt) } /// The Payjoin proposal PSBT that the receiver needs to sign @@ -303,6 +442,17 @@ impl ProvisionalProposal { /// is different from the entity that has access to the private keys, /// so the PSBT to sign must be accessible to such implementers. pub fn psbt_to_sign(&self) -> Psbt { self.psbt_context.psbt_to_sign() } + + /// Finalizes the Payjoin proposal into a PSBT which the sender will find acceptable before + /// they sign the transaction and broadcast it to the network. + /// + /// This takes a receiver signed PSBT payjoin proposal and finalizes it for broadcast to + /// the sender. Use [`Self::psbt_to_sign`] to obtain the payjoin proposal's unsigned + /// PSBT for receiver to sign and return here. + pub fn finalize_signed_proposal(self, signed_psbt: &Psbt) -> Result { + let finalized_psbt = self.psbt_context.finalize_signed_proposal(signed_psbt.clone())?; + Ok(PayjoinProposal { payjoin_psbt: finalized_psbt }) + } } /// A finalized Payjoin proposal, complete with fees and receiver signatures, that the sender diff --git a/payjoin/src/core/receive/v2/mod.rs b/payjoin/src/core/receive/v2/mod.rs index 2c3f1bd5b..fb3e0d12a 100644 --- a/payjoin/src/core/receive/v2/mod.rs +++ b/payjoin/src/core/receive/v2/mod.rs @@ -30,7 +30,7 @@ use std::time::Duration; use bitcoin::hashes::{sha256, Hash}; use bitcoin::psbt::Psbt; -use bitcoin::{Address, Amount, FeeRate, OutPoint, Script, TxOut, Txid}; +use bitcoin::{Address, Amount, FeeRate, OutPoint, Script, ScriptBuf, Transaction, TxOut, Txid}; pub(crate) use error::InternalSessionError; pub use error::SessionError; use serde::de::Deserializer; @@ -57,7 +57,10 @@ use crate::persist::{ MaybeFatalOrSuccessTransition, MaybeFatalTransition, MaybeFatalTransitionWithNoResults, MaybeSuccessTransition, MaybeTransientTransition, NextStateTransition, TerminalTransition, }; -use crate::receive::{parse_payload, InputPair, OriginalPayload, PsbtContext}; +use crate::receive::{ + check_references, parse_payload, InputOwnedTag, InputPair, InputSeenTag, OriginalPayload, + OutputOwnedTag, PsbtContext, Reference, TaggedReference, +}; use crate::time::Time; use crate::uri::ShortId; use crate::{ImplementationError, IntoUrl, IntoUrlError, Request, Version}; @@ -640,7 +643,68 @@ impl Receiver { Error, Receiver, > { - match self.state.original.check_broadcast_suitability(min_fee_rate, can_broadcast) { + let tx = self.extract_tx_to_check_broadcast_suitability(); + match can_broadcast(&tx) { + Ok(is_broadcast_suitable) => + self.apply_broadcast_suitability(min_fee_rate, is_broadcast_suitable), + Err(e) => MaybeFatalTransition::transient(e.into()), + } + } + + /// Moves on to the next typestate without any of the current typestate's validations. + /// + /// Use this for interactive payment receivers, where there is no risk of a probing attack since the + /// receiver needs to manually create payjoin URIs. + pub fn assume_interactive_receiver( + self, + ) -> NextStateTransition> { + NextStateTransition::success( + SessionEvent::CheckedBroadcastSuitability(), + Receiver { + state: MaybeInputsOwned { original: self.original.clone() }, + session_context: self.session_context, + }, + ) + } + + /// Extracts the original PSBT so caller can check that the proposal can be broadcasted. + /// + /// Result of the broadcastibility check should then be returned to + /// [`Receiver::apply_broadcast_suitability`]. + /// + /// If the receiver is a non-interactive payment processor (ex. a donation page which generates + /// a new QR code for each visit), then it should make sure that the original PSBT is broadcastable + /// as a fallback mechanism in case the payjoin fails. This validation would be equivalent to + /// `testmempoolaccept` Bitcoin Core RPC call returning `{"allowed": true,...}`. + pub fn extract_tx_to_check_broadcast_suitability(&self) -> bitcoin::Transaction { + self.original.psbt.clone().extract_tx_unchecked_fee_rate() + } + + /// Processes the result of whether the original PSBT in the proposal can be broadcasted. + /// + /// Call [`Receiver::extract_tx_to_check_broadcast_suitability`] first to + /// acquire the tx to be checked for broadcastibility. + /// + /// If the receiver is a non-interactive payment processor (ex. a donation page which generates + /// a new QR code for each visit), then it should make sure that the original PSBT is broadcastable + /// as a fallback mechanism in case the payjoin fails. This validation would be equivalent to + /// `testmempoolaccept` Bitcoin Core RPC call returning `{"allowed": true,...}`. + /// + /// Receiver can optionally set a minimum fee rate which will be enforced on the original PSBT in the proposal. + /// This can be used to further prevent probing attacks since the attacker would now need to probe the receiver + /// with transactions which are both broadcastable and pay high fee. Unrelated to the probing attack scenario, + /// this parameter also makes operating in a high fee environment easier for the receiver. + pub fn apply_broadcast_suitability( + self, + min_fee_rate: Option, + is_broadcast_suitable: bool, + ) -> MaybeFatalTransition< + SessionEvent, + Receiver, + Error, + Receiver, + > { + match self.state.original.apply_broadcast_suitability(min_fee_rate, is_broadcast_suitable) { Ok(()) => MaybeFatalTransition::success( SessionEvent::CheckedBroadcastSuitability(), Receiver { @@ -661,22 +725,6 @@ impl Receiver { } } - /// Moves on to the next typestate without any of the current typestate's validations. - /// - /// Use this for interactive payment receivers, where there is no risk of a probing attack since the - /// receiver needs to manually create payjoin URIs. - pub fn assume_interactive_receiver( - self, - ) -> NextStateTransition> { - NextStateTransition::success( - SessionEvent::CheckedBroadcastSuitability(), - Receiver { - state: MaybeInputsOwned { original: self.original.clone() }, - session_context: self.session_context, - }, - ) - } - pub(crate) fn apply_checked_broadcast_suitability(self) -> ReceiveSession { let new_state = Receiver { state: MaybeInputsOwned { original: self.original.clone() }, @@ -720,7 +768,55 @@ impl Receiver { Error, Receiver, > { - match self.state.original.check_inputs_not_owned(is_owned) { + match self.get_input_script_refs() { + Ok(input_scripts) => match check_references(input_scripts, &mut |script: &ScriptBuf| { + is_owned(script.as_script()) + }) { + Ok(checked_input_scripts) => self.apply_input_owned_checks(checked_input_scripts), + Err(e) => MaybeFatalTransition::transient(e.into()), + }, + Err(e) => match e { + Error::Implementation(_) => MaybeFatalTransition::transient(e), + _ => MaybeFatalTransition::replyable_error( + SessionEvent::GotReplyableError((&e).into()), + Receiver { + state: HasReplyableError { error_reply: (&e).into() }, + session_context: self.session_context, + }, + e, + ), + }, + } + } + + /// Get [`Reference`]s that hold the input scripts that need to be checked for ownership by the + /// receiver. + /// + /// Once completed, these checks should be submitted to + /// [`Receiver::apply_input_owned_checks`]. + /// + /// An attacker can try to spend the receiver's own inputs. This check prevents that. + pub fn get_input_script_refs( + &self, + ) -> Result>, Error> { + self.state.original.get_input_script_refs() + } + + /// Applies the input ownership checks to advance the state machine. + /// + /// Use [`Receiver::get_input_script_refs`] to obtain the references that need to be checked. + /// + /// An attacker can try to spend the receiver's own inputs. This check prevents that. + pub fn apply_input_owned_checks( + self, + checked_input_scripts: impl IntoIterator>, + ) -> MaybeFatalTransition< + SessionEvent, + Receiver, + Error, + Receiver, + > { + match self.state.original.apply_input_owned_checks(checked_input_scripts) { Ok(()) => MaybeFatalTransition::success( SessionEvent::CheckedInputsNotOwned(), Receiver { @@ -777,7 +873,50 @@ impl Receiver { Error, Receiver, > { - match self.state.original.check_no_inputs_seen_before(is_known) { + match check_references(self.get_input_outpoint_refs(), is_known) { + Ok(checked_input_outpoints) => self.apply_input_seen_checks(checked_input_outpoints), + Err(e) => MaybeFatalTransition::transient(e.into()), + } + } + + /// Get [`Reference`]s that hold the input outpoints that need to be checked for whether they + /// have already been seen by the receiver. + /// + /// Once completed, these checks should be submitted to + /// [`Receiver::apply_input_seen_checks`]. + /// + /// This check prevents the following attacks: + /// 1. Probing attacks, where the sender can use the exact same proposal (or with minimal change) + /// to have the receiver reveal their UTXO set by contributing to all proposals with different inputs + /// and sending them back to the receiver. + /// 2. Re-entrant payjoin, where the sender uses the payjoin PSBT of a previous payjoin as the + /// original proposal PSBT of the current, new payjoin. + pub fn get_input_outpoint_refs( + &self, + ) -> impl Iterator> { + self.state.original.get_input_outpoint_refs() + } + + /// Applies the input seen checks to advance the state machine. + /// + /// Use [`Receiver::get_input_outpoint_refs`] to obtain the references that need to be checked. + /// + /// This check prevents the following attacks: + /// 1. Probing attacks, where the sender can use the exact same proposal (or with minimal change) + /// to have the receiver reveal their UTXO set by contributing to all proposals with different inputs + /// and sending them back to the receiver. + /// 2. Re-entrant payjoin, where the sender uses the payjoin PSBT of a previous payjoin as the + /// original proposal PSBT of the current, new payjoin. + pub fn apply_input_seen_checks( + self, + checked_input_outpoints: impl IntoIterator>, + ) -> MaybeFatalTransition< + SessionEvent, + Receiver, + Error, + Receiver, + > { + match self.state.original.apply_input_seen_checks(checked_input_outpoints) { Ok(()) => MaybeFatalTransition::success( SessionEvent::CheckedNoInputsSeenBefore(), Receiver { @@ -839,7 +978,55 @@ impl Receiver { Error, Receiver, > { - match self.state.original.identify_receiver_outputs(is_receiver_output) { + match check_references(self.get_output_script_refs(), &mut |script: &ScriptBuf| { + is_receiver_output(script.as_script()) + }) { + Ok(checked_output_scripts) => self.apply_output_owned_checks(checked_output_scripts), + Err(e) => MaybeFatalTransition::transient(e.into()), + } + } + + /// Get [`Reference`]s that hold the output scripts that need to be checked for ownership + /// by the receiver. + /// + /// Once completed, these checks should be submitted to + /// [`Receiver::apply_output_owned_checks`]. + /// + /// Additionally, this function also protects the receiver from accidentally subtracting fees + /// from their own outputs: when a sender is sending a proposal, + /// they can select an output which they want the receiver to subtract fees from to account for + /// the increased transaction size. If a sender specifies a receiver output for this purpose, this + /// function sets that parameter to None so that it is ignored in subsequent steps of the + /// receiver flow. This protects the receiver from accidentally subtracting fees from their own + /// outputs. + pub fn get_output_script_refs( + &self, + ) -> impl Iterator> { + self.state.original.get_output_script_refs() + } + + /// Applies the output owned checks to advance the state machine. + /// + /// Use [`Receiver::get_output_script_refs`] to obtain the references that need + /// to be checked. + /// + /// Additionally, this function also protects the receiver from accidentally subtracting fees + /// from their own outputs: when a sender is sending a proposal, + /// they can select an output which they want the receiver to subtract fees from to account for + /// the increased transaction size. If a sender specifies a receiver output for this purpose, this + /// function sets that parameter to None so that it is ignored in subsequent steps of the + /// receiver flow. This protects the receiver from accidentally subtracting fees from their own + /// outputs. + pub fn apply_output_owned_checks( + self, + checked_output_scripts: impl IntoIterator>, + ) -> MaybeFatalTransition< + SessionEvent, + Receiver, + Error, + Receiver, + > { + match self.state.original.apply_output_owned_checks(checked_output_scripts) { Ok(inner) => MaybeFatalTransition::success( SessionEvent::IdentifiedReceiverOutputs(inner.owned_vouts.clone()), Receiver { state: WantsOutputs { inner }, session_context: self.session_context }, @@ -1084,19 +1271,12 @@ impl Receiver { wallet_process_psbt: impl Fn(&Psbt) -> Result, ) -> MaybeTransientTransition, ImplementationError> { - let original_psbt = self.state.psbt_context.original_psbt.clone(); - let payjoin_psbt = match self.state.psbt_context.finalize_proposal(wallet_process_psbt) { - Ok(payjoin_psbt) => payjoin_psbt, - Err(e) => { - return MaybeTransientTransition::transient(e); - } - }; - let psbt_context = PsbtContext { payjoin_psbt: payjoin_psbt.clone(), original_psbt }; - let payjoin_proposal = PayjoinProposal { psbt_context: psbt_context.clone() }; - MaybeTransientTransition::success( - SessionEvent::FinalizedProposal(payjoin_psbt), - Receiver { state: payjoin_proposal, session_context: self.session_context }, - ) + let psbt = self.psbt_to_sign(); + let signed_psbt = wallet_process_psbt(&psbt); + match signed_psbt { + Ok(signed_psbt) => self.finalize_signed_proposal(&signed_psbt), + Err(e) => MaybeTransientTransition::transient(e), + } } /// The Payjoin proposal PSBT that the receiver needs to sign @@ -1106,6 +1286,33 @@ impl Receiver { /// so the PSBT to sign must be accessible to such implementers. pub fn psbt_to_sign(&self) -> Psbt { self.state.psbt_context.psbt_to_sign() } + /// Finalizes the Payjoin proposal into a PSBT which the sender will find acceptable before + /// they sign the transaction and broadcast it to the network. + /// + /// This takes a receiver signed PSBT payjoin proposal and finalizes it for broadcast to + /// the sender. Use [`Receiver::psbt_to_sign`] to obtain the payjoin + /// proposal's unsigned PSBT for receiver to sign and return here. + pub fn finalize_signed_proposal( + self, + signed_psbt: &Psbt, + ) -> MaybeTransientTransition, ImplementationError> + { + let original_psbt = self.state.psbt_context.original_psbt.clone(); + let payjoin_psbt = + match self.state.psbt_context.finalize_signed_proposal(signed_psbt.clone()) { + Ok(payjoin_psbt) => payjoin_psbt, + Err(e) => { + return MaybeTransientTransition::transient(e); + } + }; + let psbt_context = PsbtContext { payjoin_psbt: payjoin_psbt.clone(), original_psbt }; + let payjoin_proposal = PayjoinProposal { psbt_context: psbt_context.clone() }; + MaybeTransientTransition::success( + SessionEvent::FinalizedProposal(payjoin_psbt), + Receiver { state: payjoin_proposal, session_context: self.session_context }, + ) + } + pub(crate) fn apply_payjoin_proposal(self, payjoin_psbt: Psbt) -> ReceiveSession { let psbt_context = PsbtContext { payjoin_psbt, @@ -1311,66 +1518,95 @@ impl Receiver { &self, transaction_exists: impl Fn(Txid) -> Result, ImplementationError>, ) -> MaybeFatalOrSuccessTransition { - let fallback_tx = self - .state - .psbt_context - .original_psbt - .clone() - .extract_tx_fee_rate_limit() - .expect("fallback transaction should be in the receiver context"); - - // If the fallback transaction included any non-SegWit inputs, then the transaction ID of - // the Payjoin proposal is going to change when the sender signs their non-SegWit address - // one more time. The receiver cannot monitor the transaction, and should conclude the session. - if fallback_tx.input.iter().any(|txin| txin.witness.is_empty()) { - return MaybeFatalOrSuccessTransition::success(SessionEvent::Closed( - SessionOutcome::PayjoinProposalSent, - )); + if let transition @ MaybeFatalOrSuccessTransition::Success(_) = + self.check_fallback_monitorable() + { + return transition; } - let payjoin_proposal = &self.state.psbt_context.payjoin_psbt; - let payjoin_txid = payjoin_proposal.unsigned_tx.compute_txid(); // If the sender is spending SegWit-only inputs, then the transaction ID of the Payjoin proposal // is not going to change when the sender signs it. So we can use the TXID to check the // network for the Payjoin proposal. - match transaction_exists(payjoin_txid) { - Ok(Some(tx)) => { - let tx_id = tx.compute_txid(); - if tx_id != payjoin_txid { - return MaybeFatalOrSuccessTransition::transient(Error::Implementation( - ImplementationError::from(format!("Payjoin transaction ID mismatch. Expected: {payjoin_txid}, Got: {tx_id}").as_str()), - )); - } - // TODO: should we check for witness and scriptsig on the tx? - let mut sender_witnesses = vec![]; - - for i in self.state.psbt_context.sender_input_indexes() { - let input = - tx.input.get(i).expect("sender_input_indexes should return valid indices"); - sender_witnesses.push((input.script_sig.clone(), input.witness.clone())); - } - // Payjoin transaction with SegWit inputs was detected. Log the signatures and complete the session. - return MaybeFatalOrSuccessTransition::success(SessionEvent::Closed( - SessionOutcome::Success(sender_witnesses), - )); - } + match transaction_exists(self.extract_payjoin_proposal_txid()) { + Ok(Some(tx)) => return self.payjoin_tx_exists(tx), Ok(None) => {} Err(e) => return MaybeFatalOrSuccessTransition::transient(Error::Implementation(e)), } // If the Payjoin proposal was not found, check the fallback transaction, as it is // the second of two transactions whose IDs the receiver is aware of. - match transaction_exists(fallback_tx.compute_txid()) { - Ok(Some(_)) => - return MaybeFatalOrSuccessTransition::success(SessionEvent::Closed( - SessionOutcome::FallbackBroadcasted, - )), + match transaction_exists(self.extract_fallback_txid()) { + Ok(Some(_)) => return self.fallback_tx_exists(), Ok(None) => {} Err(e) => return MaybeFatalOrSuccessTransition::transient(Error::Implementation(e)), } MaybeFatalOrSuccessTransition::no_results(self.clone()) } + + pub fn extract_fallback_tx(&self) -> Transaction { + self.state.psbt_context.original_psbt.clone().extract_tx_unchecked_fee_rate() + } + + pub fn extract_fallback_txid(&self) -> Txid { self.extract_fallback_tx().compute_txid() } + + pub fn extract_payjoin_proposal_txid(&self) -> Txid { + self.state.psbt_context.payjoin_psbt.clone().extract_tx_unchecked_fee_rate().compute_txid() + } + + pub fn check_fallback_monitorable( + &self, + ) -> MaybeFatalOrSuccessTransition { + // If the fallback transaction included any non-SegWit inputs, then the transaction ID of + // the Payjoin proposal is going to change when the sender signs their non-SegWit address + // one more time. The receiver cannot monitor the transaction, and should conclude the session. + if has_empty_witness(&self.extract_fallback_tx()) { + return MaybeFatalOrSuccessTransition::success(SessionEvent::Closed( + SessionOutcome::PayjoinProposalSent, + )); + } + + MaybeFatalOrSuccessTransition::no_results(self.clone()) + } + + pub fn fallback_tx_exists(&self) -> MaybeFatalOrSuccessTransition { + MaybeFatalOrSuccessTransition::success(SessionEvent::Closed( + SessionOutcome::FallbackBroadcasted, + )) + } + + pub fn payjoin_tx_exists( + &self, + tx: Transaction, + ) -> MaybeFatalOrSuccessTransition { + // TODO: should we check for witness and scriptsig on the tx? + let payjoin_txid = self.state.psbt_context.payjoin_psbt.unsigned_tx.compute_txid(); + let tx_id = tx.compute_txid(); + if tx_id != payjoin_txid { + return MaybeFatalOrSuccessTransition::transient(Error::Implementation( + ImplementationError::from( + format!( + "Payjoin transaction ID mismatch. Expected: {payjoin_txid}, Got: {tx_id}" + ) + .as_str(), + ), + )); + } + let mut sender_witnesses = vec![]; + + for i in self.state.psbt_context.sender_input_indexes() { + let input = tx.input.get(i).expect("sender_input_indexes should return valid indices"); + sender_witnesses.push((input.script_sig.clone(), input.witness.clone())); + } + // Payjoin transaction with SegWit inputs was detected. Log the signatures and complete the session. + MaybeFatalOrSuccessTransition::success(SessionEvent::Closed(SessionOutcome::Success( + sender_witnesses, + ))) + } +} + +fn has_empty_witness(tx: &Transaction) -> bool { + tx.input.iter().any(|txin| txin.witness.is_empty()) } /// Derive a mailbox endpoint on a directory given a [`ShortId`]. From 068252b27e4fcf21bd15cefb666b507d1d24255b Mon Sep 17 00:00:00 2001 From: xstoicunicornx Date: Wed, 13 May 2026 19:32:26 -0500 Subject: [PATCH 5/8] Add FFI bindings for non-blocking receive interface Expose the two-phase validation API from the previous commit through the FFI bindings layer. Update integration tests in C#, Dart, JavaScript, and Python to exercise both callback and nonblocking transition modes. --- payjoin-ffi/csharp/IntegrationTests.cs | 138 ++++- .../test/test_payjoin_integration_test.dart | 526 +++++++++++------- .../javascript/test/integration.test.ts | 128 ++++- .../test/test_payjoin_integration_test.py | 175 ++++-- payjoin-ffi/src/receive/mod.rs | 227 +++++++- 5 files changed, 898 insertions(+), 296 deletions(-) diff --git a/payjoin-ffi/csharp/IntegrationTests.cs b/payjoin-ffi/csharp/IntegrationTests.cs index 032686f65..10951837c 100644 --- a/payjoin-ffi/csharp/IntegrationTests.cs +++ b/payjoin-ffi/csharp/IntegrationTests.cs @@ -5,6 +5,12 @@ namespace Payjoin.Tests { + public enum TransitionMode + { + Callback, + Nonblocking, + } + public class IntegrationTests : IAsyncLifetime { private static string RpcCall(RpcClient rpc, string method, params string?[] args) => rpc.Call(method, args); @@ -168,6 +174,7 @@ private static InputPair[] GetInputs(RpcClient rpc) RpcClient receiverRpc, InMemoryReceiverPersister recvPersister, string ohttpRelay, + TransitionMode mode, CancellationToken cancellationToken) { var request = receiver.CreatePollRequest(ohttpRelay); @@ -192,7 +199,7 @@ private static InputPair[] GetInputs(RpcClient rpc) if (outcome is InitializedTransitionOutcome.Progress progress) { using var proposal = progress.inner; - return await ProcessUncheckedProposal(proposal, receiverRpc, recvPersister); + return await ProcessUncheckedProposal(proposal, receiverRpc, recvPersister, mode); } throw new InvalidOperationException("Unknown initialized transition outcome"); @@ -201,88 +208,157 @@ private static InputPair[] GetInputs(RpcClient rpc) private Task ProcessUncheckedProposal( UncheckedOriginalPayload proposal, RpcClient receiverRpc, - InMemoryReceiverPersister recvPersister) + InMemoryReceiverPersister recvPersister, + TransitionMode mode) { - using var checkedTransition = proposal.CheckBroadcastSuitability(null, new MempoolAcceptanceCallback(receiverRpc)); - using var maybeInputsOwned = checkedTransition.Save(recvPersister); + MaybeInputsOwned maybeInputsOwned; + + if (mode == TransitionMode.Callback) + { + using var checkedTransition = proposal.CheckBroadcastSuitability(null, new MempoolAcceptanceCallback(receiverRpc)); + maybeInputsOwned = checkedTransition.Save(recvPersister); + } + else + { + var canBroadcast = new MempoolAcceptanceCallback(receiverRpc).Callback(proposal.ExtractTxToCheckBroadcastSuitability()); + using var checkedTransition = proposal.ApplyBroadcastSuitability(null, canBroadcast); + maybeInputsOwned = checkedTransition.Save(recvPersister); + } - return ProcessMaybeInputsOwned(maybeInputsOwned, receiverRpc, recvPersister); + return ProcessMaybeInputsOwned(maybeInputsOwned, receiverRpc, recvPersister, mode); } private Task ProcessMaybeInputsOwned( MaybeInputsOwned proposal, RpcClient receiverRpc, - InMemoryReceiverPersister recvPersister) + InMemoryReceiverPersister recvPersister, + TransitionMode mode) { - using var transition = proposal.CheckInputsNotOwned(new IsScriptOwnedCallback(receiverRpc)); - using var maybeInputsSeen = transition.Save(recvPersister); + MaybeInputsSeen maybeInputsSeen; - return ProcessMaybeInputsSeen(maybeInputsSeen, receiverRpc, recvPersister); + if (mode == TransitionMode.Callback) + { + using var transition = proposal.CheckInputsNotOwned(new IsScriptOwnedCallback(receiverRpc)); + maybeInputsSeen = transition.Save(recvPersister); + } + else + { + var taggedRefs = proposal.GetInputScriptRefs() + .Select(r => r.Mark(new IsScriptOwnedCallback(receiverRpc).Callback(r.GetValue()))) + .ToArray(); + using var transition = proposal.ApplyInputOwnedChecks(taggedRefs); + maybeInputsSeen = transition.Save(recvPersister); + } + + return ProcessMaybeInputsSeen(maybeInputsSeen, receiverRpc, recvPersister, mode); } private Task ProcessMaybeInputsSeen( MaybeInputsSeen proposal, RpcClient receiverRpc, - InMemoryReceiverPersister recvPersister) + InMemoryReceiverPersister recvPersister, + TransitionMode mode) { - using var transition = proposal.CheckNoInputsSeenBefore(new CheckInputsNotSeenCallback()); - using var outputsUnknown = transition.Save(recvPersister); + OutputsUnknown outputsUnknown; - return ProcessOutputsUnknown(outputsUnknown, receiverRpc, recvPersister); + if (mode == TransitionMode.Callback) + { + using var transition = proposal.CheckNoInputsSeenBefore(new CheckInputsNotSeenCallback()); + outputsUnknown = transition.Save(recvPersister); + } + else + { + var taggedRefs = proposal.GetInputOutpointRefs() + .Select(r => r.Mark(new CheckInputsNotSeenCallback().Callback(r.GetValue()))) + .ToArray(); + using var transition = proposal.ApplyInputSeenChecks(taggedRefs); + outputsUnknown = transition.Save(recvPersister); + } + + return ProcessOutputsUnknown(outputsUnknown, receiverRpc, recvPersister, mode); } private Task ProcessOutputsUnknown( OutputsUnknown proposal, RpcClient receiverRpc, - InMemoryReceiverPersister recvPersister) + InMemoryReceiverPersister recvPersister, + TransitionMode mode) { - using var transition = proposal.IdentifyReceiverOutputs(new IsScriptOwnedCallback(receiverRpc)); - using var wantsOutputs = transition.Save(recvPersister); + WantsOutputs wantsOutputs; - return ProcessWantsOutputs(wantsOutputs, receiverRpc, recvPersister); + if (mode == TransitionMode.Callback) + { + using var transition = proposal.IdentifyReceiverOutputs(new IsScriptOwnedCallback(receiverRpc)); + wantsOutputs = transition.Save(recvPersister); + } + else + { + var taggedRefs = proposal.GetOutputScriptRefs() + .Select(r => r.Mark(new IsScriptOwnedCallback(receiverRpc).Callback(r.GetValue()))) + .ToArray(); + using var transition = proposal.ApplyOutputOwnedChecks(taggedRefs); + wantsOutputs = transition.Save(recvPersister); + } + + return ProcessWantsOutputs(wantsOutputs, receiverRpc, recvPersister, mode); } private Task ProcessWantsOutputs( WantsOutputs proposal, RpcClient receiverRpc, - InMemoryReceiverPersister recvPersister) + InMemoryReceiverPersister recvPersister, + TransitionMode mode) { using var transition = proposal.CommitOutputs(); using var wantsInputs = transition.Save(recvPersister); - return ProcessWantsInputs(wantsInputs, receiverRpc, recvPersister); + return ProcessWantsInputs(wantsInputs, receiverRpc, recvPersister, mode); } private Task ProcessWantsInputs( WantsInputs proposal, RpcClient receiverRpc, - InMemoryReceiverPersister recvPersister) + InMemoryReceiverPersister recvPersister, + TransitionMode mode) { using var contributed = proposal.ContributeInputs(GetInputs(receiverRpc)); using var transition = contributed.CommitInputs(); using var wantsFeeRange = transition.Save(recvPersister); - return ProcessWantsFeeRange(wantsFeeRange, receiverRpc, recvPersister); + return ProcessWantsFeeRange(wantsFeeRange, receiverRpc, recvPersister, mode); } private Task ProcessWantsFeeRange( WantsFeeRange proposal, RpcClient receiverRpc, - InMemoryReceiverPersister recvPersister) + InMemoryReceiverPersister recvPersister, + TransitionMode mode) { using var transition = proposal.ApplyFeeRange(1, 10); using var provisional = transition.Save(recvPersister); - return ProcessProvisionalProposal(provisional, receiverRpc, recvPersister); + return ProcessProvisionalProposal(provisional, receiverRpc, recvPersister, mode); } private Task ProcessProvisionalProposal( ProvisionalProposal proposal, RpcClient receiverRpc, - InMemoryReceiverPersister recvPersister) + InMemoryReceiverPersister recvPersister, + TransitionMode mode) { - using var transition = proposal.FinalizeProposal(new ProcessPsbtCallback(receiverRpc)); - var payjoinProposal = transition.Save(recvPersister); + PayjoinProposal payjoinProposal; + + if (mode == TransitionMode.Callback) + { + using var transition = proposal.FinalizeProposal(new ProcessPsbtCallback(receiverRpc)); + payjoinProposal = transition.Save(recvPersister); + } + else + { + var signedPsbt = new ProcessPsbtCallback(receiverRpc).Callback(proposal.PsbtToSign()); + using var transition = proposal.FinalizeSignedProposal(signedPsbt); + payjoinProposal = transition.Save(recvPersister); + } return Task.FromResult(payjoinProposal); } @@ -438,8 +514,10 @@ public void TestFfiValidation() }); } - [Fact] - public async Task TestIntegrationV2ToV2() + [Theory] + [InlineData(TransitionMode.Callback)] + [InlineData(TransitionMode.Nonblocking)] + public async Task TestIntegrationV2ToV2(TransitionMode mode) { var cancellationToken = TestContext.Current.CancellationToken; @@ -465,7 +543,7 @@ public async Task TestIntegrationV2ToV2() using var receiveTransition = receiverBuilder.Build(); using var session = receiveTransition.Save(recvPersister); - var initial = await RetrieveReceiverProposal(session, receiver, recvPersister, ohttpRelay, cancellationToken); + var initial = await RetrieveReceiverProposal(session, receiver, recvPersister, ohttpRelay, mode, cancellationToken); Assert.Null(initial); // ***************************** @@ -500,7 +578,7 @@ public async Task TestIntegrationV2ToV2() // ********************* // RECEIVER SIDE // Poll for the proposal - using var payjoinProposal = await RetrieveReceiverProposal(session, receiver, recvPersister, ohttpRelay, cancellationToken); + using var payjoinProposal = await RetrieveReceiverProposal(session, receiver, recvPersister, ohttpRelay, mode, cancellationToken); Assert.NotNull(payjoinProposal); Assert.IsType(payjoinProposal); diff --git a/payjoin-ffi/dart/test/test_payjoin_integration_test.dart b/payjoin-ffi/dart/test/test_payjoin_integration_test.dart index b40d40ea9..df6ee4aef 100644 --- a/payjoin-ffi/dart/test/test_payjoin_integration_test.dart +++ b/payjoin-ffi/dart/test/test_payjoin_integration_test.dart @@ -15,6 +15,8 @@ late test_utils.BitcoindInstance bitcoind; late test_utils.RpcClient receiver; late test_utils.RpcClient sender; +enum TransitionMode { callback, nonblocking } + class MempoolAcceptanceCallback implements payjoin.CanBroadcast { final payjoin.RpcClient connection; @@ -202,91 +204,186 @@ List get_inputs(payjoin.RpcClient rpc_connection) { Future process_provisional_proposal( payjoin.ProvisionalProposal proposal, InMemoryReceiverPersister recv_persister, + TransitionMode mode, ) async { - final payjoin_proposal = proposal - .finalizeProposal(processPsbt: ProcessPsbtCallback(receiver)) - .save(persister: recv_persister); + final payjoin.PayjoinProposal payjoin_proposal; + if (mode == TransitionMode.callback) { + payjoin_proposal = proposal + .finalizeProposal(processPsbt: ProcessPsbtCallback(receiver)) + .save(persister: recv_persister); + } else { + final signed_psbt = ProcessPsbtCallback( + receiver, + ).callback(proposal.psbtToSign()); + payjoin_proposal = proposal + .finalizeSignedProposal(signedPsbt: signed_psbt) + .save(persister: recv_persister); + } return payjoin.PayjoinProposalReceiveSession(payjoin_proposal); } Future process_wants_fee_range( payjoin.WantsFeeRange proposal, InMemoryReceiverPersister recv_persister, + TransitionMode mode, ) async { final wants_fee_range = proposal .applyFeeRange(minFeeRateSatPerVb: 1, maxEffectiveFeeRateSatPerVb: 10) .save(persister: recv_persister); - return await process_provisional_proposal(wants_fee_range, recv_persister); + return await process_provisional_proposal( + wants_fee_range, + recv_persister, + mode, + ); } Future process_wants_inputs( payjoin.WantsInputs proposal, InMemoryReceiverPersister recv_persister, + TransitionMode mode, ) async { final provisional_proposal = proposal .contributeInputs(replacementInputs: get_inputs(receiver)) .commitInputs() .save(persister: recv_persister); - return await process_wants_fee_range(provisional_proposal, recv_persister); + return await process_wants_fee_range( + provisional_proposal, + recv_persister, + mode, + ); } Future process_wants_outputs( payjoin.WantsOutputs proposal, InMemoryReceiverPersister recv_persister, + TransitionMode mode, ) async { final wants_inputs = proposal.commitOutputs().save(persister: recv_persister); - return await process_wants_inputs(wants_inputs, recv_persister); + return await process_wants_inputs(wants_inputs, recv_persister, mode); } Future process_outputs_unknown( payjoin.OutputsUnknown proposal, InMemoryReceiverPersister recv_persister, + TransitionMode mode, ) async { - final wants_outputs = proposal - .identifyReceiverOutputs( - isReceiverOutput: IsScriptOwnedCallback(receiver), - ) - .save(persister: recv_persister); - return await process_wants_outputs(wants_outputs, recv_persister); + final payjoin.WantsOutputs wants_outputs; + if (mode == TransitionMode.callback) { + wants_outputs = proposal + .identifyReceiverOutputs( + isReceiverOutput: IsScriptOwnedCallback(receiver), + ) + .save(persister: recv_persister); + } else { + final tagged_refs = proposal + .getOutputScriptRefs() + .map( + (ref) => ref.mark( + result: IsScriptOwnedCallback(receiver).callback(ref.getValue()), + ), + ) + .toList(); + wants_outputs = proposal + .applyOutputOwnedChecks(checkedOutputScripts: tagged_refs) + .save(persister: recv_persister); + } + return await process_wants_outputs(wants_outputs, recv_persister, mode); } Future process_maybe_inputs_seen( payjoin.MaybeInputsSeen proposal, InMemoryReceiverPersister recv_persister, + TransitionMode mode, ) async { - final outputs_unknown = proposal - .checkNoInputsSeenBefore(isKnown: CheckInputsNotSeenCallback(receiver)) - .save(persister: recv_persister); - return await process_outputs_unknown(outputs_unknown, recv_persister); + final payjoin.OutputsUnknown outputs_unknown; + if (mode == TransitionMode.callback) { + outputs_unknown = proposal + .checkNoInputsSeenBefore(isKnown: CheckInputsNotSeenCallback(receiver)) + .save(persister: recv_persister); + } else { + final tagged_refs = proposal + .getInputOutpointRefs() + .map( + (ref) => ref.mark( + result: CheckInputsNotSeenCallback( + receiver, + ).callback(ref.getValue()), + ), + ) + .toList(); + outputs_unknown = proposal + .applyInputSeenChecks(checkedInputOutpoints: tagged_refs) + .save(persister: recv_persister); + } + return await process_outputs_unknown(outputs_unknown, recv_persister, mode); } Future process_maybe_inputs_owned( payjoin.MaybeInputsOwned proposal, InMemoryReceiverPersister recv_persister, + TransitionMode mode, ) async { - final maybe_inputs_owned = proposal - .checkInputsNotOwned(isOwned: IsScriptOwnedCallback(receiver)) - .save(persister: recv_persister); - return await process_maybe_inputs_seen(maybe_inputs_owned, recv_persister); + final payjoin.MaybeInputsSeen maybe_inputs_owned; + if (mode == TransitionMode.callback) { + maybe_inputs_owned = proposal + .checkInputsNotOwned(isOwned: IsScriptOwnedCallback(receiver)) + .save(persister: recv_persister); + } else { + final tagged_refs = proposal + .getInputScriptRefs() + .map( + (ref) => ref.mark( + result: IsScriptOwnedCallback(receiver).callback(ref.getValue()), + ), + ) + .toList(); + maybe_inputs_owned = proposal + .applyInputOwnedChecks(checkedInputScripts: tagged_refs) + .save(persister: recv_persister); + } + return await process_maybe_inputs_seen( + maybe_inputs_owned, + recv_persister, + mode, + ); } Future process_unchecked_proposal( payjoin.UncheckedOriginalPayload proposal, InMemoryReceiverPersister recv_persister, + TransitionMode mode, ) async { - final unchecked_proposal = proposal - .checkBroadcastSuitability( - minFeeRateSatPerKwu: null, - canBroadcast: MempoolAcceptanceCallback(receiver), - ) - .save(persister: recv_persister); - return await process_maybe_inputs_owned(unchecked_proposal, recv_persister); + final payjoin.MaybeInputsOwned unchecked_proposal; + if (mode == TransitionMode.callback) { + unchecked_proposal = proposal + .checkBroadcastSuitability( + minFeeRateSatPerKwu: null, + canBroadcast: MempoolAcceptanceCallback(receiver), + ) + .save(persister: recv_persister); + } else { + final can_broadcast = MempoolAcceptanceCallback( + receiver, + ).callback(proposal.extractTxToCheckBroadcastSuitability()); + unchecked_proposal = proposal + .applyBroadcastSuitability( + minFeeRateSatPerKwu: null, + canBroadcast: can_broadcast, + ) + .save(persister: recv_persister); + } + return await process_maybe_inputs_owned( + unchecked_proposal, + recv_persister, + mode, + ); } Future retrieve_receiver_proposal( payjoin.Initialized receiver, InMemoryReceiverPersister recv_persister, String ohttp_relay, + TransitionMode mode, ) async { var agent = http.Client(); var request = receiver.createPollRequest(ohttpRelay: ohttp_relay); @@ -303,7 +400,7 @@ Future retrieve_receiver_proposal( return null; } else if (res is payjoin.ProgressInitializedTransitionOutcome) { var proposal = res.inner; - return await process_unchecked_proposal(proposal, recv_persister); + return await process_unchecked_proposal(proposal, recv_persister, mode); } throw Exception("Unknown initialized transition outcome: $res"); @@ -313,12 +410,14 @@ Future process_receiver_proposal( payjoin.ReceiveSession receiver, InMemoryReceiverPersister recv_persister, String ohttp_relay, + TransitionMode mode, ) async { if (receiver is payjoin.InitializedReceiveSession) { var res = await retrieve_receiver_proposal( receiver.inner, recv_persister, ohttp_relay, + mode, ); if (res == null) { return null; @@ -327,25 +426,41 @@ Future process_receiver_proposal( } if (receiver is payjoin.UncheckedOriginalPayloadReceiveSession) { - return await process_unchecked_proposal(receiver.inner, recv_persister); + return await process_unchecked_proposal( + receiver.inner, + recv_persister, + mode, + ); } if (receiver is payjoin.MaybeInputsOwnedReceiveSession) { - return await process_maybe_inputs_owned(receiver.inner, recv_persister); + return await process_maybe_inputs_owned( + receiver.inner, + recv_persister, + mode, + ); } if (receiver is payjoin.MaybeInputsSeenReceiveSession) { - return await process_maybe_inputs_seen(receiver.inner, recv_persister); + return await process_maybe_inputs_seen( + receiver.inner, + recv_persister, + mode, + ); } if (receiver is payjoin.OutputsUnknownReceiveSession) { - return await process_outputs_unknown(receiver.inner, recv_persister); + return await process_outputs_unknown(receiver.inner, recv_persister, mode); } if (receiver is payjoin.WantsOutputsReceiveSession) { - return await process_wants_outputs(receiver.inner, recv_persister); + return await process_wants_outputs(receiver.inner, recv_persister, mode); } if (receiver is payjoin.WantsInputsReceiveSession) { - return await process_wants_inputs(receiver.inner, recv_persister); + return await process_wants_inputs(receiver.inner, recv_persister, mode); } if (receiver is payjoin.ProvisionalProposalReceiveSession) { - return await process_provisional_proposal(receiver.inner, recv_persister); + return await process_provisional_proposal( + receiver.inner, + recv_persister, + mode, + ); } if (receiver is payjoin.PayjoinProposalReceiveSession) { return receiver; @@ -354,6 +469,167 @@ Future process_receiver_proposal( throw Exception("Unknown receiver state: $receiver"); } +Future run_integration_v2_to_v2(TransitionMode mode) async { + env = test_utils.initBitcoindSenderReceiver(); + bitcoind = env.getBitcoind(); + receiver = env.getReceiver(); + sender = env.getSender(); + var receiver_address = + jsonDecode(receiver.call(method: "getnewaddress", params: [])) as String; + var services = test_utils.TestServices.initialize(); + + services.waitForServicesReady(); + var directory = services.directoryUrl(); + var ohttp_keys = services.fetchOhttpKeys(); + var ohttp_relay = services.ohttpRelayUrl(); + var agent = http.Client(); + + // ********************** + // Inside the Receiver: + var recv_persister = InMemoryReceiverPersister(); + var sender_persister = InMemorySenderPersister(); + var session = create_receiver_context( + receiver_address, + directory, + ohttp_keys, + recv_persister, + ); + var process_response = await process_receiver_proposal( + payjoin.InitializedReceiveSession(session), + recv_persister, + ohttp_relay, + mode, + ); + expect(process_response, isNull); + + // ********************** + // Inside the Sender: + // Create a funded PSBT (not broadcasted) to address with amount given in the pj_uri + var pj_uri = session.pjUri(); + var psbt = build_sweep_psbt(sender, pj_uri); + payjoin.WithReplyKey req_ctx = payjoin.SenderBuilder(psbt: psbt, uri: pj_uri) + .buildRecommended(minFeeRateSatPerKwu: 1000) + .save(persister: sender_persister); + payjoin.RequestOhttpContext request = req_ctx.createV2PostRequest( + ohttpRelay: ohttp_relay, + ); + var response = await agent.post( + Uri.parse(request.request.url), + headers: {"Content-Type": request.request.contentType}, + body: request.request.body, + ); + payjoin.PollingForProposal send_ctx = req_ctx + .processResponse(response: response.bodyBytes, postCtx: request.ohttpCtx) + .save(persister: sender_persister); + // POST Original PSBT + + // ********************** + // Inside the Receiver: + + // GET fallback psbt + payjoin.ReceiveSession? payjoin_proposal = await process_receiver_proposal( + payjoin.InitializedReceiveSession(session), + recv_persister, + ohttp_relay, + mode, + ); + expect(payjoin_proposal, isNotNull); + expect(payjoin_proposal, isA()); + + payjoin.PayjoinProposal proposal = + (payjoin_proposal as payjoin.PayjoinProposalReceiveSession).inner; + payjoin.RequestResponse request_response = proposal.createPostRequest( + ohttpRelay: ohttp_relay, + ); + var fallback_response = await agent.post( + Uri.parse(request_response.request.url), + headers: {"Content-Type": request_response.request.contentType}, + body: request_response.request.body, + ); + proposal.processResponse( + body: fallback_response.bodyBytes, + ohttpContext: request_response.clientResponse, + ); + + // ********************** + // Inside the Sender: + // Sender checks, signs, finalizes, extracts, and broadcasts + // Replay post fallback to get the response + payjoin.PollingForProposalTransitionOutcome? poll_outcome; + var attempts = 0; + while (true) { + payjoin.RequestOhttpContext ohttp_context_request = send_ctx + .createPollRequest(ohttpRelay: ohttp_relay); + var final_response = await agent.post( + Uri.parse(ohttp_context_request.request.url), + headers: {"Content-Type": ohttp_context_request.request.contentType}, + body: ohttp_context_request.request.body, + ); + poll_outcome = send_ctx + .processResponse( + response: final_response.bodyBytes, + ohttpCtx: ohttp_context_request.ohttpCtx, + ) + .save(persister: sender_persister); + + if (poll_outcome is payjoin.ProgressPollingForProposalTransitionOutcome) { + break; + } + + attempts += 1; + if (attempts >= 3) { + // Receiver not ready yet; mirror Python's tolerant polling. + return; + } + } + + final progressOutcome = + poll_outcome as payjoin.ProgressPollingForProposalTransitionOutcome; + var payjoin_psbt = jsonDecode( + sender.call( + method: "walletprocesspsbt", + params: [progressOutcome.psbtBase64], + ), + )["psbt"]; + var final_psbt = jsonDecode( + sender.call( + method: "finalizepsbt", + params: [payjoin_psbt, jsonEncode(false)], + ), + )["psbt"]; + var final_tx_hex = jsonDecode( + sender.call(method: "finalizepsbt", params: [final_psbt, jsonEncode(true)]), + )["hex"]; + sender.call(method: "sendrawtransaction", params: [jsonEncode(final_tx_hex)]); + + // Check resulting transaction and balances + var decodedTx = jsonDecode( + sender.call( + method: "decoderawtransaction", + params: [jsonEncode(final_tx_hex)], + ), + ); + var network_fees = + (jsonDecode( + sender.call( + method: "decodepsbt", + params: [jsonEncode(final_psbt)], + ), + )["fee"] + as num) + .toDouble(); + // Sender sent the entire value of their utxo to the receiver (minus fees) + expect(decodedTx["vin"].length, 2); + expect(decodedTx["vout"].length, 1); + expect( + jsonDecode( + receiver.call(method: "getbalances", params: []), + )["mine"]["untrusted_pending"], + 100 - network_fees, + ); + expect(jsonDecode(sender.call(method: "getbalance", params: [])), 0.0); +} + void main() { group('fetchOhttpKeys', () { test( @@ -471,176 +747,16 @@ void main() { ); }); - test('Test integration v2 to v2', () async { - env = test_utils.initBitcoindSenderReceiver(); - bitcoind = env.getBitcoind(); - receiver = env.getReceiver(); - sender = env.getSender(); - var receiver_address = - jsonDecode(receiver.call(method: "getnewaddress", params: [])) - as String; - var services = test_utils.TestServices.initialize(); - - services.waitForServicesReady(); - var directory = services.directoryUrl(); - var ohttp_keys = services.fetchOhttpKeys(); - var ohttp_relay = services.ohttpRelayUrl(); - var agent = http.Client(); - - // ********************** - // Inside the Receiver: - var recv_persister = InMemoryReceiverPersister(); - var sender_persister = InMemorySenderPersister(); - var session = create_receiver_context( - receiver_address, - directory, - ohttp_keys, - recv_persister, - ); - var process_response = await process_receiver_proposal( - payjoin.InitializedReceiveSession(session), - recv_persister, - ohttp_relay, - ); - expect(process_response, isNull); - - // ********************** - // Inside the Sender: - // Create a funded PSBT (not broadcasted) to address with amount given in the pj_uri - var pj_uri = session.pjUri(); - var psbt = build_sweep_psbt(sender, pj_uri); - payjoin.WithReplyKey req_ctx = - payjoin.SenderBuilder(psbt: psbt, uri: pj_uri) - .buildRecommended(minFeeRateSatPerKwu: 1000) - .save(persister: sender_persister); - payjoin.RequestOhttpContext request = req_ctx.createV2PostRequest( - ohttpRelay: ohttp_relay, - ); - var response = await agent.post( - Uri.parse(request.request.url), - headers: {"Content-Type": request.request.contentType}, - body: request.request.body, - ); - payjoin.PollingForProposal send_ctx = req_ctx - .processResponse( - response: response.bodyBytes, - postCtx: request.ohttpCtx, - ) - .save(persister: sender_persister); - // POST Original PSBT - - // ********************** - // Inside the Receiver: - - // GET fallback psbt - payjoin.ReceiveSession? payjoin_proposal = - await process_receiver_proposal( - payjoin.InitializedReceiveSession(session), - recv_persister, - ohttp_relay, - ); - expect(payjoin_proposal, isNotNull); - expect(payjoin_proposal, isA()); - - payjoin.PayjoinProposal proposal = - (payjoin_proposal as payjoin.PayjoinProposalReceiveSession).inner; - payjoin.RequestResponse request_response = proposal.createPostRequest( - ohttpRelay: ohttp_relay, - ); - var fallback_response = await agent.post( - Uri.parse(request_response.request.url), - headers: {"Content-Type": request_response.request.contentType}, - body: request_response.request.body, - ); - proposal.processResponse( - body: fallback_response.bodyBytes, - ohttpContext: request_response.clientResponse, - ); - - // ********************** - // Inside the Sender: - // Sender checks, signs, finalizes, extracts, and broadcasts - // Replay post fallback to get the response - payjoin.PollingForProposalTransitionOutcome? poll_outcome; - var attempts = 0; - while (true) { - payjoin.RequestOhttpContext ohttp_context_request = send_ctx - .createPollRequest(ohttpRelay: ohttp_relay); - var final_response = await agent.post( - Uri.parse(ohttp_context_request.request.url), - headers: {"Content-Type": ohttp_context_request.request.contentType}, - body: ohttp_context_request.request.body, - ); - poll_outcome = send_ctx - .processResponse( - response: final_response.bodyBytes, - ohttpCtx: ohttp_context_request.ohttpCtx, - ) - .save(persister: sender_persister); - - if (poll_outcome - is payjoin.ProgressPollingForProposalTransitionOutcome) { - break; - } - - attempts += 1; - if (attempts >= 3) { - // Receiver not ready yet; mirror Python's tolerant polling. - return; - } - } - - final progressOutcome = - poll_outcome as payjoin.ProgressPollingForProposalTransitionOutcome; - var payjoin_psbt = jsonDecode( - sender.call( - method: "walletprocesspsbt", - params: [progressOutcome.psbtBase64], - ), - )["psbt"]; - var final_psbt = jsonDecode( - sender.call( - method: "finalizepsbt", - params: [payjoin_psbt, jsonEncode(false)], - ), - )["psbt"]; - var final_tx_hex = jsonDecode( - sender.call( - method: "finalizepsbt", - params: [final_psbt, jsonEncode(true)], - ), - )["hex"]; - sender.call( - method: "sendrawtransaction", - params: [jsonEncode(final_tx_hex)], - ); + test( + 'Test integration v2 to v2 (callback)', + () async => run_integration_v2_to_v2(TransitionMode.callback), + timeout: const Timeout(Duration(minutes: 5)), + ); - // Check resulting transaction and balances - var decodedTx = jsonDecode( - sender.call( - method: "decoderawtransaction", - params: [jsonEncode(final_tx_hex)], - ), - ); - var network_fees = - (jsonDecode( - sender.call( - method: "decodepsbt", - params: [jsonEncode(final_psbt)], - ), - )["fee"] - as num) - .toDouble(); - // Sender sent the entire value of their utxo to the receiver (minus fees) - expect(decodedTx["vin"].length, 2); - expect(decodedTx["vout"].length, 1); - expect( - jsonDecode( - receiver.call(method: "getbalances", params: []), - )["mine"]["untrusted_pending"], - 100 - network_fees, - ); - expect(jsonDecode(sender.call(method: "getbalance", params: [])), 0.0); - }, timeout: const Timeout(Duration(minutes: 5))); + test( + 'Test integration v2 to v2 (nonblocking)', + () async => run_integration_v2_to_v2(TransitionMode.nonblocking), + timeout: const Timeout(Duration(minutes: 5)), + ); }); } diff --git a/payjoin-ffi/javascript/test/integration.test.ts b/payjoin-ffi/javascript/test/integration.test.ts index 0dde4a267..4ff072ddb 100644 --- a/payjoin-ffi/javascript/test/integration.test.ts +++ b/payjoin-ffi/javascript/test/integration.test.ts @@ -31,6 +31,8 @@ interface Utxo { scriptPubKey: string; } +type TransitionMode = "callback" | "nonblocking"; + type PayjoinModule = typeof nodejsPayjoin; const webPayjoin = webPayjoinModule as unknown as PayjoinModule; @@ -225,13 +227,23 @@ class ReceiverProcessor { private readonly payjoin: PayjoinModule, private readonly receiver: testUtils.RpcClient, private readonly recvPersister: InMemoryReceiverPersister, + private readonly mode: TransitionMode, ) {} private async processProvisionalProposal( proposal: PJ<"ProvisionalProposal">, ): Promise> { + if (this.mode === "callback") { + return proposal + .finalizeProposal(new ProcessPsbtCallback(this.receiver)) + .save(this.recvPersister) as PJ<"PayjoinProposal">; + } + + const signedPsbt = new ProcessPsbtCallback(this.receiver).callback( + proposal.psbtToSign(), + ); return proposal - .finalizeProposal(new ProcessPsbtCallback(this.receiver)) + .finalizeSignedProposal(signedPsbt) .save(this.recvPersister) as PJ<"PayjoinProposal">; } @@ -266,41 +278,109 @@ class ReceiverProcessor { private async processOutputsUnknown( proposal: PJ<"OutputsUnknown">, ): Promise> { - const wantsOutputs = proposal - .identifyReceiverOutputs(new IsScriptOwnedCallback(this.receiver)) - .save(this.recvPersister) as PJ<"WantsOutputs">; + let wantsOutputs: PJ<"WantsOutputs">; + + if (this.mode === "callback") { + wantsOutputs = proposal + .identifyReceiverOutputs( + new IsScriptOwnedCallback(this.receiver), + ) + .save(this.recvPersister) as PJ<"WantsOutputs">; + } else { + const taggedRefs = proposal + .getOutputScriptRefs() + .map((ref) => + ref.mark( + new IsScriptOwnedCallback(this.receiver).callback( + ref.getValue(), + ), + ), + ); + wantsOutputs = proposal + .applyOutputOwnedChecks(taggedRefs) + .save(this.recvPersister) as PJ<"WantsOutputs">; + } + return this.processWantsOutputs(wantsOutputs); } private async processMaybeInputsSeen( proposal: PJ<"MaybeInputsSeen">, ): Promise> { - const outputsUnknown = proposal - .checkNoInputsSeenBefore( - new CheckInputsNotSeenCallback(this.receiver), - ) - .save(this.recvPersister) as PJ<"OutputsUnknown">; + let outputsUnknown: PJ<"OutputsUnknown">; + + if (this.mode === "callback") { + outputsUnknown = proposal + .checkNoInputsSeenBefore( + new CheckInputsNotSeenCallback(this.receiver), + ) + .save(this.recvPersister) as PJ<"OutputsUnknown">; + } else { + const taggedRefs = proposal + .getInputOutpointRefs() + .map((ref) => + ref.mark( + new CheckInputsNotSeenCallback(this.receiver).callback( + ref.getValue(), + ), + ), + ); + outputsUnknown = proposal + .applyInputSeenChecks(taggedRefs) + .save(this.recvPersister) as PJ<"OutputsUnknown">; + } + return this.processOutputsUnknown(outputsUnknown); } private async processMaybeInputsOwned( proposal: nodejsPayjoin.MaybeInputsOwned, ): Promise> { - const maybeInputsSeen = proposal - .checkInputsNotOwned(new IsScriptOwnedCallback(this.receiver)) - .save(this.recvPersister) as PJ<"MaybeInputsSeen">; + let maybeInputsSeen: PJ<"MaybeInputsSeen">; + + if (this.mode === "callback") { + maybeInputsSeen = proposal + .checkInputsNotOwned(new IsScriptOwnedCallback(this.receiver)) + .save(this.recvPersister) as PJ<"MaybeInputsSeen">; + } else { + const taggedRefs = proposal + .getInputScriptRefs() + .map((ref) => + ref.mark( + new IsScriptOwnedCallback(this.receiver).callback( + ref.getValue(), + ), + ), + ); + maybeInputsSeen = proposal + .applyInputOwnedChecks(taggedRefs) + .save(this.recvPersister) as PJ<"MaybeInputsSeen">; + } + return this.processMaybeInputsSeen(maybeInputsSeen); } private async processUncheckedProposal( proposal: PJ<"UncheckedOriginalPayload">, ): Promise> { - const maybeInputsOwned = proposal - .checkBroadcastSuitability( - undefined, - new MempoolAcceptanceCallback(this.receiver), - ) - .save(this.recvPersister) as PJ<"MaybeInputsOwned">; + let maybeInputsOwned: PJ<"MaybeInputsOwned">; + + if (this.mode === "callback") { + maybeInputsOwned = proposal + .checkBroadcastSuitability( + undefined, + new MempoolAcceptanceCallback(this.receiver), + ) + .save(this.recvPersister) as PJ<"MaybeInputsOwned">; + } else { + const canBroadcastResult = new MempoolAcceptanceCallback( + this.receiver, + ).callback(proposal.extractTxToCheckBroadcastSuitability()); + maybeInputsOwned = proposal + .applyBroadcastSuitability(undefined, canBroadcastResult) + .save(this.recvPersister) as PJ<"MaybeInputsOwned">; + } + return this.processMaybeInputsOwned(maybeInputsOwned); } @@ -491,7 +571,10 @@ function testFfiValidation(payjoin: PayjoinModule): void { }, /AmountOutOfRange/); } -async function testIntegrationV2ToV2(payjoin: PayjoinModule): Promise { +async function testIntegrationV2ToV2( + payjoin: PayjoinModule, + mode: TransitionMode, +): Promise { const env = testUtils.initBitcoindSenderReceiver(); const receiver = env.getReceiver(); const sender = env.getSender(); @@ -513,6 +596,7 @@ async function testIntegrationV2ToV2(payjoin: PayjoinModule): Promise { payjoin, receiver, recvPersister, + mode, ); const senderPersister = new InMemorySenderPersister(); @@ -644,11 +728,13 @@ async function testIntegrationV2ToV2(payjoin: PayjoinModule): Promise { async function runTests(): Promise { await nodejsUniffiInitAsync(); testFfiValidation(nodejsPayjoin); - await testIntegrationV2ToV2(nodejsPayjoin); + await testIntegrationV2ToV2(nodejsPayjoin, "callback"); + await testIntegrationV2ToV2(nodejsPayjoin, "nonblocking"); await webUniffiInitAsync(); testFfiValidation(webPayjoin); - await testIntegrationV2ToV2(webPayjoin); + await testIntegrationV2ToV2(webPayjoin, "callback"); + await testIntegrationV2ToV2(webPayjoin, "nonblocking"); } runTests().catch((error: unknown) => { diff --git a/payjoin-ffi/python/test/test_payjoin_integration_test.py b/payjoin-ffi/python/test/test_payjoin_integration_test.py index 146fc41fe..ef518015a 100644 --- a/payjoin-ffi/python/test/test_payjoin_integration_test.py +++ b/payjoin-ffi/python/test/test_payjoin_integration_test.py @@ -2,7 +2,7 @@ import sys import httpx import json -from typing import cast, Protocol, Any +from typing import cast, Protocol, Any, Literal from payjoin import * from payjoin.http import fetch_ohttp_keys @@ -22,6 +22,9 @@ class HasInner(Protocol): inner: Any +TransitionMode = Literal["callback", "nonblocking"] + + class TestPayjoin(unittest.IsolatedAsyncioTestCase): @classmethod def setUpClass(cls): @@ -92,12 +95,14 @@ async def process_receiver_proposal( receiver: ReceiveSession, recv_persister: InMemoryReceiverPersister, ohttp_relay: str, + mode: TransitionMode, ) -> Optional[ReceiveSession.PAYJOIN_PROPOSAL]: if receiver.is_INITIALIZED(): res = await self.retrieve_receiver_proposal( cast(ReceiveSession.INITIALIZED, receiver).inner, recv_persister, ohttp_relay, + mode, ) if res is None: return None @@ -107,35 +112,49 @@ async def process_receiver_proposal( return await self.process_unchecked_proposal( cast(ReceiveSession.UNCHECKED_ORIGINAL_PAYLOAD, receiver).inner, recv_persister, + mode, ) if receiver.is_MAYBE_INPUTS_OWNED(): return await self.process_maybe_inputs_owned( - cast(ReceiveSession.MAYBE_INPUTS_OWNED, receiver).inner, recv_persister + cast(ReceiveSession.MAYBE_INPUTS_OWNED, receiver).inner, + recv_persister, + mode, ) if receiver.is_MAYBE_INPUTS_SEEN(): return await self.process_maybe_inputs_seen( - cast(ReceiveSession.MAYBE_INPUTS_SEEN, receiver).inner, recv_persister + cast(ReceiveSession.MAYBE_INPUTS_SEEN, receiver).inner, + recv_persister, + mode, ) if receiver.is_OUTPUTS_UNKNOWN(): return await self.process_outputs_unknown( - cast(ReceiveSession.OUTPUTS_UNKNOWN, receiver).inner, recv_persister + cast(ReceiveSession.OUTPUTS_UNKNOWN, receiver).inner, + recv_persister, + mode, ) if receiver.is_WANTS_OUTPUTS(): return await self.process_wants_outputs( - cast(ReceiveSession.WANTS_OUTPUTS, receiver).inner, recv_persister + cast(ReceiveSession.WANTS_OUTPUTS, receiver).inner, + recv_persister, + mode, ) if receiver.is_WANTS_INPUTS(): return await self.process_wants_inputs( - cast(ReceiveSession.WANTS_INPUTS, receiver).inner, recv_persister + cast(ReceiveSession.WANTS_INPUTS, receiver).inner, + recv_persister, + mode, ) if receiver.is_WANTS_FEE_RANGE(): return await self.process_wants_fee_range( - cast(ReceiveSession.WANTS_FEE_RANGE, receiver).inner, recv_persister + cast(ReceiveSession.WANTS_FEE_RANGE, receiver).inner, + recv_persister, + mode, ) if receiver.is_PROVISIONAL_PROPOSAL(): return await self.process_provisional_proposal( cast(ReceiveSession.PROVISIONAL_PROPOSAL, receiver).inner, recv_persister, + mode, ) if receiver.is_PAYJOIN_PROPOSAL(): return cast(ReceiveSession.PAYJOIN_PROPOSAL, receiver) @@ -161,6 +180,7 @@ async def retrieve_receiver_proposal( receiver: Initialized, recv_persister: InMemoryReceiverPersister, ohttp_relay: str, + mode: TransitionMode, ): agent = httpx.AsyncClient() request: RequestResponse = receiver.create_poll_request(ohttp_relay) @@ -175,80 +195,155 @@ async def retrieve_receiver_proposal( if res.is_STASIS(): return None return await self.process_unchecked_proposal( - cast(ReceiveSession.UNCHECKED_ORIGINAL_PAYLOAD, res).inner, recv_persister + cast(ReceiveSession.UNCHECKED_ORIGINAL_PAYLOAD, res).inner, + recv_persister, + mode, ) async def process_unchecked_proposal( self, proposal: UncheckedOriginalPayload, recv_persister: InMemoryReceiverPersister, + mode: TransitionMode, ): - receiver = proposal.check_broadcast_suitability( - None, MempoolAcceptanceCallback(self.receiver) - ).save(recv_persister) - return await self.process_maybe_inputs_owned(receiver, recv_persister) + if mode == "callback": + receiver = proposal.check_broadcast_suitability( + None, MempoolAcceptanceCallback(self.receiver) + ).save(recv_persister) + else: + can_broadcast = MempoolAcceptanceCallback(self.receiver).callback( + proposal.extract_tx_to_check_broadcast_suitability() + ) + receiver = proposal.apply_broadcast_suitability(None, can_broadcast).save( + recv_persister + ) + return await self.process_maybe_inputs_owned(receiver, recv_persister, mode) async def process_maybe_inputs_owned( self, proposal: MaybeInputsOwned, recv_persister: InMemoryReceiverPersister, + mode: TransitionMode, ): - maybe_inputs_owned = proposal.check_inputs_not_owned( - IsScriptOwnedCallback(self.receiver) - ).save(recv_persister) - return await self.process_maybe_inputs_seen(maybe_inputs_owned, recv_persister) + if mode == "callback": + maybe_inputs_owned = proposal.check_inputs_not_owned( + IsScriptOwnedCallback(self.receiver) + ).save(recv_persister) + else: + tagged_refs = [ + ref.mark(IsScriptOwnedCallback(self.receiver).callback(ref.get_value())) + for ref in proposal.get_input_script_refs() + ] + maybe_inputs_owned = proposal.apply_input_owned_checks(tagged_refs).save( + recv_persister + ) + return await self.process_maybe_inputs_seen( + maybe_inputs_owned, recv_persister, mode + ) async def process_maybe_inputs_seen( - self, proposal: MaybeInputsSeen, recv_persister: InMemoryReceiverPersister + self, + proposal: MaybeInputsSeen, + recv_persister: InMemoryReceiverPersister, + mode: TransitionMode, ): - outputs_unknown = proposal.check_no_inputs_seen_before( - CheckInputsNotSeenCallback(self.receiver) - ).save(recv_persister) - return await self.process_outputs_unknown(outputs_unknown, recv_persister) + if mode == "callback": + outputs_unknown = proposal.check_no_inputs_seen_before( + CheckInputsNotSeenCallback(self.receiver) + ).save(recv_persister) + else: + tagged_refs = [ + ref.mark( + CheckInputsNotSeenCallback(self.receiver).callback(ref.get_value()) + ) + for ref in proposal.get_input_outpoint_refs() + ] + outputs_unknown = proposal.apply_input_seen_checks(tagged_refs).save( + recv_persister + ) + return await self.process_outputs_unknown(outputs_unknown, recv_persister, mode) async def process_outputs_unknown( - self, proposal: OutputsUnknown, recv_persister: InMemoryReceiverPersister + self, + proposal: OutputsUnknown, + recv_persister: InMemoryReceiverPersister, + mode: TransitionMode, ): - wants_outputs = proposal.identify_receiver_outputs( - IsScriptOwnedCallback(self.receiver) - ).save(recv_persister) - return await self.process_wants_outputs(wants_outputs, recv_persister) + if mode == "callback": + wants_outputs = proposal.identify_receiver_outputs( + IsScriptOwnedCallback(self.receiver) + ).save(recv_persister) + else: + tagged_refs = [ + ref.mark(IsScriptOwnedCallback(self.receiver).callback(ref.get_value())) + for ref in proposal.get_output_script_refs() + ] + wants_outputs = proposal.apply_output_owned_checks(tagged_refs).save( + recv_persister + ) + return await self.process_wants_outputs(wants_outputs, recv_persister, mode) async def process_wants_outputs( - self, proposal: WantsOutputs, recv_persister: InMemoryReceiverPersister + self, + proposal: WantsOutputs, + recv_persister: InMemoryReceiverPersister, + mode: TransitionMode, ): wants_inputs = proposal.commit_outputs().save(recv_persister) - return await self.process_wants_inputs(wants_inputs, recv_persister) + return await self.process_wants_inputs(wants_inputs, recv_persister, mode) async def process_wants_inputs( - self, proposal: WantsInputs, recv_persister: InMemoryReceiverPersister + self, + proposal: WantsInputs, + recv_persister: InMemoryReceiverPersister, + mode: TransitionMode, ): provisional_proposal = ( proposal.contribute_inputs(get_inputs(self.receiver)) .commit_inputs() .save(recv_persister) ) - return await self.process_wants_fee_range(provisional_proposal, recv_persister) + return await self.process_wants_fee_range( + provisional_proposal, recv_persister, mode + ) async def process_wants_fee_range( - self, proposal: WantsFeeRange, recv_persister: InMemoryReceiverPersister + self, + proposal: WantsFeeRange, + recv_persister: InMemoryReceiverPersister, + mode: TransitionMode, ): provisional_proposal = proposal.apply_fee_range(1, 10).save(recv_persister) return await self.process_provisional_proposal( - provisional_proposal, recv_persister + provisional_proposal, recv_persister, mode ) async def process_provisional_proposal( self, proposal: ProvisionalProposal, recv_persister: InMemoryReceiverPersister, + mode: TransitionMode, ): - payjoin_proposal = proposal.finalize_proposal( - ProcessPsbtCallback(self.receiver) - ).save(recv_persister) + if mode == "callback": + payjoin_proposal = proposal.finalize_proposal( + ProcessPsbtCallback(self.receiver) + ).save(recv_persister) + else: + signed_psbt = ProcessPsbtCallback(self.receiver).callback( + proposal.psbt_to_sign() + ) + payjoin_proposal = proposal.finalize_signed_proposal(signed_psbt).save( + recv_persister + ) return ReceiveSession.PAYJOIN_PROPOSAL(payjoin_proposal) - async def test_integration_v2_to_v2(self): + def setUp(self): + sender_address = json.loads(self.sender.call("getnewaddress", [])) + self.sender.call( + "generatetoaddress", [json.dumps(101), json.dumps(sender_address)] + ) + + async def _run_integration_v2_to_v2(self, mode: TransitionMode): try: receiver_address = json.loads(self.receiver.call("getnewaddress", [])) init_tracing() @@ -271,6 +366,7 @@ async def test_integration_v2_to_v2(self): cast(ReceiveSession, ReceiveSession.INITIALIZED(session)), recv_persister, ohttp_relay, + mode, ) self.assertIsNone(process_response) @@ -305,6 +401,7 @@ async def test_integration_v2_to_v2(self): cast(ReceiveSession, ReceiveSession.INITIALIZED(session)), recv_persister, ohttp_relay, + mode, ) self.assertIsNotNone(payjoin_proposal) self.assertEqual( @@ -385,6 +482,12 @@ async def test_integration_v2_to_v2(self): print("Caught:", e) raise + async def test_integration_v2_to_v2_callback(self): + await self._run_integration_v2_to_v2("callback") + + async def test_integration_v2_to_v2_nonblocking(self): + await self._run_integration_v2_to_v2("nonblocking") + def build_sweep_psbt(sender: RpcClient, pj_uri: PjUri) -> str: outputs = {} diff --git a/payjoin-ffi/src/receive/mod.rs b/payjoin-ffi/src/receive/mod.rs index 47d9833f0..04efaa926 100644 --- a/payjoin-ffi/src/receive/mod.rs +++ b/payjoin-ffi/src/receive/mod.rs @@ -344,9 +344,6 @@ impl InitialReceiveTransition { } } -#[derive(Clone, Debug, uniffi::Object)] -pub struct ReceiverBuilder(payjoin::receive::v2::ReceiverBuilder); - /// Primitive representation of a transaction output for the FFI boundary. #[derive(Clone, Debug, serde::Serialize, serde::Deserialize, uniffi::Record)] pub struct TxOut { @@ -394,7 +391,7 @@ impl TxIn { } /// Primitive representation of an outpoint for the FFI boundary. -#[derive(Clone, Debug, serde::Serialize, serde::Deserialize, uniffi::Record)] +#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize, uniffi::Record)] pub struct OutPoint { /// Hex-encoded txid (big-endian). pub txid: String, @@ -458,6 +455,9 @@ impl From for Weight { fn from(value: payjoin::bitcoin::Weight) -> Self { Weight { weight_units: value.to_wu() } } } +#[derive(Clone, Debug, uniffi::Object)] +pub struct ReceiverBuilder(payjoin::receive::v2::ReceiverBuilder); + #[uniffi::export] impl ReceiverBuilder { /// Creates a new [`Initialized`] with the provided parameters. @@ -728,6 +728,23 @@ impl UncheckedOriginalPayload { ))))) } + pub fn extract_tx_to_check_broadcast_suitability(&self) -> Vec { + payjoin::bitcoin::consensus::encode::serialize( + &self.0.clone().extract_tx_to_check_broadcast_suitability(), + ) + } + + pub fn apply_broadcast_suitability( + &self, + min_fee_rate_sat_per_kwu: Option, + can_broadcast: bool, + ) -> Result { + let min_fee_rate = validate_fee_rate_sat_per_kwu_opt(min_fee_rate_sat_per_kwu)?; + Ok(UncheckedOriginalPayloadTransition(Arc::new(RwLock::new(Some( + self.0.clone().apply_broadcast_suitability(min_fee_rate, can_broadcast), + ))))) + } + /// Call this method if the only way to initiate a Payjoin with this receiver /// requires manual intervention, as in most consumer wallets. /// @@ -740,6 +757,65 @@ impl UncheckedOriginalPayload { } } +trait FfiTaggedReference { + fn get_result(&self) -> bool; + fn get_value(&self) -> V; +} + +fn to_tagged_references( + references: impl Iterator>, + checks: Vec>, +) -> Result< + impl Iterator>, + payjoin::error::ImplementationError, +> +where + V: Clone, + T: payjoin::receive::Tag, + R: FfiTaggedReference, + Vffi: From + PartialEq, +{ + payjoin::receive::check_references(references, &mut move |pj_ref: &V| { + let found_result = checks.iter().find_map(|ffi_ref| { + if Vffi::from(pj_ref.clone()) == ffi_ref.get_value() { + Some(ffi_ref.get_result()) + } else { + None + } + }); + match found_result { + Some(result) => Ok(result), + None => Err(payjoin::ImplementationError::from( + "Returned input owned checks is missing result for script {script}", + )), + } + }) +} + +#[derive(Debug, uniffi::Object)] +pub struct InputOwnedReference( + payjoin::receive::Reference, +); + +#[uniffi::export] +impl InputOwnedReference { + pub fn get_value(&self) -> Vec { self.0.get_value().to_bytes() } + pub fn mark(&self, result: bool) -> Arc { + Arc::new(InputOwnedTaggedReference { value: self.get_value(), result }) + } +} + +#[derive(Debug, Clone, uniffi::Object)] +pub struct InputOwnedTaggedReference { + value: Vec, + result: bool, +} + +impl FfiTaggedReference> for InputOwnedTaggedReference { + fn get_result(&self) -> bool { self.result } + fn get_value(&self) -> Vec { self.value.clone() } +} + #[derive(Clone, uniffi::Object)] pub struct MaybeInputsOwned(payjoin::receive::v2::Receiver); @@ -783,6 +859,7 @@ impl MaybeInputsOwned { &self.0.clone().extract_tx_to_schedule_broadcast(), ) } + pub fn check_inputs_not_owned( &self, is_owned: Arc, @@ -793,6 +870,53 @@ impl MaybeInputsOwned { }), )))) } + + pub fn get_input_script_refs(&self) -> Result>, ReceiverError> { + self.0 + .clone() + .get_input_script_refs() + .map(|iter| { + iter.map(|input_script_ref| Arc::new(InputOwnedReference(input_script_ref))) + .collect::>() + }) + .map_err(ReceiverError::from) + } + + pub fn apply_input_owned_checks( + &self, + checked_input_scripts: Vec>, + ) -> Result { + let references = self.0.clone().get_input_script_refs()?; + let checked = to_tagged_references(references, checked_input_scripts) + .map_err(|e| ReceiverError::Implementation(Arc::new(ImplementationError::from(e))))?; + Ok(MaybeInputsOwnedTransition(Arc::new(RwLock::new(Some( + self.0.clone().apply_input_owned_checks(checked), + ))))) + } +} + +#[derive(Debug, uniffi::Object)] +pub struct InputSeenReference( + payjoin::receive::Reference, +); + +#[uniffi::export] +impl InputSeenReference { + pub fn get_value(&self) -> OutPoint { self.0.get_value().into() } + pub fn mark(&self, result: bool) -> Arc { + Arc::new(InputSeenTaggedReference { value: self.get_value(), result }) + } +} + +#[derive(Debug, Clone, uniffi::Object)] +pub struct InputSeenTaggedReference { + value: OutPoint, + result: bool, +} + +impl FfiTaggedReference for InputSeenTaggedReference { + fn get_result(&self) -> bool { self.result } + fn get_value(&self) -> OutPoint { self.value.clone() } } #[derive(Clone, uniffi::Object)] @@ -844,6 +968,50 @@ impl MaybeInputsSeen { }), )))) } + + pub fn get_input_outpoint_refs(&self) -> Vec> { + self.0 + .clone() + .get_input_outpoint_refs() + .map(|input_outpoint_ref| Arc::new(InputSeenReference(input_outpoint_ref))) + .collect::>() + } + + pub fn apply_input_seen_checks( + &self, + checked_input_outpoints: Vec>, + ) -> Result { + let references = self.0.clone().get_input_outpoint_refs(); + let checked = to_tagged_references(references, checked_input_outpoints) + .map_err(|e| ReceiverError::Implementation(Arc::new(ImplementationError::from(e))))?; + Ok(MaybeInputsSeenTransition(Arc::new(RwLock::new(Some( + self.0.clone().apply_input_seen_checks(checked), + ))))) + } +} + +#[derive(Debug, uniffi::Object)] +pub struct OutputOwnedReference( + payjoin::receive::Reference, +); + +#[uniffi::export] +impl OutputOwnedReference { + pub fn get_value(&self) -> Vec { self.0.get_value().to_bytes() } + pub fn mark(&self, result: bool) -> Arc { + Arc::new(OutputOwnedTaggedReference { value: self.get_value(), result }) + } +} + +#[derive(Debug, Clone, uniffi::Object)] +pub struct OutputOwnedTaggedReference { + value: Vec, + result: bool, +} + +impl FfiTaggedReference> for OutputOwnedTaggedReference { + fn get_result(&self) -> bool { self.result } + fn get_value(&self) -> Vec { self.value.clone() } } /// The receiver has not yet identified which outputs belong to the receiver. @@ -893,6 +1061,26 @@ impl OutputsUnknown { }), )))) } + + pub fn get_output_script_refs(&self) -> Vec> { + self.0 + .clone() + .get_output_script_refs() + .map(|output_script_ref| Arc::new(OutputOwnedReference(output_script_ref))) + .collect::>() + } + + pub fn apply_output_owned_checks( + &self, + checked_output_scripts: Vec>, + ) -> Result { + let references = self.0.clone().get_output_script_refs(); + let checked = to_tagged_references(references, checked_output_scripts) + .map_err(|e| ReceiverError::Implementation(Arc::new(ImplementationError::from(e))))?; + Ok(OutputsUnknownTransition(Arc::new(RwLock::new(Some( + self.0.clone().apply_output_owned_checks(checked), + ))))) + } } #[derive(uniffi::Object)] @@ -1177,6 +1365,14 @@ impl ProvisionalProposal { } pub fn psbt_to_sign(&self) -> String { self.0.clone().psbt_to_sign().to_string() } + + pub fn finalize_signed_proposal(&self, signed_psbt: String) -> ProvisionalProposalTransition { + ProvisionalProposalTransition(Arc::new(RwLock::new(Some( + self.0.clone().finalize_proposal(|_| { + Ok(Psbt::from_str(&signed_psbt).map_err(ImplementationError::new)?) + }), + )))) + } } #[derive(Clone, uniffi::Object)] @@ -1433,6 +1629,29 @@ impl Monitor { .map_err(|e| ImplementationError::new(e).into()) }))))) } + pub fn extract_fallback_txid(&self) -> String { + self.0.clone().extract_fallback_txid().to_string() + } + + pub fn extract_payjoin_proposal_txid(&self) -> String { + self.0.clone().extract_payjoin_proposal_txid().to_string() + } + + pub fn check_fallback_monitorable(&self) -> MonitorTransition { + MonitorTransition(Arc::new(RwLock::new(Some(self.0.clone().check_fallback_monitorable())))) + } + + pub fn fallback_tx_exists(&self) -> MonitorTransition { + MonitorTransition(Arc::new(RwLock::new(Some(self.0.clone().fallback_tx_exists())))) + } + + pub fn payjoin_tx_exists( + &self, + payjoin_tx: Vec, + ) -> Result { + let tx = try_deserialize_tx(payjoin_tx)?; + Ok(MonitorTransition(Arc::new(RwLock::new(Some(self.0.clone().payjoin_tx_exists(tx)))))) + } } /// Session persister that should save and load events as JSON strings. From 258938d3a47f1d5f6b45e972a57692a8a4475b47 Mon Sep 17 00:00:00 2001 From: xstoicunicornx Date: Sat, 23 May 2026 10:12:53 -0500 Subject: [PATCH 6/8] Add missing comments for FFI bindings --- payjoin-ffi/src/receive/mod.rs | 137 ++++++++++++++++++++++++++++++++- 1 file changed, 134 insertions(+), 3 deletions(-) diff --git a/payjoin-ffi/src/receive/mod.rs b/payjoin-ffi/src/receive/mod.rs index 04efaa926..b2af0b1af 100644 --- a/payjoin-ffi/src/receive/mod.rs +++ b/payjoin-ffi/src/receive/mod.rs @@ -632,7 +632,11 @@ impl Initialized { .map_err(Into::into) } - /// The response can either be an UncheckedOriginalPayload or an ACCEPTED message indicating no UncheckedOriginalPayload is available yet. + /// Process the polling response from the directory. + /// + /// Returns an [`InitializedTransition`] that, once persisted, yields either + /// an [`UncheckedOriginalPayload`] if the sender's Original PSBT is available, + /// or [`Initialized`] if no proposal has arrived yet. pub fn process_response(&self, body: &[u8], ctx: &ClientResponse) -> InitializedTransition { InitializedTransition(Arc::new(RwLock::new(Some( self.0.clone().process_response(body, ctx.into()), @@ -713,6 +717,11 @@ pub trait CanBroadcast: Send + Sync { #[uniffi::export] impl UncheckedOriginalPayload { + /// Check that the sender's Original PSBT is suitable for broadcast, ensuring + /// it can be used as a fallback if the payjoin does not complete. + /// + /// Returns an [`UncheckedOriginalPayloadTransition`] that, once persisted, + /// yields a [`MaybeInputsOwned`] to continue validation. pub fn check_broadcast_suitability( &self, min_fee_rate_sat_per_kwu: Option, @@ -728,12 +737,21 @@ impl UncheckedOriginalPayload { ))))) } + /// Extract the transaction from the Original PSBT for external broadcast suitability checks. + /// + /// Returns the consensus-encoded raw transaction bytes. pub fn extract_tx_to_check_broadcast_suitability(&self) -> Vec { payjoin::bitcoin::consensus::encode::serialize( &self.0.clone().extract_tx_to_check_broadcast_suitability(), ) } + /// Apply the result of an external broadcast suitability check, ensuring + /// the Original PSBT can be used as a fallback if the payjoin does + /// not complete. + /// + /// Returns an [`UncheckedOriginalPayloadTransition`] that, once persisted, + /// yields a [`MaybeInputsOwned`] to continue validation. pub fn apply_broadcast_suitability( &self, min_fee_rate_sat_per_kwu: Option, @@ -853,13 +871,21 @@ pub trait IsScriptOwned: Send + Sync { #[uniffi::export] impl MaybeInputsOwned { - ///The Sender’s Original PSBT + /// Extract the transaction from the Original PSBT for scheduling broadcast as a + /// fallback in case the payjoin does not complete. + /// + /// Returns the consensus-encoded raw transaction bytes. pub fn extract_tx_to_schedule_broadcast(&self) -> Vec { payjoin::bitcoin::consensus::encode::serialize( &self.0.clone().extract_tx_to_schedule_broadcast(), ) } + /// Check that none of the Original PSBT's inputs belong to the receiver, + /// preventing an attacker from spending the receiver's own inputs. + /// + /// Returns a [`MaybeInputsOwnedTransition`] that, once persisted, + /// yields a [`MaybeInputsSeen`] to continue validation. pub fn check_inputs_not_owned( &self, is_owned: Arc, @@ -871,6 +897,10 @@ impl MaybeInputsOwned { )))) } + /// Get references to the input scripts for external ownership checks. + /// + /// Each reference can be marked with the result via [`InputOwnedReference::mark`] + /// and passed to [`MaybeInputsOwned::apply_input_owned_checks`]. pub fn get_input_script_refs(&self) -> Result>, ReceiverError> { self.0 .clone() @@ -882,6 +912,12 @@ impl MaybeInputsOwned { .map_err(ReceiverError::from) } + /// Apply the results of external input ownership checks, ensuring none of the + /// inputs are owned by the receiver. This prevents an attacker from spending + /// the receiver's own inputs. + /// + /// Returns a [`MaybeInputsOwnedTransition`] that, once persisted, + /// yields a [`MaybeInputsSeen`] to continue validation. pub fn apply_input_owned_checks( &self, checked_input_scripts: Vec>, @@ -956,6 +992,12 @@ pub trait IsOutputKnown: Send + Sync { #[uniffi::export] impl MaybeInputsSeen { + /// Check that none of the inputs have been seen before, preventing input + /// probing and replay attacks (where inputs have been used in a previous + /// payjoin attempt). + /// + /// Returns a [`MaybeInputsSeenTransition`] that, once persisted, + /// yields an [`OutputsUnknown`] to continue validation. pub fn check_no_inputs_seen_before( &self, is_known: Arc, @@ -969,6 +1011,10 @@ impl MaybeInputsSeen { )))) } + /// Get references to the input outpoints for external outpoint seen checks. + /// + /// Each reference can be marked with the result via [`InputSeenReference::mark`] + /// and passed to [`MaybeInputsSeen::apply_input_seen_checks`]. pub fn get_input_outpoint_refs(&self) -> Vec> { self.0 .clone() @@ -977,6 +1023,12 @@ impl MaybeInputsSeen { .collect::>() } + /// Apply the results of external outpoint seen checks, ensuring none of + /// the inputs have been seen before. This prevents input probing and replay + /// attacks (where inputs have been used in a previous payjoin attempt). + /// + /// Returns a [`MaybeInputsSeenTransition`] that, once persisted, + /// yields an [`OutputsUnknown`] to continue validation. pub fn apply_input_seen_checks( &self, checked_input_outpoints: Vec>, @@ -1048,7 +1100,11 @@ impl_save_for_transition!(OutputsUnknownTransition, WantsOutputs); #[uniffi::export] impl OutputsUnknown { - /// Find which outputs belong to the receiver + /// Identify which outputs in the original transaction belong to the receiver + /// and ensure at least one output pays the receiver. + /// + /// Returns an [`OutputsUnknownTransition`] that, once persisted, + /// yields a [`WantsOutputs`] to continue the proposal. pub fn identify_receiver_outputs( &self, is_receiver_output: Arc, @@ -1062,6 +1118,10 @@ impl OutputsUnknown { )))) } + /// Get references to the output scripts for external ownership checks. + /// + /// Each reference can be marked with the result via [`OutputOwnedReference::mark`] + /// and passed to [`OutputsUnknown::apply_output_owned_checks`]. pub fn get_output_script_refs(&self) -> Vec> { self.0 .clone() @@ -1070,6 +1130,12 @@ impl OutputsUnknown { .collect::>() } + /// Apply the results of external output ownership checks, identifying which + /// outputs in the original transaction belong to the receiver and ensuring + /// at least one output pays the receiver. + /// + /// Returns an [`OutputsUnknownTransition`] that, once persisted, + /// yields a [`WantsOutputs`] to continue the proposal. pub fn apply_output_owned_checks( &self, checked_output_scripts: Vec>, @@ -1111,8 +1177,14 @@ impl_save_for_transition!(WantsOutputsTransition, WantsInputs); #[uniffi::export] impl WantsOutputs { + /// Returns whether output substitution is enabled for this session. pub fn output_substitution(&self) -> OutputSubstitution { self.0.output_substitution() } + /// Replace all receiver outputs with the provided `replacement_outputs`, + /// and set up the `drain_script` as the receiver-owned output whose value + /// may be adjusted based on modifications in subsequent states. + /// + /// Returns an updated [`WantsOutputs`] with the replaced outputs. pub fn replace_receiver_outputs( &self, replacement_outputs: Vec, @@ -1130,6 +1202,9 @@ impl WantsOutputs { .map_err(Into::into) } + /// Substitute the receiver output script with the provided script. + /// + /// Returns an updated [`WantsOutputs`] with the substituted output. pub fn substitute_receiver_script( &self, output_script_pubkey: Vec, @@ -1143,6 +1218,10 @@ impl WantsOutputs { .map_err(Into::into) } + /// Commit the output modifications and proceed to input contribution. + /// + /// Returns a [`WantsOutputsTransition`] that, once persisted, + /// yields a [`WantsInputs`]. pub fn commit_outputs(&self) -> WantsOutputsTransition { WantsOutputsTransition(Arc::new(RwLock::new(Some(self.0.clone().commit_outputs())))) } @@ -1199,6 +1278,9 @@ impl WantsInputs { } } + /// Add the provided inputs to the payjoin proposal. + /// + /// Returns an updated [`WantsInputs`] with the contributed inputs. pub fn contribute_inputs( &self, replacement_inputs: Vec>, @@ -1212,6 +1294,10 @@ impl WantsInputs { .map_err(Into::into) } + /// Commit the input contributions and proceed to fee negotiation. + /// + /// Returns a [`WantsInputsTransition`] that, once persisted, + /// yields a [`WantsFeeRange`]. pub fn commit_inputs(&self) -> WantsInputsTransition { WantsInputsTransition(Arc::new(RwLock::new(Some(self.0.clone().commit_inputs())))) } @@ -1350,6 +1436,10 @@ pub trait ProcessPsbt: Send + Sync { #[uniffi::export] impl ProvisionalProposal { + /// Finalize the proposal by signing the PSBT via the `process_psbt` callback. + /// + /// Returns a [`ProvisionalProposalTransition`] that, once persisted, + /// yields the final [`PayjoinProposal`]. pub fn finalize_proposal( &self, process_psbt: Arc, @@ -1364,8 +1454,13 @@ impl ProvisionalProposal { )))) } + /// Extract the PSBT that needs to be signed by the receiver's wallet. pub fn psbt_to_sign(&self) -> String { self.0.clone().psbt_to_sign().to_string() } + /// Finalize the proposal with a signed PSBT. + /// + /// Returns a [`ProvisionalProposalTransition`] that, once persisted, + /// yields the final [`PayjoinProposal`]. pub fn finalize_signed_proposal(&self, signed_psbt: String) -> ProvisionalProposalTransition { ProvisionalProposalTransition(Arc::new(RwLock::new(Some( self.0.clone().finalize_proposal(|_| { @@ -1414,6 +1509,8 @@ impl_save_for_transition!(PayjoinProposalTransition, Monitor); #[uniffi::export] impl PayjoinProposal { + /// Returns the outpoints of receiver UTXOs that should be locked to + /// prevent double-spending while the payjoin is in progress. pub fn utxos_to_be_locked(&self) -> Vec { let mut outpoints: Vec = Vec::new(); for o in String { , @@ -1529,6 +1627,8 @@ impl HasReplyableErrorTransition { #[uniffi::export] impl HasReplyableError { + /// Construct an OHTTP encapsulated POST request to post the receiver + /// error response to the directory so it can be retrieved by the sender. pub fn create_error_request( &self, ohttp_relay: String, @@ -1538,6 +1638,11 @@ impl HasReplyableError { }) } + /// Process the response from the directory after posting the receiver + /// error response. + /// + /// Returns a [`HasReplyableErrorTransition`] that, once persisted, + /// completes the error reporting. pub fn process_error_response( &self, body: &[u8], @@ -1618,6 +1723,11 @@ fn try_deserialize_tx( #[uniffi::export] impl Monitor { + /// Check the network for the payjoin or fallback transaction via the + /// `transaction_exists` callback. + /// + /// Returns a [`MonitorTransition`] that, once persisted, completes + /// the session if a transaction is found. pub fn check_payment( &self, transaction_exists: Arc, @@ -1629,22 +1739,43 @@ impl Monitor { .map_err(|e| ImplementationError::new(e).into()) }))))) } + + /// Returns the txid of the fallback transaction. pub fn extract_fallback_txid(&self) -> String { self.0.clone().extract_fallback_txid().to_string() } + /// Returns the txid of the payjoin proposal transaction. pub fn extract_payjoin_proposal_txid(&self) -> String { self.0.clone().extract_payjoin_proposal_txid().to_string() } + /// Check whether the fallback transaction can be monitored. If the + /// fallback transaction includes non-SegWit inputs, the fallback + /// transaction ID can change when the sender signs again, making + /// monitoring impossible and concluding the session. + /// + /// Returns a [`MonitorTransition`] that, once persisted, yields a + /// [`Monitor`] to continue monitoring or completes the session if + /// monitoring is not possible. pub fn check_fallback_monitorable(&self) -> MonitorTransition { MonitorTransition(Arc::new(RwLock::new(Some(self.0.clone().check_fallback_monitorable())))) } + /// Signal that the fallback transaction exists on the network, + /// completing the session. + /// + /// Returns a [`MonitorTransition`] that, once persisted, completes + /// the session. pub fn fallback_tx_exists(&self) -> MonitorTransition { MonitorTransition(Arc::new(RwLock::new(Some(self.0.clone().fallback_tx_exists())))) } + /// Signal that the payjoin transaction exists on the network, + /// completing the session. + /// + /// Returns a [`MonitorTransition`] that, once persisted, completes + /// the session. pub fn payjoin_tx_exists( &self, payjoin_tx: Vec, From 9eaebe1d7931686b44866a6857e0db5bda6de49b Mon Sep 17 00:00:00 2001 From: xstoicunicornx Date: Wed, 20 May 2026 13:49:55 -0500 Subject: [PATCH 7/8] Update payjoin-cli to non-blocking receive interface Migrate both v1 and v2 receiver flows in payjoin-cli from the callback-based validation API to the two-phase extract/apply non-blocking API. --- payjoin-cli/src/app/v1.rs | 26 +++++---- payjoin-cli/src/app/v2/mod.rs | 101 +++++++++++++++++----------------- 2 files changed, 65 insertions(+), 62 deletions(-) diff --git a/payjoin-cli/src/app/v1.rs b/payjoin-cli/src/app/v1.rs index f557facd3..41681a6d1 100644 --- a/payjoin-cli/src/app/v1.rs +++ b/payjoin-cli/src/app/v1.rs @@ -13,7 +13,7 @@ use hyper_util::rt::TokioIo; use payjoin::bitcoin::consensus::encode::serialize_hex; use payjoin::bitcoin::{Amount, FeeRate}; use payjoin::receive::v1::{PayjoinProposal, UncheckedOriginalPayload}; -use payjoin::receive::Error; +use payjoin::receive::{check_references, Error}; use payjoin::send::v1::SenderBuilder; use payjoin::{ImplementationError, IntoUrl, Uri, UriExt}; use tokio::net::TcpListener; @@ -347,33 +347,37 @@ impl App { let wallet = self.wallet(); // Receive Check 1: Can Broadcast - let proposal = proposal.check_broadcast_suitability(None, |tx| { - wallet - .can_broadcast(tx) - .map_err(|e| ImplementationError::from(e.into_boxed_dyn_error())) - })?; + let is_broadcast_suitable = wallet + .can_broadcast(&proposal.extract_tx_to_check_broadcast_suitability()) + .map_err(|e| ImplementationError::from(e.into_boxed_dyn_error()))?; + let proposal = proposal.apply_broadcast_suitability(None, is_broadcast_suitable)?; tracing::trace!("check1"); // in a payment processor where the sender could go offline, this is where you schedule to broadcast the original_tx let _to_broadcast_in_failure_case = proposal.extract_tx_to_schedule_broadcast(); // Receive Check 2: receiver can't sign for proposal inputs - let proposal = proposal.check_inputs_not_owned(&mut |input| { + let refs = proposal.get_input_script_refs()?; + let checked_refs = check_references(refs, &mut |input| { wallet.is_mine(input).map_err(|e| ImplementationError::from(e.into_boxed_dyn_error())) })?; + let proposal = proposal.apply_input_owned_checks(checked_refs)?; tracing::trace!("check2"); // Receive Check 3: have we seen this input before? More of a check for non-interactive i.e. payment processor receivers. - let payjoin = proposal.check_no_inputs_seen_before(&mut |input| { - Ok(self.db.insert_input_seen_before(*input)?) - })?; + let refs = proposal.get_input_outpoint_refs(); + let checked_refs = + check_references(refs, &mut |input| Ok(self.db.insert_input_seen_before(*input)?))?; + let payjoin = proposal.apply_input_seen_checks(checked_refs)?; tracing::trace!("check3"); - let payjoin = payjoin.identify_receiver_outputs(&mut |output_script| { + let refs = payjoin.get_output_script_refs(); + let checked_refs = check_references(refs, &mut |output_script| { wallet .is_mine(output_script) .map_err(|e| ImplementationError::from(e.into_boxed_dyn_error())) })?; + let payjoin = payjoin.apply_output_owned_checks(checked_refs)?; let payjoin = payjoin .substitute_receiver_script( diff --git a/payjoin-cli/src/app/v2/mod.rs b/payjoin-cli/src/app/v2/mod.rs index 67dca2910..739418618 100644 --- a/payjoin-cli/src/app/v2/mod.rs +++ b/payjoin-cli/src/app/v2/mod.rs @@ -5,6 +5,7 @@ use anyhow::{anyhow, Context, Result}; use payjoin::bitcoin::consensus::encode::serialize_hex; use payjoin::bitcoin::{Amount, FeeRate}; use payjoin::persist::{OptionalTransitionOutcome, SessionPersister}; +use payjoin::receive::check_references; use payjoin::receive::v2::{ replay_event_log as replay_receiver_event_log, HasReplyableError, Initialized, MaybeInputsOwned, MaybeInputsSeen, Monitor, OutputsUnknown, PayjoinProposal, @@ -702,13 +703,11 @@ impl App { persister: &ReceiverPersister, ) -> Result<()> { let wallet = self.wallet(); - let proposal = proposal - .check_broadcast_suitability(None, |tx| { - wallet - .can_broadcast(tx) - .map_err(|e| ImplementationError::from(e.into_boxed_dyn_error())) - }) - .save(persister)?; + let is_broadcast_suitable = wallet + .can_broadcast(&proposal.extract_tx_to_check_broadcast_suitability()) + .map_err(|e| ImplementationError::from(e.into_boxed_dyn_error()))?; + let proposal = + proposal.apply_broadcast_suitability(None, is_broadcast_suitable).save(persister)?; println!("Fallback transaction received. Consider broadcasting this to get paid if the Payjoin fails:"); println!("{}", serialize_hex(&proposal.extract_tx_to_schedule_broadcast())); @@ -721,13 +720,11 @@ impl App { persister: &ReceiverPersister, ) -> Result<()> { let wallet = self.wallet(); - let proposal = proposal - .check_inputs_not_owned(&mut |input| { - wallet - .is_mine(input) - .map_err(|e| ImplementationError::from(e.into_boxed_dyn_error())) - }) - .save(persister)?; + let refs = proposal.get_input_script_refs()?; + let checked_refs = check_references(refs, &mut |input| { + wallet.is_mine(input).map_err(|e| ImplementationError::from(e.into_boxed_dyn_error())) + })?; + let proposal = proposal.apply_input_owned_checks(checked_refs).save(persister)?; self.check_no_inputs_seen_before(proposal, persister).await } @@ -736,11 +733,10 @@ impl App { proposal: Receiver, persister: &ReceiverPersister, ) -> Result<()> { - let proposal = proposal - .check_no_inputs_seen_before(&mut |input| { - Ok(self.db.insert_input_seen_before(*input)?) - }) - .save(persister)?; + let refs = proposal.get_input_outpoint_refs(); + let checked_refs = + check_references(refs, &mut |input| Ok(self.db.insert_input_seen_before(*input)?))?; + let proposal = proposal.apply_input_seen_checks(checked_refs).save(persister)?; self.identify_receiver_outputs(proposal, persister).await } @@ -750,13 +746,13 @@ impl App { persister: &ReceiverPersister, ) -> Result<()> { let wallet = self.wallet(); - let proposal = proposal - .identify_receiver_outputs(&mut |output_script| { - wallet - .is_mine(output_script) - .map_err(|e| ImplementationError::from(e.into_boxed_dyn_error())) - }) - .save(persister)?; + let refs = proposal.get_output_script_refs(); + let checked_refs = check_references(refs, &mut |output_script| { + wallet + .is_mine(output_script) + .map_err(|e| ImplementationError::from(e.into_boxed_dyn_error())) + })?; + let proposal = proposal.apply_output_owned_checks(checked_refs).save(persister)?; self.commit_outputs(proposal, persister).await } @@ -804,13 +800,11 @@ impl App { persister: &ReceiverPersister, ) -> Result<()> { let wallet = self.wallet(); - let proposal = proposal - .finalize_proposal(|psbt| { - wallet - .process_psbt(psbt) - .map_err(|e| ImplementationError::from(e.into_boxed_dyn_error())) - }) - .save(persister)?; + let psbt = proposal.psbt_to_sign(); + let signed_psbt = wallet + .process_psbt(&psbt) + .map_err(|e| ImplementationError::from(e.into_boxed_dyn_error()))?; + let proposal = proposal.finalize_signed_proposal(&signed_psbt).save(persister)?; self.send_payjoin_proposal(proposal, persister).await } @@ -847,27 +841,32 @@ impl App { tracing::debug!("Polling for payment confirmation"); + let fallback_txid = proposal.extract_fallback_txid(); + let payjoin_txid = proposal.extract_payjoin_proposal_txid(); + let get_raw_tx = |txid| { + self.wallet() + .get_raw_transaction(&txid) + .map_err(|e| ImplementationError::from(e.into_boxed_dyn_error())) + }; + match proposal.check_fallback_monitorable().save(persister)? { + OptionalTransitionOutcome::Progress(_) => { + println!("Unable to monitor for fallback tx containing non-segwit inputs, completing session"); + return Ok(()); + } + OptionalTransitionOutcome::Stasis(_) => {} + } let result = tokio::time::timeout(timeout_duration, async { loop { interval.tick().await; - let check_result = proposal - .check_payment(|txid| { - self.wallet() - .get_raw_transaction(&txid) - .map_err(|e| ImplementationError::from(e.into_boxed_dyn_error())) - }) - .save(persister); - - match check_result { - Ok(_) => { - println!("Payjoin transaction detected in the mempool!"); - return Ok(()); - } - Err(_) => { - // keep polling - - continue; - } + if let Some(tx) = get_raw_tx(payjoin_txid)? { + proposal.payjoin_tx_exists(tx).save(persister)?; + println!("Payjoin transaction detected in the mempool!"); + return Ok(()); + }; + if get_raw_tx(fallback_txid)?.is_some() { + proposal.fallback_tx_exists().save(persister)?; + println!("Fallback transaction detected in the mempool!"); + return Ok(()); } } }) From 38b475e6d59ee66efa1cb129dab3d5edd3521d8a Mon Sep 17 00:00:00 2001 From: xstoicunicornx Date: Sat, 23 May 2026 12:00:42 -0500 Subject: [PATCH 8/8] Remove unused PsbtContext::finalize_proposal --- payjoin/src/core/receive/mod.rs | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/payjoin/src/core/receive/mod.rs b/payjoin/src/core/receive/mod.rs index 6144cece3..c7a9e9b90 100644 --- a/payjoin/src/core/receive/mod.rs +++ b/payjoin/src/core/receive/mod.rs @@ -466,19 +466,6 @@ impl PsbtContext { psbt } - /// Finalizes the Payjoin proposal into a PSBT which the sender will find acceptable before - /// they sign the transaction and broadcast it to the network. - /// - /// Finalization consists of signing and finalizing the PSBT using the passed `wallet_process_psbt` signing function. - fn finalize_proposal( - self, - wallet_process_psbt: impl Fn(&Psbt) -> Result, - ) -> Result { - let psbt = self.psbt_to_sign(); - let signed_psbt = wallet_process_psbt(&psbt)?; - self.finalize_signed_proposal(signed_psbt) - } - /// Finalizes the signed payjoin proposal PSBT which the sender will find acceptable before /// they sign the transaction and broadcast it to the network. ///