77from pandas .testing import assert_frame_equal
88
99from pyretailscience .metrics .distribution .acv import Acv
10+ from pyretailscience .options import ColumnHelper
11+
12+ cols = ColumnHelper ()
1013
1114
1215class TestAcv :
@@ -16,10 +19,10 @@ def test_acv_total_no_grouping(self):
1619 """Test total ACV across all transactions without grouping."""
1720 df = pd .DataFrame (
1821 {
19- " customer_id" : [1 , 2 , 3 , 1 , 2 ],
20- " store_id" : [101 , 101 , 102 , 102 , 103 ],
21- " product_id" : [10 , 20 , 30 , 40 , 50 ],
22- " unit_spend" : [500_000.0 , 750_000.0 , 300_000.0 , 600_000.0 , 350_000.0 ],
22+ cols . customer_id : [1 , 2 , 3 , 1 , 2 ],
23+ cols . store_id : [101 , 101 , 102 , 102 , 103 ],
24+ cols . product_id : [10 , 20 , 30 , 40 , 50 ],
25+ cols . unit_spend : [500_000.0 , 750_000.0 , 300_000.0 , 600_000.0 , 350_000.0 ],
2326 }
2427 )
2528 result = Acv (df ).df
@@ -31,15 +34,15 @@ def test_acv_grouped_by_store(self, input_type):
3134 """Test ACV grouped by store returns correct per-store values for both input types."""
3235 pdf = pd .DataFrame (
3336 {
34- " store_id" : [101 , 101 , 102 , 102 , 103 ],
35- " unit_spend" : [400_000.0 , 600_000.0 , 300_000.0 , 200_000.0 , 500_000.0 ],
37+ cols . store_id : [101 , 101 , 102 , 102 , 103 ],
38+ cols . unit_spend : [400_000.0 , 600_000.0 , 300_000.0 , 200_000.0 , 500_000.0 ],
3639 }
3740 )
3841 df = ibis .memtable (pdf ) if input_type == "ibis" else pdf
39- result = Acv (df , group_col = " store_id" ).df .sort_values (" store_id" ).reset_index (drop = True )
42+ result = Acv (df , group_col = cols . store_id ).df .sort_values (cols . store_id ).reset_index (drop = True )
4043 expected = pd .DataFrame (
4144 {
42- " store_id" : [101 , 102 , 103 ],
45+ cols . store_id : [101 , 102 , 103 ],
4346 "acv" : [1.0 , 0.5 , 0.5 ],
4447 }
4548 )
@@ -49,15 +52,15 @@ def test_acv_group_col_list(self):
4952 """Test ACV grouped by multiple columns."""
5053 df = pd .DataFrame (
5154 {
52- " store_id" : [101 , 101 , 102 ],
55+ cols . store_id : [101 , 101 , 102 ],
5356 "region" : ["North" , "North" , "South" ],
54- " unit_spend" : [1_000_000.0 , 500_000.0 , 2_000_000.0 ],
57+ cols . unit_spend : [1_000_000.0 , 500_000.0 , 2_000_000.0 ],
5558 }
5659 )
57- result = Acv (df , group_col = [" store_id" , "region" ]).df .sort_values (" store_id" ).reset_index (drop = True )
60+ result = Acv (df , group_col = [cols . store_id , "region" ]).df .sort_values (cols . store_id ).reset_index (drop = True )
5861 expected = pd .DataFrame (
5962 {
60- " store_id" : [101 , 102 ],
63+ cols . store_id : [101 , 102 ],
6164 "region" : ["North" , "South" ],
6265 "acv" : [1.5 , 2.0 ],
6366 }
@@ -68,37 +71,37 @@ def test_acv_with_nan_values(self):
6871 """Test that NaN values are excluded from the ACV sum."""
6972 df = pd .DataFrame (
7073 {
71- " store_id" : [101 , 101 , 102 ],
72- " unit_spend" : [1_000_000.0 , np .nan , 500_000.0 ],
74+ cols . store_id : [101 , 101 , 102 ],
75+ cols . unit_spend : [1_000_000.0 , np .nan , 500_000.0 ],
7376 }
7477 )
75- result = Acv (df , group_col = " store_id" ).df .sort_values (" store_id" ).reset_index (drop = True )
78+ result = Acv (df , group_col = cols . store_id ).df .sort_values (cols . store_id ).reset_index (drop = True )
7679 expected = pd .DataFrame (
7780 {
78- " store_id" : [101 , 102 ],
81+ cols . store_id : [101 , 102 ],
7982 "acv" : [1.0 , 0.5 ],
8083 }
8184 )
8285 assert_frame_equal (result , expected )
8386
8487 def test_acv_missing_column_raises (self ):
8588 """Test that missing unit_spend column raises ValueError."""
86- df = pd .DataFrame ({" customer_id" : [1 , 2 ], " store_id" : [101 , 102 ]})
89+ df = pd .DataFrame ({cols . customer_id : [1 , 2 ], cols . store_id : [101 , 102 ]})
8790 with pytest .raises (ValueError , match = "missing" ):
8891 Acv (df )
8992
90- def test_acv_missing_group_col_column_raises (self ):
93+ def test_acv_missing_group_col_raises (self ):
9194 """Test that missing group_col column raises ValueError."""
92- df = pd .DataFrame ({" unit_spend" : [100.0 , 200.0 ]})
95+ df = pd .DataFrame ({cols . unit_spend : [100.0 , 200.0 ]})
9396 with pytest .raises (ValueError , match = "missing" ):
94- Acv (df , group_col = " store_id" )
97+ Acv (df , group_col = cols . store_id )
9598
9699 def test_acv_custom_scale_factor (self ):
97100 """Test ACV with a custom scale factor."""
98101 df = pd .DataFrame (
99102 {
100- " store_id" : [101 , 102 ],
101- " unit_spend" : [5_000.0 , 10_000.0 ],
103+ cols . store_id : [101 , 102 ],
104+ cols . unit_spend : [5_000.0 , 10_000.0 ],
102105 }
103106 )
104107 result = Acv (df , acv_scale_factor = 1_000 ).df
@@ -108,11 +111,11 @@ def test_acv_custom_scale_factor(self):
108111 @pytest .mark .parametrize ("scale_factor" , [0 , - 1_000 ])
109112 def test_acv_non_positive_scale_factor_raises (self , scale_factor ):
110113 """Test that zero or negative acv_scale_factor raises ValueError."""
111- df = pd .DataFrame ({" unit_spend" : [500_000.0 , 1_000_000.0 ]})
114+ df = pd .DataFrame ({cols . unit_spend : [500_000.0 , 1_000_000.0 ]})
112115 with pytest .raises (ValueError , match = "acv_scale_factor must be positive" ):
113116 Acv (df , acv_scale_factor = scale_factor )
114117
115118 def test_acv_invalid_type_raises (self ):
116119 """Test that passing a non-DataFrame/Table raises TypeError."""
117120 with pytest .raises (TypeError , match = "pandas DataFrame or an Ibis Table" ):
118- Acv ({" unit_spend" : [100.0 ]})
121+ Acv ({cols . unit_spend : [100.0 ]})
0 commit comments