Skip to content
This repository was archived by the owner on Jan 6, 2025. It is now read-only.

Add initial proptest to check forwarding amount calculation #56

Merged
merged 3 commits into from
Dec 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
- name: Pin crates for MSRV
if: matrix.msrv
run: |
# No need to pin currently
cargo update -p proptest --precise "1.2.0" --verbose # proptest 1.3.0 requires rustc 1.64.0
- name: Cargo check
run: cargo check --release
- name: Check documentation
Expand Down
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,6 @@ core2 = { version = "0.3.0", optional = true, default-features = false }
chrono = { version = "0.4", default-features = false, features = ["serde", "alloc"] }
serde = { version = "1.0", default-features = false, features = ["derive", "alloc"] }
serde_json = "1.0"

[dev-dependencies]
proptest = "1.0.0"
6 changes: 6 additions & 0 deletions proptest-regressions/lsps2/channel_manager.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Seeds for failure cases proptest has generated in the past. It is
# automatically read and these particular cases re-run before any
# novel cases are generated.
#
# It is recommended to check this file in to source control so that
# everyone who runs the test benefits from these saved cases.
115 changes: 73 additions & 42 deletions src/lsps2/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -779,67 +779,98 @@ where
fn calculate_amount_to_forward_per_htlc(
htlcs: &[InterceptedHTLC], total_amt_to_forward_msat: u64,
) -> Vec<(InterceptId, u64)> {
// TODO: we should eventually make sure the HTLCs are all above ChannelDetails::next_outbound_minimum_msat
let total_received_msat: u64 =
htlcs.iter().map(|htlc| htlc.expected_outbound_amount_msat).sum();

let mut fee_remaining_msat = total_received_msat - total_amt_to_forward_msat;
let total_fee_msat = fee_remaining_msat;
match total_received_msat.checked_sub(total_amt_to_forward_msat) {
Some(total_fee_msat) => {
let mut fee_remaining_msat = total_fee_msat;

let mut per_htlc_forwards = vec![];
let mut per_htlc_forwards = vec![];

for (index, htlc) in htlcs.iter().enumerate() {
let proportional_fee_amt_msat =
total_fee_msat * htlc.expected_outbound_amount_msat / total_received_msat;
for (index, htlc) in htlcs.iter().enumerate() {
let proportional_fee_amt_msat =
total_fee_msat * (htlc.expected_outbound_amount_msat / total_received_msat);

let mut actual_fee_amt_msat = core::cmp::min(fee_remaining_msat, proportional_fee_amt_msat);
fee_remaining_msat -= actual_fee_amt_msat;
let mut actual_fee_amt_msat =
core::cmp::min(fee_remaining_msat, proportional_fee_amt_msat);
fee_remaining_msat -= actual_fee_amt_msat;

if index == htlcs.len() - 1 {
actual_fee_amt_msat += fee_remaining_msat;
}
if index == htlcs.len() - 1 {
actual_fee_amt_msat += fee_remaining_msat;
}

let amount_to_forward_msat = htlc.expected_outbound_amount_msat - actual_fee_amt_msat;
let amount_to_forward_msat =
htlc.expected_outbound_amount_msat.saturating_sub(actual_fee_amt_msat);

per_htlc_forwards.push((htlc.intercept_id, amount_to_forward_msat))
}
per_htlc_forwards.push((htlc.intercept_id, amount_to_forward_msat))
}

per_htlc_forwards
per_htlc_forwards
}
None => Vec::new(),
}
}

