@@ -67,27 +67,40 @@ def __init__(self, df: pd.DataFrame) -> None:
67
67
self .df = df [["customer_id" , "segment_name" , "segment_id" ]].set_index ("customer_id" )
68
68
69
69
70
- class HMLSegmentation (BaseSegmentation ):
71
- """Segments customers into Heavy, Medium, Light and Zero spenders based on the total spend ."""
70
+ class ThresholdSegmentation (BaseSegmentation ):
71
+ """Segments customers based on user-defined thresholds and segments ."""
72
72
73
73
def __init__ (
74
74
self ,
75
75
df : pd .DataFrame ,
76
+ thresholds : list [float ],
77
+ segments : dict [any , str ],
76
78
value_col : str = "total_price" ,
79
+ agg_func : str = "sum" ,
80
+ zero_segment_name : str = "Zero" ,
81
+ zero_segment_id : str = "Z" ,
77
82
zero_value_customers : Literal ["separate_segment" , "exclude" , "include_with_light" ] = "separate_segment" ,
78
83
) -> None :
79
- """Segments customers into Heavy, Medium, Light and Zero spenders based on the total spend .
84
+ """Segments customers based on user-defined thresholds and segments .
80
85
81
86
Args:
82
87
df (pd.DataFrame): A dataframe with the transaction data. The dataframe must contain a customer_id column.
83
- value_col (str, optional): The column to use for the segmentation. Defaults to "total_price".
88
+ thresholds (List[float]): The percentile thresholds for segmentation.
89
+ segments (Dict[str, str]): A dictionary where keys are segment IDs and values are segment names.
90
+ value_col (str): The column to use for the segmentation.
91
+ agg_func (str, optional): The aggregation function to use when grouping by customer_id. Defaults to "sum".
92
+ zero_segment_name (str, optional): The name of the segment for customers with zero spend. Defaults to "Zero".
93
+ zero_segment_id (str, optional): The ID of the segment for customers with zero spend. Defaults to "Z".
84
94
zero_value_customers (Literal["separate_segment", "exclude", "include_with_light"], optional): How to handle
85
95
customers with zero spend. Defaults to "separate_segment".
86
96
87
97
Raises:
88
98
ValueError: If the dataframe is missing the columns "customer_id" or `value_col`, or these columns contain
89
99
null values.
90
100
"""
101
+ if df .empty :
102
+ raise ValueError ("Input DataFrame is empty" )
103
+
91
104
required_cols = ["customer_id" , value_col ]
92
105
contract = CustomContract (
93
106
df ,
@@ -99,33 +112,79 @@ def __init__(
99
112
msg = f"The dataframe requires the columns { required_cols } and they must be non-null"
100
113
raise ValueError (msg )
101
114
115
+ if len (df ) < len (thresholds ):
116
+ msg = f"There are { len (df )} customers, which is less than the number of segment thresholds."
117
+ raise ValueError (msg )
118
+
119
+ if set (thresholds ) != set (thresholds ):
120
+ raise ValueError ("The thresholds must be unique." )
121
+
122
+ thresholds = sorted (thresholds )
123
+ if thresholds [0 ] != 0 :
124
+ thresholds = [0 , * thresholds ]
125
+ if thresholds [- 1 ] != 1 :
126
+ thresholds .append (1 )
127
+
128
+ if len (thresholds ) - 1 != len (segments ):
129
+ raise ValueError ("The number of thresholds must match the number of segments." )
130
+
102
131
# Group by customer_id and calculate total_spend
103
- grouped_df = df .groupby ("customer_id" )[value_col ].sum ( ).to_frame (value_col )
132
+ grouped_df = df .groupby ("customer_id" )[value_col ].agg ( agg_func ).to_frame (value_col )
104
133
105
134
# Separate customers with zero spend
106
- hml_df = grouped_df
135
+ self . df = grouped_df
107
136
if zero_value_customers in ["separate_segment" , "exclude" ]:
108
137
zero_idx = grouped_df [value_col ] == 0
109
- zero_cust_df = grouped_df [zero_idx ]
110
- zero_cust_df ["segment_name" ] = "Zero"
138
+ zero_cust_df = grouped_df [zero_idx ].copy ()
139
+ zero_cust_df ["segment_name" ] = zero_segment_name
140
+ zero_cust_df ["segment_id" ] = zero_segment_id
111
141
112
- hml_df = grouped_df [~ zero_idx ]
142
+ self . df = grouped_df [~ zero_idx ]
113
143
114
144
# Create a new column 'segment' based on the total_spend
115
- hml_df ["segment_name" ] = pd .qcut (
116
- hml_df [value_col ],
117
- q = [0 , 0.500 , 0.800 , 1 ],
118
- labels = ["Light" , "Medium" , "Heavy" ],
145
+ labels = list (segments .values ())
146
+
147
+ self .df ["segment_name" ] = pd .qcut (
148
+ self .df [value_col ],
149
+ q = thresholds ,
150
+ labels = labels ,
119
151
)
120
152
153
+ self .df ["segment_id" ] = self .df ["segment_name" ].map ({v : k for k , v in segments .items ()})
154
+
121
155
if zero_value_customers == "separate_segment" :
122
- hml_df = pd .concat ([hml_df , zero_cust_df ])
156
+ self . df = pd .concat ([self . df , zero_cust_df ])
123
157
124
- segment_code_map = {"Light" : "L" , "Medium" : "M" , "Heavy" : "H" , "Zero" : "Z" }
125
158
126
- hml_df ["segment_id" ] = hml_df ["segment_name" ].map (segment_code_map )
159
+ class HMLSegmentation (ThresholdSegmentation ):
160
+ """Segments customers into Heavy, Medium, Light and Zero spenders based on the total spend."""
127
161
128
- self .df = hml_df
162
+ def __init__ (
163
+ self ,
164
+ df : pd .DataFrame ,
165
+ value_col : str = "total_price" ,
166
+ agg_func : str = "sum" ,
167
+ zero_value_customers : Literal ["separate_segment" , "exclude" , "include_with_light" ] = "separate_segment" ,
168
+ ) -> None :
169
+ """Segments customers into Heavy, Medium, Light and Zero spenders based on the total spend.
170
+
171
+ Args:
172
+ df (pd.DataFrame): A dataframe with the transaction data. The dataframe must contain a customer_id column.
173
+ value_col (str, optional): The column to use for the segmentation. Defaults to "total_price".
174
+ agg_func (str, optional): The aggregation function to use when grouping by customer_id. Defaults to "sum".
175
+ zero_value_customers (Literal["separate_segment", "exclude", "include_with_light"], optional): How to handle
176
+ customers with zero spend. Defaults to "separate_segment".
177
+ """
178
+ thresholds = [0.500 , 0.800 , 1 ]
179
+ segments = {"L" : "Light" , "M" : "Medium" , "H" : "Heavy" }
180
+ super ().__init__ (
181
+ df = df ,
182
+ value_col = value_col ,
183
+ agg_func = agg_func ,
184
+ thresholds = thresholds ,
185
+ segments = segments ,
186
+ zero_value_customers = zero_value_customers ,
187
+ )
129
188
130
189
131
190
class SegTransactionStats :
0 commit comments