12
12
build_expected_unique_columns ,
13
13
build_non_null_columns ,
14
14
)
15
+ from pyretailscience .options import get_option
15
16
from pyretailscience .style .tailwind import COLORS
16
17
17
18
@@ -31,7 +32,12 @@ def add_segment(self, df: pd.DataFrame) -> pd.DataFrame:
31
32
ValueError: If the number of rows before and after the merge do not match.
32
33
"""
33
34
rows_before = len (df )
34
- df = df .merge (self .df [["segment_name" , "segment_id" ]], how = "left" , left_on = "customer_id" , right_index = True )
35
+ df = df .merge (
36
+ self .df [["segment_name" , "segment_id" ]],
37
+ how = "left" ,
38
+ left_on = get_option ("column.customer_id" ),
39
+ right_index = True ,
40
+ )
35
41
rows_after = len (df )
36
42
if rows_before != rows_after :
37
43
raise ValueError ("The number of rows before and after the merge do not match. This should not happen." )
@@ -51,7 +57,7 @@ def __init__(self, df: pd.DataFrame) -> None:
51
57
Raises:
52
58
ValueError: If the dataframe does not have the columns customer_id, segment_name and segment_id.
53
59
"""
54
- required_cols = " customer_id" , "segment_name" , "segment_id"
60
+ required_cols = get_option ( "column. customer_id") , "segment_name" , "segment_id"
55
61
contract = CustomContract (
56
62
df ,
57
63
basic_expectations = build_expected_columns (columns = required_cols ),
@@ -63,7 +69,9 @@ def __init__(self, df: pd.DataFrame) -> None:
63
69
msg = f"The dataframe requires the columns { required_cols } and they must be non-null and unique."
64
70
raise ValueError (msg )
65
71
66
- self .df = df [["customer_id" , "segment_name" , "segment_id" ]].set_index ("customer_id" )
72
+ self .df = df [[get_option ("column.customer_id" ), "segment_name" , "segment_id" ]].set_index (
73
+ get_option ("column.customer_id" ),
74
+ )
67
75
68
76
69
77
class ThresholdSegmentation (BaseSegmentation ):
@@ -74,7 +82,7 @@ def __init__(
74
82
df : pd .DataFrame ,
75
83
thresholds : list [float ],
76
84
segments : dict [any , str ],
77
- value_col : str = "total_price" ,
85
+ value_col : str | None = None ,
78
86
agg_func : str = "sum" ,
79
87
zero_segment_name : str = "Zero" ,
80
88
zero_segment_id : str = "Z" ,
@@ -86,7 +94,7 @@ def __init__(
86
94
df (pd.DataFrame): A dataframe with the transaction data. The dataframe must contain a customer_id column.
87
95
thresholds (List[float]): The percentile thresholds for segmentation.
88
96
segments (Dict[str, str]): A dictionary where keys are segment IDs and values are segment names.
89
- value_col (str): The column to use for the segmentation.
97
+ value_col (str, optional ): The column to use for the segmentation. Defaults to get_option("column.unit_spend") .
90
98
agg_func (str, optional): The aggregation function to use when grouping by customer_id. Defaults to "sum".
91
99
zero_segment_name (str, optional): The name of the segment for customers with zero spend. Defaults to "Zero".
92
100
zero_segment_id (str, optional): The ID of the segment for customers with zero spend. Defaults to "Z".
@@ -100,7 +108,9 @@ def __init__(
100
108
if df .empty :
101
109
raise ValueError ("Input DataFrame is empty" )
102
110
103
- required_cols = ["customer_id" , value_col ]
111
+ value_col = get_option ("column.unit_spend" ) if value_col is None else value_col
112
+
113
+ required_cols = [get_option ("column.customer_id" ), value_col ]
104
114
contract = CustomContract (
105
115
df ,
106
116
basic_expectations = build_expected_columns (columns = required_cols ),
@@ -128,7 +138,7 @@ def __init__(
128
138
raise ValueError ("The number of thresholds must match the number of segments." )
129
139
130
140
# Group by customer_id and calculate total_spend
131
- grouped_df = df .groupby (" customer_id" )[value_col ].agg (agg_func ).to_frame (value_col )
141
+ grouped_df = df .groupby (get_option ( "column. customer_id") )[value_col ].agg (agg_func ).to_frame (value_col )
132
142
133
143
# Separate customers with zero spend
134
144
self .df = grouped_df
@@ -138,7 +148,7 @@ def __init__(
138
148
zero_cust_df ["segment_name" ] = zero_segment_name
139
149
zero_cust_df ["segment_id" ] = zero_segment_id
140
150
141
- self .df = grouped_df [~ zero_idx ]
151
+ self .df = grouped_df [~ zero_idx ]. copy ()
142
152
143
153
# Create a new column 'segment' based on the total_spend
144
154
labels = list (segments .values ())
@@ -161,20 +171,25 @@ class HMLSegmentation(ThresholdSegmentation):
161
171
def __init__ (
162
172
self ,
163
173
df : pd .DataFrame ,
164
- value_col : str = "total_price" ,
174
+ value_col : str | None = None ,
165
175
agg_func : str = "sum" ,
166
176
zero_value_customers : Literal ["separate_segment" , "exclude" , "include_with_light" ] = "separate_segment" ,
167
177
) -> None :
168
178
"""Segments customers into Heavy, Medium, Light and Zero spenders based on the total spend.
169
179
180
+ HMLSegmentation is a subclass of ThresholdSegmentation and based around an industry standard definition. The
181
+ thresholds for Heavy (top 20%), Medium (next 30%) and Light (bottom 50%) are chosen based on the pareto
182
+ distribution, commonly know as the 80/20 rule. It is typically used in retail to segment customers based on
183
+ their spend, transaction volume or quantities purchased.
184
+
170
185
Args:
171
186
df (pd.DataFrame): A dataframe with the transaction data. The dataframe must contain a customer_id column.
172
- value_col (str, optional): The column to use for the segmentation. Defaults to "total_price" .
187
+ value_col (str, optional): The column to use for the segmentation. Defaults to get_option("column.unit_spend") .
173
188
agg_func (str, optional): The aggregation function to use when grouping by customer_id. Defaults to "sum".
174
189
zero_value_customers (Literal["separate_segment", "exclude", "include_with_light"], optional): How to handle
175
190
customers with zero spend. Defaults to "separate_segment".
176
191
"""
177
- thresholds = [0.500 , 0.800 , 1 ]
192
+ thresholds = [0 , 0 .500 , 0.800 , 1 ]
178
193
segments = {"L" : "Light" , "M" : "Medium" , "H" : "Heavy" }
179
194
super ().__init__ (
180
195
df = df ,
@@ -202,9 +217,14 @@ def __init__(self, df: pd.DataFrame, segment_col: str = "segment_id") -> None:
202
217
TransactionLevelContract.
203
218
204
219
"""
205
- required_cols = ["customer_id" , "total_price" , "transaction_id" , segment_col ]
206
- if "quantity" in df .columns :
207
- required_cols .append ("quantity" )
220
+ required_cols = [
221
+ get_option ("column.customer_id" ),
222
+ get_option ("column.unit_spend" ),
223
+ get_option ("column.transaction_id" ),
224
+ segment_col ,
225
+ ]
226
+ if get_option ("column.unit_quantity" ) in df .columns :
227
+ required_cols .append (get_option ("column.unit_quantity" ))
208
228
contract = CustomContract (
209
229
df ,
210
230
basic_expectations = build_expected_columns (columns = required_cols ),
@@ -222,18 +242,18 @@ def __init__(self, df: pd.DataFrame, segment_col: str = "segment_id") -> None:
222
242
@staticmethod
223
243
def _calc_seg_stats (df : pd .DataFrame , segment_col : str ) -> pd .DataFrame :
224
244
aggs = {
225
- "revenue" : ("total_price" , "sum" ),
226
- "transactions" : (" transaction_id" , "nunique" ),
227
- "customers" : (" customer_id" , "nunique" ),
245
+ get_option ( "column.agg.unit_spend" ) : (get_option ( "column.unit_spend" ) , "sum" ),
246
+ get_option ( "column.agg.transaction_id" ) : (get_option ( "column. transaction_id") , "nunique" ),
247
+ get_option ( "column.agg.customer_id" ) : (get_option ( "column. customer_id") , "nunique" ),
228
248
}
229
249
total_aggs = {
230
- "revenue" : [df ["total_price" ].sum ()],
231
- "transactions" : [df [" transaction_id" ].nunique ()],
232
- "customers" : [df [" customer_id" ].nunique ()],
250
+ get_option ( "column.agg.unit_spend" ) : [df [get_option ( "column.unit_spend" ) ].sum ()],
251
+ get_option ( "column.agg.transaction_id" ) : [df [get_option ( "column. transaction_id") ].nunique ()],
252
+ get_option ( "column.agg.customer_id" ) : [df [get_option ( "column. customer_id") ].nunique ()],
233
253
}
234
- if "quantity" in df .columns :
235
- aggs ["total_quantity" ] = ("quantity" , "sum" )
236
- total_aggs ["total_quantity" ] = [df ["quantity" ].sum ()]
254
+ if get_option ( "column.unit_quantity" ) in df .columns :
255
+ aggs [get_option ( "column.agg.unit_quantity" ) ] = (get_option ( "column.unit_quantity" ) , "sum" )
256
+ total_aggs [get_option ( "column.agg.unit_quantity" ) ] = [df [get_option ( "column.unit_quantity" ) ].sum ()]
237
257
238
258
stats_df = pd .concat (
239
259
[
@@ -242,9 +262,13 @@ def _calc_seg_stats(df: pd.DataFrame, segment_col: str) -> pd.DataFrame:
242
262
],
243
263
)
244
264
245
- if "quantity" in df .columns :
246
- stats_df ["price_per_unit" ] = stats_df ["revenue" ] / stats_df ["total_quantity" ]
247
- stats_df ["quantity_per_transaction" ] = stats_df ["total_quantity" ] / stats_df ["transactions" ]
265
+ if get_option ("column.unit_quantity" ) in df .columns :
266
+ stats_df [get_option ("column.calc.price_per_unit" )] = (
267
+ stats_df [get_option ("column.agg.unit_spend" )] / stats_df [get_option ("column.agg.unit_quantity" )]
268
+ )
269
+ stats_df [get_option ("column.calc.units_per_transaction" )] = (
270
+ stats_df [get_option ("column.agg.unit_quantity" )] / stats_df [get_option ("column.agg.transaction_id" )]
271
+ )
248
272
249
273
return stats_df
250
274
0 commit comments