From 2dea1a96c4eca23086fe17e7b08b369f774d257b Mon Sep 17 00:00:00 2001 From: Elias Rohrer Date: Fri, 1 Dec 2023 11:31:24 +0100 Subject: [PATCH 1/3] Use `proptest` for `test_calculate_amount_to_forward` --- .github/workflows/build.yml | 2 +- Cargo.toml | 3 + .../lsps2/channel_manager.txt | 6 ++ src/lsps2/service.rs | 78 ++++++++++++------- 4 files changed, 61 insertions(+), 28 deletions(-) create mode 100644 proptest-regressions/lsps2/channel_manager.txt diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index f93a776..e304183 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -28,7 +28,7 @@ jobs: - name: Pin crates for MSRV if: matrix.msrv run: | - # No need to pin currently + cargo update -p proptest --precise "1.2.0" --verbose # proptest 1.3.0 requires rustc 1.64.0 - name: Cargo check run: cargo check --release - name: Check documentation diff --git a/Cargo.toml b/Cargo.toml index a71a03c..b909522 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,3 +22,6 @@ core2 = { version = "0.3.0", optional = true, default-features = false } chrono = { version = "0.4", default-features = false, features = ["serde", "alloc"] } serde = { version = "1.0", default-features = false, features = ["derive", "alloc"] } serde_json = "1.0" + +[dev-dependencies] +proptest = "1.0.0" diff --git a/proptest-regressions/lsps2/channel_manager.txt b/proptest-regressions/lsps2/channel_manager.txt new file mode 100644 index 0000000..341ba3f --- /dev/null +++ b/proptest-regressions/lsps2/channel_manager.txt @@ -0,0 +1,6 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. diff --git a/src/lsps2/service.rs b/src/lsps2/service.rs index 706d8ab..b772b58 100644 --- a/src/lsps2/service.rs +++ b/src/lsps2/service.rs @@ -810,36 +810,60 @@ fn calculate_amount_to_forward_per_htlc( mod tests { use super::*; + use proptest::prelude::*; - #[test] - fn test_calculate_amount_to_forward() { - // TODO: Use proptest to generate random allocations - let htlcs = vec![ - InterceptedHTLC { - intercept_id: InterceptId([0; 32]), - expected_outbound_amount_msat: 1000, - }, - InterceptedHTLC { - intercept_id: InterceptId([1; 32]), - expected_outbound_amount_msat: 2000, - }, - InterceptedHTLC { - intercept_id: InterceptId([2; 32]), - expected_outbound_amount_msat: 3000, - }, - ]; - - let total_amt_to_forward_msat = 5000; + const MAX_VALUE_MSAT: u64 = 21_000_000_0000_0000_000; - let result = calculate_amount_to_forward_per_htlc(&htlcs, total_amt_to_forward_msat); - - assert_eq!(result[0].0, htlcs[0].intercept_id); - assert_eq!(result[0].1, 834); + fn arb_forward_amounts() -> impl Strategy { + (1u64..MAX_VALUE_MSAT, 1u64..MAX_VALUE_MSAT, 1u64..MAX_VALUE_MSAT, 1u64..MAX_VALUE_MSAT) + .prop_map(|(a, b, c, d)| { + (a, b, c, core::cmp::min(d, a.saturating_add(b).saturating_add(c))) + }) + } - assert_eq!(result[1].0, htlcs[1].intercept_id); - assert_eq!(result[1].1, 1667); + proptest! { + #[test] + fn test_calculate_amount_to_forward((o_0, o_1, o_2, total_amt_to_forward_msat) in arb_forward_amounts()) { + let htlcs = vec![ + InterceptedHTLC { + intercept_id: InterceptId([0; 32]), + expected_outbound_amount_msat: o_0 + }, + InterceptedHTLC { + intercept_id: InterceptId([1; 32]), + expected_outbound_amount_msat: o_1 + }, + InterceptedHTLC { + intercept_id: InterceptId([2; 32]), + expected_outbound_amount_msat: o_2 + }, + ]; + + let result = calculate_amount_to_forward_per_htlc(&htlcs, total_amt_to_forward_msat); + let total_received_msat = o_0 + o_1 + o_2; + + if total_received_msat < total_amt_to_forward_msat { + assert_eq!(result.len(), 0); + } else { + assert_ne!(result.len(), 0); + assert_eq!(result[0].0, htlcs[0].intercept_id); + assert_eq!(result[1].0, htlcs[1].intercept_id); + assert_eq!(result[2].0, htlcs[2].intercept_id); + assert!(result[0].1 <= o_0); + assert!(result[1].1 <= o_1); + assert!(result[2].1 <= o_2); + + let result_sum = result.iter().map(|(_, f)| f).sum::(); + assert!(result_sum >= total_amt_to_forward_msat); + let five_pct = result_sum as f32 * 0.1; + let fair_share_0 = ((o_0 as f32 / total_received_msat as f32) * result_sum as f32).max(o_0 as f32); + assert!(result[0].1 as f32 <= fair_share_0 + five_pct); + let fair_share_1 = ((o_1 as f32 / total_received_msat as f32) * result_sum as f32).max(o_1 as f32); + assert!(result[1].1 as f32 <= fair_share_1 + five_pct); + let fair_share_2 = ((o_2 as f32 / total_received_msat as f32) * result_sum as f32).max(o_2 as f32); + assert!(result[2].1 as f32 <= fair_share_2 + five_pct); + } - assert_eq!(result[2].0, htlcs[2].intercept_id); - assert_eq!(result[2].1, 2499); + } } } From b94b2fec6b385d5bd1639e0fbf8d0b675df149bf Mon Sep 17 00:00:00 2001 From: Elias Rohrer Date: Fri, 1 Dec 2023 12:39:36 +0100 Subject: [PATCH 2/3] Adapt `calculate_amount_to_forward_per_htlc` to avoid over/underflows --- src/lsps2/service.rs | 37 ++++++++++++++++++++++--------------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/src/lsps2/service.rs b/src/lsps2/service.rs index b772b58..1e2ba83 100644 --- a/src/lsps2/service.rs +++ b/src/lsps2/service.rs @@ -779,31 +779,38 @@ where fn calculate_amount_to_forward_per_htlc( htlcs: &[InterceptedHTLC], total_amt_to_forward_msat: u64, ) -> Vec<(InterceptId, u64)> { + // TODO: we should eventually make sure the HTLCs are all above ChannelDetails::next_outbound_minimum_msat let total_received_msat: u64 = htlcs.iter().map(|htlc| htlc.expected_outbound_amount_msat).sum(); - let mut fee_remaining_msat = total_received_msat - total_amt_to_forward_msat; - let total_fee_msat = fee_remaining_msat; + match total_received_msat.checked_sub(total_amt_to_forward_msat) { + Some(total_fee_msat) => { + let mut fee_remaining_msat = total_fee_msat; - let mut per_htlc_forwards = vec![]; + let mut per_htlc_forwards = vec![]; - for (index, htlc) in htlcs.iter().enumerate() { - let proportional_fee_amt_msat = - total_fee_msat * htlc.expected_outbound_amount_msat / total_received_msat; + for (index, htlc) in htlcs.iter().enumerate() { + let proportional_fee_amt_msat = + total_fee_msat * (htlc.expected_outbound_amount_msat / total_received_msat); - let mut actual_fee_amt_msat = core::cmp::min(fee_remaining_msat, proportional_fee_amt_msat); - fee_remaining_msat -= actual_fee_amt_msat; + let mut actual_fee_amt_msat = + core::cmp::min(fee_remaining_msat, proportional_fee_amt_msat); + fee_remaining_msat -= actual_fee_amt_msat; - if index == htlcs.len() - 1 { - actual_fee_amt_msat += fee_remaining_msat; - } + if index == htlcs.len() - 1 { + actual_fee_amt_msat += fee_remaining_msat; + } - let amount_to_forward_msat = htlc.expected_outbound_amount_msat - actual_fee_amt_msat; + let amount_to_forward_msat = + htlc.expected_outbound_amount_msat.saturating_sub(actual_fee_amt_msat); - per_htlc_forwards.push((htlc.intercept_id, amount_to_forward_msat)) - } + per_htlc_forwards.push((htlc.intercept_id, amount_to_forward_msat)) + } - per_htlc_forwards + per_htlc_forwards + } + None => Vec::new(), + } } #[cfg(test)] From 77d9078e84d8b320338e278e738970bffdd8d766 Mon Sep 17 00:00:00 2001 From: Elias Rohrer Date: Tue, 5 Dec 2023 14:56:20 +0100 Subject: [PATCH 3/3] Add test for `compute_opening_fee` --- src/lsps2/utils.rs | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/src/lsps2/utils.rs b/src/lsps2/utils.rs index 6721a5b..9a63d2c 100644 --- a/src/lsps2/utils.rs +++ b/src/lsps2/utils.rs @@ -54,3 +54,29 @@ pub fn compute_opening_fee( .and_then(|f| f.checked_div(1000000)) .map(|f| core::cmp::max(f, opening_fee_min_fee_msat)) } + +#[cfg(test)] +mod tests { + use super::*; + use proptest::prelude::*; + + const MAX_VALUE_MSAT: u64 = 21_000_000_0000_0000_000; + + fn arb_opening_fee_params() -> impl Strategy { + (0u64..MAX_VALUE_MSAT, 0u64..MAX_VALUE_MSAT, 0u64..MAX_VALUE_MSAT) + } + + proptest! { + #[test] + fn test_compute_opening_fee((payment_size_msat, opening_fee_min_fee_msat, opening_fee_proportional) in arb_opening_fee_params()) { + if let Some(res) = compute_opening_fee(payment_size_msat, opening_fee_min_fee_msat, opening_fee_proportional) { + assert!(res >= opening_fee_min_fee_msat); + assert_eq!(res as f32, (payment_size_msat as f32 * opening_fee_proportional as f32)); + } else { + // Check we actually overflowed. + let max_value = u64::MAX as u128; + assert!((payment_size_msat as u128 * opening_fee_proportional as u128) > max_value); + } + } + } +}