@@ -67,27 +67,40 @@ def __init__(self, df: pd.DataFrame) -> None:
6767 self .df = df [["customer_id" , "segment_name" , "segment_id" ]].set_index ("customer_id" )
6868
6969
70- class HMLSegmentation (BaseSegmentation ):
71- """Segments customers into Heavy, Medium, Light and Zero spenders based on the total spend ."""
70+ class ThresholdSegmentation (BaseSegmentation ):
71+ """Segments customers based on user-defined thresholds and segments ."""
7272
7373 def __init__ (
7474 self ,
7575 df : pd .DataFrame ,
76+ thresholds : list [float ],
77+ segments : dict [any , str ],
7678 value_col : str = "total_price" ,
79+ agg_func : str = "sum" ,
80+ zero_segment_name : str = "Zero" ,
81+ zero_segment_id : str = "Z" ,
7782 zero_value_customers : Literal ["separate_segment" , "exclude" , "include_with_light" ] = "separate_segment" ,
7883 ) -> None :
79- """Segments customers into Heavy, Medium, Light and Zero spenders based on the total spend .
84+ """Segments customers based on user-defined thresholds and segments .
8085
8186 Args:
8287 df (pd.DataFrame): A dataframe with the transaction data. The dataframe must contain a customer_id column.
83- value_col (str, optional): The column to use for the segmentation. Defaults to "total_price".
88+ thresholds (List[float]): The percentile thresholds for segmentation.
89+ segments (Dict[str, str]): A dictionary where keys are segment IDs and values are segment names.
90+ value_col (str): The column to use for the segmentation.
91+ agg_func (str, optional): The aggregation function to use when grouping by customer_id. Defaults to "sum".
92+ zero_segment_name (str, optional): The name of the segment for customers with zero spend. Defaults to "Zero".
93+ zero_segment_id (str, optional): The ID of the segment for customers with zero spend. Defaults to "Z".
8494 zero_value_customers (Literal["separate_segment", "exclude", "include_with_light"], optional): How to handle
8595 customers with zero spend. Defaults to "separate_segment".
8696
8797 Raises:
8898 ValueError: If the dataframe is missing the columns "customer_id" or `value_col`, or these columns contain
8999 null values.
90100 """
101+ if df .empty :
102+ raise ValueError ("Input DataFrame is empty" )
103+
91104 required_cols = ["customer_id" , value_col ]
92105 contract = CustomContract (
93106 df ,
@@ -99,33 +112,79 @@ def __init__(
99112 msg = f"The dataframe requires the columns { required_cols } and they must be non-null"
100113 raise ValueError (msg )
101114
115+ if len (df ) < len (thresholds ):
116+ msg = f"There are { len (df )} customers, which is less than the number of segment thresholds."
117+ raise ValueError (msg )
118+
119+ if set (thresholds ) != set (thresholds ):
120+ raise ValueError ("The thresholds must be unique." )
121+
122+ thresholds = sorted (thresholds )
123+ if thresholds [0 ] != 0 :
124+ thresholds = [0 , * thresholds ]
125+ if thresholds [- 1 ] != 1 :
126+ thresholds .append (1 )
127+
128+ if len (thresholds ) - 1 != len (segments ):
129+ raise ValueError ("The number of thresholds must match the number of segments." )
130+
102131 # Group by customer_id and calculate total_spend
103- grouped_df = df .groupby ("customer_id" )[value_col ].sum ( ).to_frame (value_col )
132+ grouped_df = df .groupby ("customer_id" )[value_col ].agg ( agg_func ).to_frame (value_col )
104133
105134 # Separate customers with zero spend
106- hml_df = grouped_df
135+ self . df = grouped_df
107136 if zero_value_customers in ["separate_segment" , "exclude" ]:
108137 zero_idx = grouped_df [value_col ] == 0
109- zero_cust_df = grouped_df [zero_idx ]
110- zero_cust_df ["segment_name" ] = "Zero"
138+ zero_cust_df = grouped_df [zero_idx ].copy ()
139+ zero_cust_df ["segment_name" ] = zero_segment_name
140+ zero_cust_df ["segment_id" ] = zero_segment_id
111141
112- hml_df = grouped_df [~ zero_idx ]
142+ self . df = grouped_df [~ zero_idx ]
113143
114144 # Create a new column 'segment' based on the total_spend
115- hml_df ["segment_name" ] = pd .qcut (
116- hml_df [value_col ],
117- q = [0 , 0.500 , 0.800 , 1 ],
118- labels = ["Light" , "Medium" , "Heavy" ],
145+ labels = list (segments .values ())
146+
147+ self .df ["segment_name" ] = pd .qcut (
148+ self .df [value_col ],
149+ q = thresholds ,
150+ labels = labels ,
119151 )
120152
153+ self .df ["segment_id" ] = self .df ["segment_name" ].map ({v : k for k , v in segments .items ()})
154+
121155 if zero_value_customers == "separate_segment" :
122- hml_df = pd .concat ([hml_df , zero_cust_df ])
156+ self . df = pd .concat ([self . df , zero_cust_df ])
123157
124- segment_code_map = {"Light" : "L" , "Medium" : "M" , "Heavy" : "H" , "Zero" : "Z" }
125158
126- hml_df ["segment_id" ] = hml_df ["segment_name" ].map (segment_code_map )
159+ class HMLSegmentation (ThresholdSegmentation ):
160+ """Segments customers into Heavy, Medium, Light and Zero spenders based on the total spend."""
127161
128- self .df = hml_df
162+ def __init__ (
163+ self ,
164+ df : pd .DataFrame ,
165+ value_col : str = "total_price" ,
166+ agg_func : str = "sum" ,
167+ zero_value_customers : Literal ["separate_segment" , "exclude" , "include_with_light" ] = "separate_segment" ,
168+ ) -> None :
169+ """Segments customers into Heavy, Medium, Light and Zero spenders based on the total spend.
170+
171+ Args:
172+ df (pd.DataFrame): A dataframe with the transaction data. The dataframe must contain a customer_id column.
173+ value_col (str, optional): The column to use for the segmentation. Defaults to "total_price".
174+ agg_func (str, optional): The aggregation function to use when grouping by customer_id. Defaults to "sum".
175+ zero_value_customers (Literal["separate_segment", "exclude", "include_with_light"], optional): How to handle
176+ customers with zero spend. Defaults to "separate_segment".
177+ """
178+ thresholds = [0.500 , 0.800 , 1 ]
179+ segments = {"L" : "Light" , "M" : "Medium" , "H" : "Heavy" }
180+ super ().__init__ (
181+ df = df ,
182+ value_col = value_col ,
183+ agg_func = agg_func ,
184+ thresholds = thresholds ,
185+ segments = segments ,
186+ zero_value_customers = zero_value_customers ,
187+ )
129188
130189
131190class SegTransactionStats :
0 commit comments