Skip to content
94 changes: 73 additions & 21 deletions lightning/src/blinded_path/payment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,12 @@ impl BlindedPaymentPath {
)
}

fn new_inner<ES: EntropySource, T: secp256k1::Signing + secp256k1::Verification>(
intermediate_nodes: &[PaymentForwardNode], payee_node_id: PublicKey,
fn new_inner<
F: ForwardTlvsInfo,
ES: EntropySource,
T: secp256k1::Signing + secp256k1::Verification,
>(
intermediate_nodes: &[ForwardNode<F>], payee_node_id: PublicKey,
local_node_receive_key: ReceiveAuthKey, dummy_tlvs: &[DummyTlvs], payee_tlvs: ReceiveTlvs,
htlc_maximum_msat: u64, min_final_cltv_expiry_delta: u16, entropy_source: ES,
secp_ctx: &Secp256k1<T>,
Expand Down Expand Up @@ -323,18 +327,36 @@ impl BlindedPaymentPath {
}
}

/// An intermediate node, its outbound channel, and relay parameters.
/// Common interface for forward TLV types used in blinded payment paths.
///
/// Both [`ForwardTlvs`] (channel-based forwarding) and [`TrampolineForwardTlvs`] (trampoline
/// node-based forwarding) implement this trait, allowing blinded path construction to be generic
/// over the forwarding mechanism.
pub trait ForwardTlvsInfo: Writeable + Clone {
/// The payment relay parameters for this hop.
fn payment_relay(&self) -> &PaymentRelay;
/// The payment constraints for this hop.
fn payment_constraints(&self) -> &PaymentConstraints;
/// The features for this hop.
fn features(&self) -> &BlindedHopFeatures;
}

/// An intermediate node, its forwarding parameters, and its [`ForwardTlvsInfo`] for use in a
/// [`BlindedPaymentPath`].
#[derive(Clone, Debug)]
pub struct PaymentForwardNode {
pub struct ForwardNode<F: ForwardTlvsInfo> {
/// The TLVs for this node's [`BlindedHop`], where the fee parameters contained within are also
/// used for [`BlindedPayInfo`] construction.
pub tlvs: ForwardTlvs,
pub tlvs: F,
/// This node's pubkey.
pub node_id: PublicKey,
/// The maximum value, in msat, that may be accepted by this node.
pub htlc_maximum_msat: u64,
}

/// An intermediate node for a regular (non-trampoline) [`BlindedPaymentPath`].
pub type PaymentForwardNode = ForwardNode<ForwardTlvs>;

/// Data to construct a [`BlindedHop`] for forwarding a payment.
#[derive(Clone, Debug)]
pub struct ForwardTlvs {
Expand All @@ -354,6 +376,18 @@ pub struct ForwardTlvs {
pub next_blinding_override: Option<PublicKey>,
}

impl ForwardTlvsInfo for ForwardTlvs {
fn payment_relay(&self) -> &PaymentRelay {
&self.payment_relay
}
fn payment_constraints(&self) -> &PaymentConstraints {
&self.payment_constraints
}
fn features(&self) -> &BlindedHopFeatures {
&self.features
}
}

/// Data to construct a [`BlindedHop`] for forwarding a Trampoline payment.
#[derive(Clone, Debug)]
pub struct TrampolineForwardTlvs {
Expand All @@ -373,6 +407,18 @@ pub struct TrampolineForwardTlvs {
pub next_blinding_override: Option<PublicKey>,
}

impl ForwardTlvsInfo for TrampolineForwardTlvs {
fn payment_relay(&self) -> &PaymentRelay {
&self.payment_relay
}
fn payment_constraints(&self) -> &PaymentConstraints {
&self.payment_constraints
}
fn features(&self) -> &BlindedHopFeatures {
&self.features
}
}

/// TLVs carried by a dummy hop within a blinded payment path.
///
/// Dummy hops do not correspond to real forwarding decisions, but are processed
Expand Down Expand Up @@ -440,8 +486,8 @@ pub(crate) enum BlindedTrampolineTlvs {

// Used to include forward and receive TLVs in the same iterator for encoding.
#[derive(Clone)]
enum BlindedPaymentTlvsRef<'a> {
Forward(&'a ForwardTlvs),
enum BlindedPaymentTlvsRef<'a, F: ForwardTlvsInfo = ForwardTlvs> {
Forward(&'a F),
Dummy(&'a DummyTlvs),
Receive(&'a ReceiveTlvs),
}
Expand Down Expand Up @@ -619,7 +665,7 @@ impl Writeable for ReceiveTlvs {
}
}

impl<'a> Writeable for BlindedPaymentTlvsRef<'a> {
impl<'a, F: ForwardTlvsInfo> Writeable for BlindedPaymentTlvsRef<'a, F> {
fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
match self {
Self::Forward(tlvs) => tlvs.write(w)?,
Expand Down Expand Up @@ -723,8 +769,8 @@ impl Readable for BlindedTrampolineTlvs {
pub(crate) const PAYMENT_PADDING_ROUND_OFF: usize = 30;

/// Construct blinded payment hops for the given `intermediate_nodes` and payee info.
pub(super) fn blinded_hops<T: secp256k1::Signing + secp256k1::Verification>(
secp_ctx: &Secp256k1<T>, intermediate_nodes: &[PaymentForwardNode], payee_node_id: PublicKey,
pub(super) fn blinded_hops<F: ForwardTlvsInfo, T: secp256k1::Signing + secp256k1::Verification>(
secp_ctx: &Secp256k1<T>, intermediate_nodes: &[ForwardNode<F>], payee_node_id: PublicKey,
dummy_tlvs: &[DummyTlvs], payee_tlvs: ReceiveTlvs, session_priv: &SecretKey,
local_node_receive_key: ReceiveAuthKey,
) -> Vec<BlindedHop> {
Expand Down Expand Up @@ -823,15 +869,15 @@ where
Ok((curr_base_fee, curr_prop_mil))
}

pub(super) fn compute_payinfo(
intermediate_nodes: &[PaymentForwardNode], dummy_tlvs: &[DummyTlvs], payee_tlvs: &ReceiveTlvs,
pub(super) fn compute_payinfo<F: ForwardTlvsInfo>(
intermediate_nodes: &[ForwardNode<F>], dummy_tlvs: &[DummyTlvs], payee_tlvs: &ReceiveTlvs,
payee_htlc_maximum_msat: u64, min_final_cltv_expiry_delta: u16,
) -> Result<BlindedPayInfo, ()> {
let routing_fees = intermediate_nodes
.iter()
.map(|node| RoutingFees {
base_msat: node.tlvs.payment_relay.fee_base_msat,
proportional_millionths: node.tlvs.payment_relay.fee_proportional_millionths,
base_msat: node.tlvs.payment_relay().fee_base_msat,
proportional_millionths: node.tlvs.payment_relay().fee_proportional_millionths,
})
.chain(dummy_tlvs.iter().map(|tlvs| RoutingFees {
base_msat: tlvs.payment_relay.fee_base_msat,
Expand All @@ -847,24 +893,24 @@ pub(super) fn compute_payinfo(
for node in intermediate_nodes.iter() {
// In the future, we'll want to take the intersection of all supported features for the
// `BlindedPayInfo`, but there are no features in that context right now.
if node.tlvs.features.requires_unknown_bits_from(&BlindedHopFeatures::empty()) {
if node.tlvs.features().requires_unknown_bits_from(&BlindedHopFeatures::empty()) {
return Err(());
}

cltv_expiry_delta =
cltv_expiry_delta.checked_add(node.tlvs.payment_relay.cltv_expiry_delta).ok_or(())?;
cltv_expiry_delta.checked_add(node.tlvs.payment_relay().cltv_expiry_delta).ok_or(())?;

// The min htlc for an intermediate node is that node's min minus the fees charged by all of the
// following hops for forwarding that min, since that fee amount will automatically be included
// in the amount that this node receives and contribute towards reaching its min.
htlc_minimum_msat = amt_to_forward_msat(
core::cmp::max(node.tlvs.payment_constraints.htlc_minimum_msat, htlc_minimum_msat),
&node.tlvs.payment_relay,
core::cmp::max(node.tlvs.payment_constraints().htlc_minimum_msat, htlc_minimum_msat),
node.tlvs.payment_relay(),
)
.unwrap_or(1); // If underflow occurs, we definitely reached this node's min
htlc_maximum_msat = amt_to_forward_msat(
core::cmp::min(node.htlc_maximum_msat, htlc_maximum_msat),
&node.tlvs.payment_relay,
node.tlvs.payment_relay(),
)
.ok_or(())?; // If underflow occurs, we cannot send to this hop without exceeding their max
}
Expand Down Expand Up @@ -1038,8 +1084,14 @@ mod tests {
payment_constraints: PaymentConstraints { max_cltv_expiry: 0, htlc_minimum_msat: 1 },
payment_context: PaymentContext::Bolt12Refund(Bolt12RefundContext {}),
};
let blinded_payinfo =
super::compute_payinfo(&[], &[], &recv_tlvs, 4242, TEST_FINAL_CLTV as u16).unwrap();
let blinded_payinfo = super::compute_payinfo::<ForwardTlvs>(
&[],
&[],
&recv_tlvs,
4242,
TEST_FINAL_CLTV as u16,
)
.unwrap();
assert_eq!(blinded_payinfo.fee_base_msat, 0);
assert_eq!(blinded_payinfo.fee_proportional_millionths, 0);
assert_eq!(blinded_payinfo.cltv_expiry_delta, TEST_FINAL_CLTV as u16);
Expand Down
Loading
Loading