#[cfg(test)]
mod tests {

use super::*;
use proptest::prelude::*;

#[test]
fn test_calculate_amount_to_forward() {
// TODO: Use proptest to generate random allocations
let htlcs = vec![
InterceptedHTLC {
intercept_id: InterceptId([0; 32]),
expected_outbound_amount_msat: 1000,
},
InterceptedHTLC {
intercept_id: InterceptId([1; 32]),
expected_outbound_amount_msat: 2000,
},
InterceptedHTLC {
intercept_id: InterceptId([2; 32]),
expected_outbound_amount_msat: 3000,
},
];

let total_amt_to_forward_msat = 5000;

let result = calculate_amount_to_forward_per_htlc(&htlcs, total_amt_to_forward_msat);
const MAX_VALUE_MSAT: u64 = 21_000_000_0000_0000_000;

assert_eq!(result[0].0, htlcs[0].intercept_id);
assert_eq!(result[0].1, 834);
fn arb_forward_amounts() -> impl Strategy<Value = (u64, u64, u64, u64)> {
(1u64..MAX_VALUE_MSAT, 1u64..MAX_VALUE_MSAT, 1u64..MAX_VALUE_MSAT, 1u64..MAX_VALUE_MSAT)
.prop_map(|(a, b, c, d)| {
(a, b, c, core::cmp::min(d, a.saturating_add(b).saturating_add(c)))
})
}

assert_eq!(result[1].0, htlcs[1].intercept_id);
assert_eq!(result[1].1, 1667);
proptest! {
#[test]
fn test_calculate_amount_to_forward((o_0, o_1, o_2, total_amt_to_forward_msat) in arb_forward_amounts()) {
let htlcs = vec![
InterceptedHTLC {
intercept_id: InterceptId([0; 32]),
expected_outbound_amount_msat: o_0
},
InterceptedHTLC {
intercept_id: InterceptId([1; 32]),
expected_outbound_amount_msat: o_1
},
InterceptedHTLC {
intercept_id: InterceptId([2; 32]),
expected_outbound_amount_msat: o_2
},
];

let result = calculate_amount_to_forward_per_htlc(&htlcs, total_amt_to_forward_msat);
let total_received_msat = o_0 + o_1 + o_2;

if total_received_msat < total_amt_to_forward_msat {
assert_eq!(result.len(), 0);
} else {
assert_ne!(result.len(), 0);
assert_eq!(result[0].0, htlcs[0].intercept_id);
assert_eq!(result[1].0, htlcs[1].intercept_id);
assert_eq!(result[2].0, htlcs[2].intercept_id);
assert!(result[0].1 <= o_0);
assert!(result[1].1 <= o_1);
assert!(result[2].1 <= o_2);

let result_sum = result.iter().map(|(_, f)| f).sum::<u64>();
assert!(result_sum >= total_amt_to_forward_msat);
let five_pct = result_sum as f32 * 0.1;
let fair_share_0 = ((o_0 as f32 / total_received_msat as f32) * result_sum as f32).max(o_0 as f32);
assert!(result[0].1 as f32 <= fair_share_0 + five_pct);
let fair_share_1 = ((o_1 as f32 / total_received_msat as f32) * result_sum as f32).max(o_1 as f32);
assert!(result[1].1 as f32 <= fair_share_1 + five_pct);
let fair_share_2 = ((o_2 as f32 / total_received_msat as f32) * result_sum as f32).max(o_2 as f32);
assert!(result[2].1 as f32 <= fair_share_2 + five_pct);
}

assert_eq!(result[2].0, htlcs[2].intercept_id);
assert_eq!(result[2].1, 2499);
}
}
}
26 changes: 26 additions & 0 deletions src/lsps2/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,29 @@ pub fn compute_opening_fee(
.and_then(|f| f.checked_div(1000000))
.map(|f| core::cmp::max(f, opening_fee_min_fee_msat))
}

#[cfg(test)]
mod tests {
use super::*;
use proptest::prelude::*;

const MAX_VALUE_MSAT: u64 = 21_000_000_0000_0000_000;

fn arb_opening_fee_params() -> impl Strategy<Value = (u64, u64, u64)> {
(0u64..MAX_VALUE_MSAT, 0u64..MAX_VALUE_MSAT, 0u64..MAX_VALUE_MSAT)
}

proptest! {
#[test]
fn test_compute_opening_fee((payment_size_msat, opening_fee_min_fee_msat, opening_fee_proportional) in arb_opening_fee_params()) {
if let Some(res) = compute_opening_fee(payment_size_msat, opening_fee_min_fee_msat, opening_fee_proportional) {
assert!(res >= opening_fee_min_fee_msat);
assert_eq!(res as f32, (payment_size_msat as f32 * opening_fee_proportional as f32));
} else {
// Check we actually overflowed.
let max_value = u64::MAX as u128;
assert!((payment_size_msat as u128 * opening_fee_proportional as u128) > max_value);
}
}
}
}