Skip to content

Commit 448a0dd

Browse files
committed
feat: added treshold segmentation creation
1 parent 79a0829 commit 448a0dd

2 files changed

Lines changed: 235 additions & 21 deletions

File tree

pyretailscience/segmentation.py

Lines changed: 70 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -67,20 +67,30 @@ def __init__(self, df: pd.DataFrame) -> None:
6767
self.df = df[["customer_id", "segment_name", "segment_id"]].set_index("customer_id")
6868

6969

70-
class HMLSegmentation(BaseSegmentation):
71-
"""Segments customers into Heavy, Medium, Light and Zero spenders based on the total spend."""
70+
class ThresholdSegmentation(BaseSegmentation):
71+
"""Segments customers based on user-defined thresholds and segments."""
7272

7373
def __init__(
7474
self,
7575
df: pd.DataFrame,
76+
thresholds: list[float],
77+
segments: dict[any, str],
7678
value_col: str = "total_price",
79+
agg_func: str = "sum",
80+
zero_segment_name: str = "Zero",
81+
zero_segment_id: str = "Z",
7782
zero_value_customers: Literal["separate_segment", "exclude", "include_with_light"] = "separate_segment",
7883
) -> None:
79-
"""Segments customers into Heavy, Medium, Light and Zero spenders based on the total spend.
84+
"""Segments customers based on user-defined thresholds and segments.
8085
8186
Args:
8287
df (pd.DataFrame): A dataframe with the transaction data. The dataframe must contain a customer_id column.
83-
value_col (str, optional): The column to use for the segmentation. Defaults to "total_price".
88+
thresholds (List[float]): The percentile thresholds for segmentation.
89+
segments (Dict[str, str]): A dictionary where keys are segment IDs and values are segment names.
90+
value_col (str): The column to use for the segmentation.
91+
agg_func (str, optional): The aggregation function to use when grouping by customer_id. Defaults to "sum".
92+
zero_segment_name (str, optional): The name of the segment for customers with zero spend. Defaults to "Zero".
93+
zero_segment_id (str, optional): The ID of the segment for customers with zero spend. Defaults to "Z".
8494
zero_value_customers (Literal["separate_segment", "exclude", "include_with_light"], optional): How to handle
8595
customers with zero spend. Defaults to "separate_segment".
8696
@@ -102,38 +112,78 @@ def __init__(
102112
msg = f"The dataframe requires the columns {required_cols} and they must be non-null"
103113
raise ValueError(msg)
104114

105-
hml_cuts = [0.500, 0.800, 1]
106-
if len(df) < len(hml_cuts):
107-
msg = f"There are {len(df)} customers, which is less than is less than the number of segment thresholds."
115+
if len(df) < len(thresholds):
116+
msg = f"There are {len(df)} customers, which is less than the number of segment thresholds."
108117
raise ValueError(msg)
109118

119+
if len(thresholds) != len(segments):
120+
raise ValueError("The number of thresholds must match the number of segments.")
121+
110122
# Group by customer_id and calculate total_spend
111-
grouped_df = df.groupby("customer_id")[value_col].sum().to_frame(value_col)
123+
grouped_df = df.groupby("customer_id")[value_col].agg(agg_func).to_frame(value_col)
112124

113125
# Separate customers with zero spend
114-
hml_df = grouped_df
126+
self.df = grouped_df
115127
if zero_value_customers in ["separate_segment", "exclude"]:
116128
zero_idx = grouped_df[value_col] == 0
117-
zero_cust_df = grouped_df[zero_idx]
118-
zero_cust_df["segment_name"] = "Zero"
129+
zero_cust_df = grouped_df[zero_idx].copy()
130+
zero_cust_df["segment_name"] = zero_segment_name
131+
zero_cust_df["segment_id"] = zero_segment_id
119132

120-
hml_df = grouped_df[~zero_idx]
133+
self.df = grouped_df[~zero_idx]
121134

122135
# Create a new column 'segment' based on the total_spend
123-
hml_df["segment_name"] = pd.qcut(
124-
hml_df[value_col],
125-
q=[0, *hml_cuts],
126-
labels=["Light", "Medium", "Heavy"],
136+
labels = list(segments.values())
137+
q = thresholds
138+
if thresholds[0] != 0:
139+
q = [0, *thresholds]
140+
141+
self.df["segment_name"] = pd.qcut(
142+
self.df[value_col],
143+
q=q,
144+
labels=labels,
127145
)
128146

147+
if self.df["segment_name"].isnull().any():
148+
raise ValueError(
149+
"Based on the tresholds selected some customers weren't segmented.",
150+
"Please check the thresholds cover all the values from 0 to 1",
151+
)
152+
self.df["segment_id"] = self.df["segment_name"].map({v: k for k, v in segments.items()})
153+
129154
if zero_value_customers == "separate_segment":
130-
hml_df = pd.concat([hml_df, zero_cust_df])
155+
self.df = pd.concat([self.df, zero_cust_df])
131156

132-
segment_code_map = {"Light": "L", "Medium": "M", "Heavy": "H", "Zero": "Z"}
133157

134-
hml_df["segment_id"] = hml_df["segment_name"].map(segment_code_map)
158+
class HMLSegmentation(ThresholdSegmentation):
159+
"""Segments customers into Heavy, Medium, Light and Zero spenders based on the total spend."""
160+
161+
def __init__(
162+
self,
163+
df: pd.DataFrame,
164+
value_col: str = "total_price",
165+
agg_func: str = "sum",
166+
zero_value_customers: Literal["separate_segment", "exclude", "include_with_light"] = "separate_segment",
167+
) -> None:
168+
"""Segments customers into Heavy, Medium, Light and Zero spenders based on the total spend.
135169
136-
self.df = hml_df
170+
Args:
171+
df (pd.DataFrame): A dataframe with the transaction data. The dataframe must contain a customer_id column.
172+
value_col (str, optional): The column to use for the segmentation. Defaults to "total_price".
173+
agg_func (str, optional): The aggregation function to use when grouping by customer_id. Defaults to "sum".
174+
zero_value_customers (Literal["separate_segment", "exclude", "include_with_light"], optional): How to handle
175+
customers with zero spend. Defaults to "separate_segment".
176+
"""
177+
thresholds = [0.500, 0.800, 1]
178+
segments = {"L": "Light", "M": "Medium", "H": "Heavy"}
179+
super().__init__(
180+
df=df,
181+
value_col=value_col,
182+
agg_func=agg_func,
183+
thresholds=thresholds,
184+
segments=segments,
185+
zero_value_customers=zero_value_customers,
186+
)
137187

138188

139189
class SegTransactionStats:

tests/test_segmentation.py

Lines changed: 165 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pandas as pd
44
import pytest
55

6-
from pyretailscience.segmentation import HMLSegmentation, SegTransactionStats
6+
from pyretailscience.segmentation import HMLSegmentation, SegTransactionStats, ThresholdSegmentation
77

88

99
class TestCalcSegStats:
@@ -90,6 +90,170 @@ def test_handles_dataframe_with_one_segment(self, base_df):
9090
pd.testing.assert_frame_equal(segment_stats, expected_output)
9191

9292

93+
class TestThresholdSegmentation:
94+
"""Tests for the ThresholdSegmentation class."""
95+
96+
def test_correct_segmentation(self):
97+
"""Test that the method correctly segments customers based on given thresholds and segments."""
98+
df = pd.DataFrame({"customer_id": [1, 2, 3, 4], "total_price": [100, 200, 300, 400]})
99+
thresholds = [0.5, 1]
100+
segments = {0: "Low", 1: "High"}
101+
seg = ThresholdSegmentation(
102+
df=df,
103+
thresholds=thresholds,
104+
segments=segments,
105+
value_col="total_price",
106+
zero_value_customers="exclude",
107+
)
108+
result_df = seg.df
109+
assert result_df.loc[1, "segment_name"] == "Low"
110+
assert result_df.loc[2, "segment_name"] == "Low"
111+
assert result_df.loc[3, "segment_name"] == "High"
112+
assert result_df.loc[4, "segment_name"] == "High"
113+
114+
def test_single_customer(self):
115+
"""Test that the method correctly segments a DataFrame with only one customer."""
116+
df = pd.DataFrame({"customer_id": [1], "total_price": [100]})
117+
thresholds = [0.5, 1]
118+
segments = {0: "Low"}
119+
with pytest.raises(ValueError):
120+
ThresholdSegmentation(
121+
df=df,
122+
thresholds=thresholds,
123+
segments=segments,
124+
)
125+
126+
def test_correct_aggregation_function(self):
127+
"""Test that the correct aggregation function is applied for product_id custom segmentation."""
128+
df = pd.DataFrame(
129+
{
130+
"customer_id": [1, 2, 2, 2, 2, 2, 3, 3, 3, 4, 4, 5, 5, 5],
131+
"product_id": [3, 4, 4, 6, 1, 5, 7, 2, 2, 3, 2, 3, 4, 1],
132+
},
133+
)
134+
value_col = "product_id"
135+
agg_func = "nunique"
136+
137+
my_seg = ThresholdSegmentation(
138+
df=df,
139+
value_col=value_col,
140+
agg_func=agg_func,
141+
thresholds=[0.2, 0.8, 1],
142+
segments={"A": "Low", "B": "Medium", "C": "High"},
143+
zero_value_customers="separate_segment",
144+
)
145+
# 1 - 1 products
146+
# 2 - 4 products
147+
# 3 - 2 products
148+
# 4 - 2 products
149+
# 5 - 3 product
150+
151+
expected_result = pd.DataFrame(
152+
{
153+
"customer_id": [1, 2, 3, 4, 5],
154+
"product_id": [1, 4, 2, 2, 3],
155+
"segment_name": ["Low", "High", "Medium", "Medium", "Medium"],
156+
"segment_id": ["A", "C", "B", "B", "B"],
157+
},
158+
)
159+
expected_result["segment_id"] = pd.Categorical(
160+
expected_result["segment_id"],
161+
categories=["A", "B", "C"],
162+
ordered=True,
163+
)
164+
expected_result["segment_name"] = pd.Categorical(
165+
expected_result["segment_name"],
166+
categories=["Low", "Medium", "High"],
167+
ordered=True,
168+
)
169+
pd.testing.assert_frame_equal(my_seg.df.reset_index(), expected_result)
170+
171+
def test_correctly_checks_segment_data(self):
172+
"""Test that the method correctly merges segment data back into the original DataFrame."""
173+
df = pd.DataFrame(
174+
{
175+
"customer_id": [1, 2, 3, 4, 5],
176+
"total_price": [100, 200, 0, 150, 0],
177+
},
178+
)
179+
value_col = "total_price"
180+
agg_func = "sum"
181+
thresholds = [0.33, 0.66, 1]
182+
segments = {"A": "Low", "B": "Medium", "C": "High"}
183+
zero_value_customers = "separate_segment"
184+
185+
# Create ThresholdSegmentation instance
186+
threshold_seg = ThresholdSegmentation(
187+
df=df,
188+
value_col=value_col,
189+
agg_func=agg_func,
190+
thresholds=thresholds,
191+
segments=segments,
192+
zero_value_customers=zero_value_customers,
193+
)
194+
195+
# Call add_segment method
196+
segmented_df = threshold_seg.add_segment(df)
197+
198+
# Assert the correct segment_name and segment_id
199+
expected_df = pd.DataFrame(
200+
{
201+
"customer_id": [1, 2, 3, 4, 5],
202+
"total_price": [100, 200, 0, 150, 0],
203+
"segment_name": ["Low", "High", "Zero", "Medium", "Zero"],
204+
"segment_id": ["A", "C", "Z", "B", "Z"],
205+
},
206+
)
207+
pd.testing.assert_frame_equal(segmented_df, expected_df)
208+
209+
def test_handles_dataframe_with_duplicate_customer_id_entries(self):
210+
"""Test that the method correctly handles a DataFrame with duplicate customer_id entries."""
211+
df = pd.DataFrame({"customer_id": [1, 2, 3, 1, 2, 3], "total_price": [100, 200, 300, 150, 250, 350]})
212+
213+
my_seg = ThresholdSegmentation(
214+
df=df,
215+
value_col="total_price",
216+
agg_func="sum",
217+
thresholds=[0.5, 0.8, 1],
218+
segments={"L": "Light", "M": "Medium", "H": "Heavy"},
219+
zero_value_customers="include_with_light",
220+
)
221+
222+
result_df = my_seg.add_segment(df)
223+
assert len(result_df) == len(df)
224+
225+
def test_correctly_maps_segment_names_to_segment_ids_with_fixed_thresholds(self):
226+
"""Test that the method correctly maps segment names to segment IDs with fixed thresholds."""
227+
# Setup
228+
df = pd.DataFrame({"customer_id": [1, 2, 3, 4, 5], "total_price": [100, 200, 300, 400, 500]})
229+
value_col = "total_price"
230+
agg_func = "sum"
231+
thresholds = [0.33, 0.66, 1]
232+
segments = {1: "Low", 2: "Medium", 3: "High"}
233+
zero_value_customers = "separate_segment"
234+
235+
my_seg = ThresholdSegmentation(
236+
df=df,
237+
value_col=value_col,
238+
agg_func=agg_func,
239+
thresholds=thresholds,
240+
segments=segments,
241+
zero_value_customers=zero_value_customers,
242+
)
243+
244+
assert len(my_seg.df[["segment_id", "segment_name"]].drop_duplicates()) == len(segments)
245+
assert my_seg.df.set_index("segment_id")["segment_name"].to_dict() == segments
246+
247+
def test_thresholds_do_not_cover_all_values(self):
248+
"""Test that the method raises an error when the thresholds do not cover all values."""
249+
df = pd.DataFrame({"customer_id": [1, 2, 3, 4, 5], "total_price": [100, 200, 300, 400, 500]})
250+
thresholds = [0.5, 0.8] # Missing the upper bound
251+
segments = {1: "Low", 2: "Medium", 3: "High"}
252+
253+
with pytest.raises(ValueError):
254+
ThresholdSegmentation(df, thresholds, segments)
255+
256+
93257
class TestSegTransactionStats:
94258
"""Tests for the SegTransactionStats class."""
95259

0 commit comments

Comments
 (0)