Skip to content

Commit fbf887d

Browse files
authored
Added ThresholdSegmentation class (#58)
* feat: add input validation and tests in HMLSegmentation * feat: added treshold segmentation creation
1 parent d8c9965 commit fbf887d

File tree

2 files changed

+371
-18
lines changed

2 files changed

+371
-18
lines changed

pyretailscience/segmentation.py

Lines changed: 76 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -67,27 +67,40 @@ def __init__(self, df: pd.DataFrame) -> None:
6767
self.df = df[["customer_id", "segment_name", "segment_id"]].set_index("customer_id")
6868

6969

70-
class HMLSegmentation(BaseSegmentation):
71-
"""Segments customers into Heavy, Medium, Light and Zero spenders based on the total spend."""
70+
class ThresholdSegmentation(BaseSegmentation):
71+
"""Segments customers based on user-defined thresholds and segments."""
7272

7373
def __init__(
7474
self,
7575
df: pd.DataFrame,
76+
thresholds: list[float],
77+
segments: dict[any, str],
7678
value_col: str = "total_price",
79+
agg_func: str = "sum",
80+
zero_segment_name: str = "Zero",
81+
zero_segment_id: str = "Z",
7782
zero_value_customers: Literal["separate_segment", "exclude", "include_with_light"] = "separate_segment",
7883
) -> None:
79-
"""Segments customers into Heavy, Medium, Light and Zero spenders based on the total spend.
84+
"""Segments customers based on user-defined thresholds and segments.
8085
8186
Args:
8287
df (pd.DataFrame): A dataframe with the transaction data. The dataframe must contain a customer_id column.
83-
value_col (str, optional): The column to use for the segmentation. Defaults to "total_price".
88+
thresholds (List[float]): The percentile thresholds for segmentation.
89+
segments (Dict[str, str]): A dictionary where keys are segment IDs and values are segment names.
90+
value_col (str): The column to use for the segmentation.
91+
agg_func (str, optional): The aggregation function to use when grouping by customer_id. Defaults to "sum".
92+
zero_segment_name (str, optional): The name of the segment for customers with zero spend. Defaults to "Zero".
93+
zero_segment_id (str, optional): The ID of the segment for customers with zero spend. Defaults to "Z".
8494
zero_value_customers (Literal["separate_segment", "exclude", "include_with_light"], optional): How to handle
8595
customers with zero spend. Defaults to "separate_segment".
8696
8797
Raises:
8898
ValueError: If the dataframe is missing the columns "customer_id" or `value_col`, or these columns contain
8999
null values.
90100
"""
101+
if df.empty:
102+
raise ValueError("Input DataFrame is empty")
103+
91104
required_cols = ["customer_id", value_col]
92105
contract = CustomContract(
93106
df,
@@ -99,33 +112,79 @@ def __init__(
99112
msg = f"The dataframe requires the columns {required_cols} and they must be non-null"
100113
raise ValueError(msg)
101114

115+
if len(df) < len(thresholds):
116+
msg = f"There are {len(df)} customers, which is less than the number of segment thresholds."
117+
raise ValueError(msg)
118+
119+
if set(thresholds) != set(thresholds):
120+
raise ValueError("The thresholds must be unique.")
121+
122+
thresholds = sorted(thresholds)
123+
if thresholds[0] != 0:
124+
thresholds = [0, *thresholds]
125+
if thresholds[-1] != 1:
126+
thresholds.append(1)
127+
128+
if len(thresholds) - 1 != len(segments):
129+
raise ValueError("The number of thresholds must match the number of segments.")
130+
102131
# Group by customer_id and calculate total_spend
103-
grouped_df = df.groupby("customer_id")[value_col].sum().to_frame(value_col)
132+
grouped_df = df.groupby("customer_id")[value_col].agg(agg_func).to_frame(value_col)
104133

105134
# Separate customers with zero spend
106-
hml_df = grouped_df
135+
self.df = grouped_df
107136
if zero_value_customers in ["separate_segment", "exclude"]:
108137
zero_idx = grouped_df[value_col] == 0
109-
zero_cust_df = grouped_df[zero_idx]
110-
zero_cust_df["segment_name"] = "Zero"
138+
zero_cust_df = grouped_df[zero_idx].copy()
139+
zero_cust_df["segment_name"] = zero_segment_name
140+
zero_cust_df["segment_id"] = zero_segment_id
111141

112-
hml_df = grouped_df[~zero_idx]
142+
self.df = grouped_df[~zero_idx]
113143

114144
# Create a new column 'segment' based on the total_spend
115-
hml_df["segment_name"] = pd.qcut(
116-
hml_df[value_col],
117-
q=[0, 0.500, 0.800, 1],
118-
labels=["Light", "Medium", "Heavy"],
145+
labels = list(segments.values())
146+
147+
self.df["segment_name"] = pd.qcut(
148+
self.df[value_col],
149+
q=thresholds,
150+
labels=labels,
119151
)
120152

153+
self.df["segment_id"] = self.df["segment_name"].map({v: k for k, v in segments.items()})
154+
121155
if zero_value_customers == "separate_segment":
122-
hml_df = pd.concat([hml_df, zero_cust_df])
156+
self.df = pd.concat([self.df, zero_cust_df])
123157

124-
segment_code_map = {"Light": "L", "Medium": "M", "Heavy": "H", "Zero": "Z"}
125158

126-
hml_df["segment_id"] = hml_df["segment_name"].map(segment_code_map)
159+
class HMLSegmentation(ThresholdSegmentation):
160+
"""Segments customers into Heavy, Medium, Light and Zero spenders based on the total spend."""
127161

128-
self.df = hml_df
162+
def __init__(
163+
self,
164+
df: pd.DataFrame,
165+
value_col: str = "total_price",
166+
agg_func: str = "sum",
167+
zero_value_customers: Literal["separate_segment", "exclude", "include_with_light"] = "separate_segment",
168+
) -> None:
169+
"""Segments customers into Heavy, Medium, Light and Zero spenders based on the total spend.
170+
171+
Args:
172+
df (pd.DataFrame): A dataframe with the transaction data. The dataframe must contain a customer_id column.
173+
value_col (str, optional): The column to use for the segmentation. Defaults to "total_price".
174+
agg_func (str, optional): The aggregation function to use when grouping by customer_id. Defaults to "sum".
175+
zero_value_customers (Literal["separate_segment", "exclude", "include_with_light"], optional): How to handle
176+
customers with zero spend. Defaults to "separate_segment".
177+
"""
178+
thresholds = [0.500, 0.800, 1]
179+
segments = {"L": "Light", "M": "Medium", "H": "Heavy"}
180+
super().__init__(
181+
df=df,
182+
value_col=value_col,
183+
agg_func=agg_func,
184+
thresholds=thresholds,
185+
segments=segments,
186+
zero_value_customers=zero_value_customers,
187+
)
129188

130189

131190
class SegTransactionStats:

0 commit comments

Comments
 (0)