diff --git a/payjoin/src/core/receive/mod.rs b/payjoin/src/core/receive/mod.rs index 5b28abcf5..bac207c53 100644 --- a/payjoin/src/core/receive/mod.rs +++ b/payjoin/src/core/receive/mod.rs @@ -484,7 +484,9 @@ pub(crate) mod tests { witness, Amount, PubkeyHash, ScriptBuf, ScriptHash, Sequence, Txid, WScriptHash, XOnlyPublicKey, }; - use payjoin_test_utils::{DUMMY20, DUMMY32, PARSED_ORIGINAL_PSBT, QUERY_PARAMS}; + use payjoin_test_utils::{ + DUMMY20, DUMMY32, PARSED_ORIGINAL_PSBT, PARSED_PAYJOIN_PROPOSAL, QUERY_PARAMS, + }; use super::*; use crate::psbt::InternalPsbtInputError::InvalidScriptPubKey; @@ -496,6 +498,24 @@ pub(crate) mod tests { OriginalPayload { psbt: PARSED_ORIGINAL_PSBT.clone(), params } } + pub(crate) fn original_missing_prevtxout_from_test_vector() -> OriginalPayload { + let params = Params::from_query_str(QUERY_PARAMS, &[Version::One]) + .expect("Could not parse params from query str"); + let mut psbt: Psbt = PARSED_ORIGINAL_PSBT.clone(); + for psbtin in psbt.inputs_mut() { + psbtin.non_witness_utxo = None; + psbtin.witness_utxo = None; + } + OriginalPayload { psbt: psbt.clone(), params } + } + + pub(crate) fn psbt_context_from_test_vector() -> PsbtContext { + PsbtContext { + payjoin_psbt: PARSED_PAYJOIN_PROPOSAL.clone(), + original_psbt: PARSED_ORIGINAL_PSBT.clone(), + } + } + #[test] fn input_pair_with_expected_weight() { let p2wsh_txout = TxOut { @@ -830,6 +850,141 @@ pub(crate) mod tests { assert_eq!(err, PsbtInputError::from(InternalPsbtInputError::ProvidedUnnecessaryWeight)); } + #[test] + fn test_check_broadcast_suitability() { + let original = original_from_test_vector(); + + // Outcome 1: min_fee_rate too high → PsbtBelowFeeRate error + let err = original + .clone() + .check_broadcast_suitability(Some(FeeRate::MAX), |_| Ok(true)) + .expect_err("Should fail when fee rate is below minimum"); + match err { + Error::Protocol(ProtocolError::OriginalPayload(PayloadError( + InternalPayloadError::PsbtBelowFeeRate(original_fee_rate, min_fee_rate), + ))) => { + assert_eq!(original_fee_rate, original.psbt_fee_rate().unwrap()); + assert_eq!(min_fee_rate, FeeRate::MAX); + } + _ => panic!("Expected PsbtBelowFeeRate error, got: {err:?}"), + } + + // Outcome 2: can_broadcast returns false → OriginalPsbtNotBroadcastable error + let err = original + .clone() + .check_broadcast_suitability(None, |_| Ok(false)) + .expect_err("Should fail when can_broadcast returns false"); + match err { + Error::Protocol(ProtocolError::OriginalPayload(PayloadError( + InternalPayloadError::OriginalPsbtNotBroadcastable, + ))) => {} + _ => panic!("Expected OriginalPsbtNotBroadcastable error, got: {err:?}"), + } + + // Outcome 3: can_broadcast returns an implementation error → Error::Implementation + let err = original + .clone() + .check_broadcast_suitability(None, |_| { + Err(ImplementationError::from("broadcast check failed")) + }) + .expect_err("Should fail when can_broadcast returns an implementation error"); + match err { + Error::Implementation(error_message) => { + assert_eq!(error_message.to_string(), "broadcast check failed".to_string()) + } + _ => panic!("Expected Error::Implementation, got: {err:?}"), + } + + // Outcome 4: success + original + .check_broadcast_suitability(None, |_| Ok(true)) + .expect("Should succeed when fee rate is acceptable and can_broadcast returns true"); + } + + #[test] + fn test_check_inputs_not_owned() { + let original = original_from_test_vector(); + let original_missing_prevtxout = original_missing_prevtxout_from_test_vector(); + + // Outcome 1: input_scripts returns a PrevTxOut error → Protocol error + let err = original_missing_prevtxout + .check_inputs_not_owned(&mut |_| Ok(false)) + .expect_err("Should fail when previous txout is missing"); + match err { + Error::Protocol(ProtocolError::OriginalPayload(PayloadError( + InternalPayloadError::PrevTxOut(_), + ))) => {} + _ => panic!("Expected PrevTxOut error, got: {err:?}"), + } + + // Outcome 2: is_owned returns true → InputOwned error + let err = original + .clone() + .check_inputs_not_owned(&mut |_| Ok(true)) + .expect_err("Should fail when input is owned"); + match err { + Error::Protocol(ProtocolError::OriginalPayload(PayloadError( + InternalPayloadError::InputOwned(_), + ))) => {} + _ => panic!("Expected InputOwned error, got: {err:?}"), + } + + // Outcome 3: is_owned returns an implementation error → Error::Implementation + let err = original + .clone() + .check_inputs_not_owned(&mut |_| { + Err(ImplementationError::from("ownership check failed")) + }) + .expect_err("Should fail when is_owned returns an implementation error"); + match err { + Error::Implementation(error_message) => { + assert_eq!(error_message.to_string(), "ownership check failed".to_string()) + } + _ => panic!("Expected Error::Implementation, got: {err:?}"), + } + + // Outcome 4: is_owned returns false → success + original + .check_inputs_not_owned(&mut |_| Ok(false)) + .expect("Should succeed when no inputs are owned"); + } + + #[test] + fn test_check_no_inputs_seen_before() { + let original = original_from_test_vector(); + + // Outcome 1: is_known returns true → InputSeen error + let err = original + .clone() + .check_no_inputs_seen_before(&mut |_| Ok(true)) + .expect_err("Should fail when input has been seen before"); + match err { + Error::Protocol(ProtocolError::OriginalPayload(PayloadError( + InternalPayloadError::InputSeen(_), + ))) => {} + _ => panic!("Expected InputSeen error, got: {err:?}"), + } + + // Outcome 2: is_known returns an implementation error → Error::Implementation + let err = original + .clone() + .check_no_inputs_seen_before(&mut |_| { + Err(ImplementationError::from("input seen check failed")) + }) + .expect_err("Should fail when is_known returns an implementation error"); + match err { + Error::Implementation(error_message) => { + assert_eq!(error_message.to_string(), "input seen check failed".to_string()) + } + _ => panic!("Expected Error::Implementation, got: {err:?}"), + } + + // Outcome 3: is_known returns false → success + original + .check_no_inputs_seen_before(&mut |_| Ok(false)) + .expect("Should succeed when no inputs have been seen before"); + } + #[test] fn test_identify_receiver_outputs() { let original = original_from_test_vector(); @@ -864,4 +1019,32 @@ pub(crate) mod tests { assert_eq!(wants_outputs.owned_vouts, vec![0, 1]); assert_eq!(wants_outputs.params.additional_fee_contribution, None); } + + #[test] + fn test_finalize_proposal() { + let psbt_context = psbt_context_from_test_vector(); + + // Outcome 1: wallet_process_psbt returns an implementation error → ImplementationError + let err = psbt_context + .clone() + .finalize_proposal(|_| Err(ImplementationError::from("wallet signing failed"))) + .expect_err("Should fail when wallet_process_psbt returns an error"); + assert_eq!(err.to_string(), "wallet signing failed"); + + // Outcome 2: wallet_process_psbt returns a psbt with mismatched ntxid → ImplementationError + let psbt_context = psbt_context_from_test_vector(); + let err = psbt_context + .clone() + .finalize_proposal(|_| { + // return a totally different psbt to trigger ntxid mismatch + Ok(PARSED_ORIGINAL_PSBT.clone()) + }) + .expect_err("Should fail when ntxid mismatches"); + assert!(err.to_string().contains("Ntxid mismatch")); + + // Outcome 3: wallet_process_psbt succeeds → Ok(Psbt) + let _psbt = psbt_context + .finalize_proposal(|_| Ok(PARSED_PAYJOIN_PROPOSAL.clone())) + .expect("Should succeed when wallet_process_psbt returns a valid signed psbt"); + } }