@@ -15,9 +15,20 @@ def sample_data():
15
15
return pd .DataFrame (
16
16
{
17
17
cols .customer_id : [1 , 2 , 3 , 4 , 5 , 5 , 6 , 7 , 8 , 8 , 9 , 10 ],
18
- "group_1_idx" : [True , False , False , False , False , True , True , False , False , True , False , True ],
19
- "group_2_idx" : [False , True , False , False , True , False , False , True , False , False , True , False ],
20
- "group_3_idx" : [False , False , True , False , False , False , False , False , True , False , False , False ],
18
+ "category_1_name" : [
19
+ "Jeans" ,
20
+ "Shoes" ,
21
+ "Dresses" ,
22
+ "Hats" ,
23
+ "Shoes" ,
24
+ "Jeans" ,
25
+ "Jeans" ,
26
+ "Shoes" ,
27
+ "Dresses" ,
28
+ "Jeans" ,
29
+ "Shoes" ,
30
+ "Jeans" ,
31
+ ],
21
32
cols .unit_spend : [10 , 20 , 30 , 40 , 20 , 50 , 10 , 20 , 30 , 15 , 40 , 50 ],
22
33
},
23
34
)
@@ -27,14 +38,16 @@ def test_calc_cross_shop_two_groups(sample_data):
27
38
"""Test the _calc_cross_shop method with two groups."""
28
39
cross_shop_df = CrossShop ._calc_cross_shop (
29
40
sample_data ,
30
- group_1_idx = sample_data ["group_1_idx" ],
31
- group_2_idx = sample_data ["group_2_idx" ],
41
+ group_1_col = "category_1_name" ,
42
+ group_1_val = "Jeans" ,
43
+ group_2_col = "category_1_name" ,
44
+ group_2_val = "Shoes" ,
32
45
)
33
46
ret_df = pd .DataFrame (
34
47
{
35
48
cols .customer_id : [1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 ],
36
- "group_1" : [1 , 0 , 0 , 0 , 1 , 1 , 0 , 1 , 0 , 1 ],
37
- "group_2" : [0 , 1 , 0 , 0 , 1 , 0 , 1 , 0 , 1 , 0 ],
49
+ "group_1" : pd . Series ( [1 , 0 , 0 , 0 , 1 , 1 , 0 , 1 , 0 , 1 ], dtype = "int32" ) ,
50
+ "group_2" : pd . Series ( [0 , 1 , 0 , 0 , 1 , 0 , 1 , 0 , 1 , 0 ], dtype = "int32" ) ,
38
51
"groups" : [(1 , 0 ), (0 , 1 ), (0 , 0 ), (0 , 0 ), (1 , 1 ), (1 , 0 ), (0 , 1 ), (1 , 0 ), (0 , 1 ), (1 , 0 )],
39
52
cols .unit_spend : [10 , 20 , 30 , 40 , 70 , 10 , 20 , 45 , 40 , 50 ],
40
53
},
@@ -47,16 +60,19 @@ def test_calc_cross_shop_three_groups(sample_data):
47
60
"""Test the _calc_cross_shop method with three groups."""
48
61
cross_shop_df = CrossShop ._calc_cross_shop (
49
62
sample_data ,
50
- group_1_idx = sample_data ["group_1_idx" ],
51
- group_2_idx = sample_data ["group_2_idx" ],
52
- group_3_idx = sample_data ["group_3_idx" ],
63
+ group_1_col = "category_1_name" ,
64
+ group_1_val = "Jeans" ,
65
+ group_2_col = "category_1_name" ,
66
+ group_2_val = "Shoes" ,
67
+ group_3_col = "category_1_name" ,
68
+ group_3_val = "Dresses" ,
53
69
)
54
70
ret_df = pd .DataFrame (
55
71
{
56
72
cols .customer_id : [1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 ],
57
- "group_1" : [1 , 0 , 0 , 0 , 1 , 1 , 0 , 1 , 0 , 1 ],
58
- "group_2" : [0 , 1 , 0 , 0 , 1 , 0 , 1 , 0 , 1 , 0 ],
59
- "group_3" : [0 , 0 , 1 , 0 , 0 , 0 , 0 , 1 , 0 , 0 ],
73
+ "group_1" : pd . Series ( [1 , 0 , 0 , 0 , 1 , 1 , 0 , 1 , 0 , 1 ], dtype = "int32" ) ,
74
+ "group_2" : pd . Series ( [0 , 1 , 0 , 0 , 1 , 0 , 1 , 0 , 1 , 0 ], dtype = "int32" ) ,
75
+ "group_3" : pd . Series ( [0 , 0 , 1 , 0 , 0 , 0 , 0 , 1 , 0 , 0 ], dtype = "int32" ) ,
60
76
"groups" : [
61
77
(1 , 0 , 0 ),
62
78
(0 , 1 , 0 ),
@@ -73,39 +89,19 @@ def test_calc_cross_shop_three_groups(sample_data):
73
89
},
74
90
).set_index (cols .customer_id )
75
91
76
- assert cross_shop_df .equals (ret_df )
77
-
78
-
79
- def test_calc_cross_shop_two_groups_overlap_error (sample_data ):
80
- """Test the _calc_cross_shop method with two groups and overlapping group indices."""
81
- with pytest .raises (ValueError ):
82
- CrossShop ._calc_cross_shop (
83
- sample_data ,
84
- # Pass the same group index for both groups
85
- group_1_idx = sample_data ["group_1_idx" ],
86
- group_2_idx = sample_data ["group_1_idx" ],
87
- )
88
-
89
-
90
- def test_calc_cross_shop_three_groups_overlap_error (sample_data ):
91
- """Test the _calc_cross_shop method with three groups and overlapping group indices."""
92
- with pytest .raises (ValueError ):
93
- CrossShop ._calc_cross_shop (
94
- sample_data ,
95
- # Pass the same group index for groups 1 and 3
96
- group_1_idx = sample_data ["group_1_idx" ],
97
- group_2_idx = sample_data ["group_2_idx" ],
98
- group_3_idx = sample_data ["group_1_idx" ],
99
- )
92
+ pd .testing .assert_frame_equal (cross_shop_df , ret_df , check_dtype = False )
100
93
101
94
102
95
def test_calc_cross_shop_three_groups_customer_id_nunique (sample_data ):
103
96
"""Test the _calc_cross_shop method with three groups and customer_id as the value column."""
104
97
cross_shop_df = CrossShop ._calc_cross_shop (
105
98
sample_data ,
106
- group_1_idx = sample_data ["group_1_idx" ],
107
- group_2_idx = sample_data ["group_2_idx" ],
108
- group_3_idx = sample_data ["group_3_idx" ],
99
+ group_1_col = "category_1_name" ,
100
+ group_1_val = "Jeans" ,
101
+ group_2_col = "category_1_name" ,
102
+ group_2_val = "Shoes" ,
103
+ group_3_col = "category_1_name" ,
104
+ group_3_val = "Dresses" ,
109
105
value_col = cols .customer_id ,
110
106
agg_func = "nunique" ,
111
107
)
@@ -131,17 +127,20 @@ def test_calc_cross_shop_three_groups_customer_id_nunique(sample_data):
131
127
index = [1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 ],
132
128
)
133
129
ret_df .index .name = cols .customer_id
134
-
130
+ ret_df = ret_df . astype ({ "group_1" : "int32" , "group_2" : "int32" , "group_3" : "int32" })
135
131
assert cross_shop_df .equals (ret_df )
136
132
137
133
138
134
def test_calc_cross_shop_table (sample_data ):
139
135
"""Test the _calc_cross_shop_table method."""
140
136
cross_shop_df = CrossShop ._calc_cross_shop (
141
137
sample_data ,
142
- group_1_idx = sample_data ["group_1_idx" ],
143
- group_2_idx = sample_data ["group_2_idx" ],
144
- group_3_idx = sample_data ["group_3_idx" ],
138
+ group_1_col = "category_1_name" ,
139
+ group_1_val = "Jeans" ,
140
+ group_2_col = "category_1_name" ,
141
+ group_2_val = "Shoes" ,
142
+ group_3_col = "category_1_name" ,
143
+ group_3_val = "Dresses" ,
145
144
value_col = cols .unit_spend ,
146
145
)
147
146
cross_shop_table = CrossShop ._calc_cross_shop_table (
@@ -174,9 +173,12 @@ def test_calc_cross_shop_table_customer_id_nunique(sample_data):
174
173
"""Test the _calc_cross_shop_table method with customer_id as the value column."""
175
174
cross_shop_df = CrossShop ._calc_cross_shop (
176
175
sample_data ,
177
- group_1_idx = sample_data ["group_1_idx" ],
178
- group_2_idx = sample_data ["group_2_idx" ],
179
- group_3_idx = sample_data ["group_3_idx" ],
176
+ group_1_col = "category_1_name" ,
177
+ group_1_val = "Jeans" ,
178
+ group_2_col = "category_1_name" ,
179
+ group_2_val = "Shoes" ,
180
+ group_3_col = "category_1_name" ,
181
+ group_3_val = "Dresses" ,
180
182
value_col = cols .customer_id ,
181
183
agg_func = "nunique" ,
182
184
)
@@ -195,19 +197,24 @@ def test_calc_cross_shop_table_customer_id_nunique(sample_data):
195
197
assert cross_shop_table .equals (ret_df )
196
198
197
199
198
- def test_calc_cross_shop_all_groups_false (sample_data ):
199
- """Test the _calc_cross_shop method with all group indices set to False ."""
200
- with pytest .raises (ValueError ):
200
+ def test_calc_cross_shop_invalid_group_3 (sample_data ):
201
+ """Test that _calc_cross_shop raises ValueError if only one of group_3_col or group_3_val is provided ."""
202
+ with pytest .raises (ValueError , match = "If group_3_col or group_3_val is populated, then the other must be as well" ):
201
203
CrossShop ._calc_cross_shop (
202
204
sample_data ,
203
- group_1_idx = [False ] * len (sample_data ),
204
- group_2_idx = [False ] * len (sample_data ),
205
+ group_1_col = "category_1_name" ,
206
+ group_1_val = "Jeans" ,
207
+ group_2_col = "category_1_name" ,
208
+ group_2_val = "Shoes" ,
209
+ group_3_col = "category_1_name" ,
205
210
)
206
211
207
- with pytest .raises (ValueError ):
212
+ with pytest .raises (ValueError , match = "If group_3_col or group_3_val is populated, then the other must be as well" ):
208
213
CrossShop ._calc_cross_shop (
209
214
sample_data ,
210
- group_1_idx = [False ] * len (sample_data ),
211
- group_2_idx = [False ] * len (sample_data ),
212
- group_3_idx = [False ] * len (sample_data ),
215
+ group_1_col = "category_1_name" ,
216
+ group_1_val = "Jeans" ,
217
+ group_2_col = "category_1_name" ,
218
+ group_2_val = "Shoes" ,
219
+ group_3_val = "T-Shirts" ,
213
220
)
0 commit comments