Skip to content

Commit fa6d7d7

Browse files
committed
feat: added treshold segmentation creation
1 parent 79a0829 commit fa6d7d7

File tree

2 files changed

+259
-21
lines changed

2 files changed

+259
-21
lines changed

pyretailscience/segmentation.py

Lines changed: 71 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -67,20 +67,30 @@ def __init__(self, df: pd.DataFrame) -> None:
6767
self.df = df[["customer_id", "segment_name", "segment_id"]].set_index("customer_id")
6868

6969

70-
class HMLSegmentation(BaseSegmentation):
71-
"""Segments customers into Heavy, Medium, Light and Zero spenders based on the total spend."""
70+
class ThresholdSegmentation(BaseSegmentation):
71+
"""Segments customers based on user-defined thresholds and segments."""
7272

7373
def __init__(
7474
self,
7575
df: pd.DataFrame,
76+
thresholds: list[float],
77+
segments: dict[any, str],
7678
value_col: str = "total_price",
79+
agg_func: str = "sum",
80+
zero_segment_name: str = "Zero",
81+
zero_segment_id: str = "Z",
7782
zero_value_customers: Literal["separate_segment", "exclude", "include_with_light"] = "separate_segment",
7883
) -> None:
79-
"""Segments customers into Heavy, Medium, Light and Zero spenders based on the total spend.
84+
"""Segments customers based on user-defined thresholds and segments.
8085
8186
Args:
8287
df (pd.DataFrame): A dataframe with the transaction data. The dataframe must contain a customer_id column.
83-
value_col (str, optional): The column to use for the segmentation. Defaults to "total_price".
88+
thresholds (List[float]): The percentile thresholds for segmentation.
89+
segments (Dict[str, str]): A dictionary where keys are segment IDs and values are segment names.
90+
value_col (str): The column to use for the segmentation.
91+
agg_func (str, optional): The aggregation function to use when grouping by customer_id. Defaults to "sum".
92+
zero_segment_name (str, optional): The name of the segment for customers with zero spend. Defaults to "Zero".
93+
zero_segment_id (str, optional): The ID of the segment for customers with zero spend. Defaults to "Z".
8494
zero_value_customers (Literal["separate_segment", "exclude", "include_with_light"], optional): How to handle
8595
customers with zero spend. Defaults to "separate_segment".
8696
@@ -102,38 +112,79 @@ def __init__(
102112
msg = f"The dataframe requires the columns {required_cols} and they must be non-null"
103113
raise ValueError(msg)
104114

105-
hml_cuts = [0.500, 0.800, 1]
106-
if len(df) < len(hml_cuts):
107-
msg = f"There are {len(df)} customers, which is less than is less than the number of segment thresholds."
115+
if len(df) < len(thresholds):
116+
msg = f"There are {len(df)} customers, which is less than the number of segment thresholds."
108117
raise ValueError(msg)
109118

119+
if set(thresholds) != set(thresholds):
120+
raise ValueError("The thresholds must be unique.")
121+
122+
thresholds = sorted(thresholds)
123+
if thresholds[0] != 0:
124+
thresholds = [0, *thresholds]
125+
if thresholds[-1] != 1:
126+
thresholds.append(1)
127+
128+
if len(thresholds) - 1 != len(segments):
129+
raise ValueError("The number of thresholds must match the number of segments.")
130+
110131
# Group by customer_id and calculate total_spend
111-
grouped_df = df.groupby("customer_id")[value_col].sum().to_frame(value_col)
132+
grouped_df = df.groupby("customer_id")[value_col].agg(agg_func).to_frame(value_col)
112133

113134
# Separate customers with zero spend
114-
hml_df = grouped_df
135+
self.df = grouped_df
115136
if zero_value_customers in ["separate_segment", "exclude"]:
116137
zero_idx = grouped_df[value_col] == 0
117-
zero_cust_df = grouped_df[zero_idx]
118-
zero_cust_df["segment_name"] = "Zero"
138+
zero_cust_df = grouped_df[zero_idx].copy()
139+
zero_cust_df["segment_name"] = zero_segment_name
140+
zero_cust_df["segment_id"] = zero_segment_id
119141

120-
hml_df = grouped_df[~zero_idx]
142+
self.df = grouped_df[~zero_idx]
121143

122144
# Create a new column 'segment' based on the total_spend
123-
hml_df["segment_name"] = pd.qcut(
124-
hml_df[value_col],
125-
q=[0, *hml_cuts],
126-
labels=["Light", "Medium", "Heavy"],
145+
labels = list(segments.values())
146+
147+
self.df["segment_name"] = pd.qcut(
148+
self.df[value_col],
149+
q=thresholds,
150+
labels=labels,
127151
)
128152

153+
self.df["segment_id"] = self.df["segment_name"].map({v: k for k, v in segments.items()})
154+
129155
if zero_value_customers == "separate_segment":
130-
hml_df = pd.concat([hml_df, zero_cust_df])
156+
self.df = pd.concat([self.df, zero_cust_df])
157+
131158

132-
segment_code_map = {"Light": "L", "Medium": "M", "Heavy": "H", "Zero": "Z"}
159+
class HMLSegmentation(ThresholdSegmentation):
160+
"""Segments customers into Heavy, Medium, Light and Zero spenders based on the total spend."""
133161

134-
hml_df["segment_id"] = hml_df["segment_name"].map(segment_code_map)
162+
def __init__(
163+
self,
164+
df: pd.DataFrame,
165+
value_col: str = "total_price",
166+
agg_func: str = "sum",
167+
zero_value_customers: Literal["separate_segment", "exclude", "include_with_light"] = "separate_segment",
168+
) -> None:
169+
"""Segments customers into Heavy, Medium, Light and Zero spenders based on the total spend.
135170
136-
self.df = hml_df
171+
Args:
172+
df (pd.DataFrame): A dataframe with the transaction data. The dataframe must contain a customer_id column.
173+
value_col (str, optional): The column to use for the segmentation. Defaults to "total_price".
174+
agg_func (str, optional): The aggregation function to use when grouping by customer_id. Defaults to "sum".
175+
zero_value_customers (Literal["separate_segment", "exclude", "include_with_light"], optional): How to handle
176+
customers with zero spend. Defaults to "separate_segment".
177+
"""
178+
thresholds = [0.500, 0.800, 1]
179+
segments = {"L": "Light", "M": "Medium", "H": "Heavy"}
180+
super().__init__(
181+
df=df,
182+
value_col=value_col,
183+
agg_func=agg_func,
184+
thresholds=thresholds,
185+
segments=segments,
186+
zero_value_customers=zero_value_customers,
187+
)
137188

138189

139190
class SegTransactionStats:

tests/test_segmentation.py

Lines changed: 188 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pandas as pd
44
import pytest
55

6-
from pyretailscience.segmentation import HMLSegmentation, SegTransactionStats
6+
from pyretailscience.segmentation import HMLSegmentation, SegTransactionStats, ThresholdSegmentation
77

88

99
class TestCalcSegStats:
@@ -90,6 +90,193 @@ def test_handles_dataframe_with_one_segment(self, base_df):
9090
pd.testing.assert_frame_equal(segment_stats, expected_output)
9191

9292

93+
class TestThresholdSegmentation:
94+
"""Tests for the ThresholdSegmentation class."""
95+
96+
def test_correct_segmentation(self):
97+
"""Test that the method correctly segments customers based on given thresholds and segments."""
98+
df = pd.DataFrame({"customer_id": [1, 2, 3, 4], "total_price": [100, 200, 300, 400]})
99+
thresholds = [0.5, 1]
100+
segments = {0: "Low", 1: "High"}
101+
seg = ThresholdSegmentation(
102+
df=df,
103+
thresholds=thresholds,
104+
segments=segments,
105+
value_col="total_price",
106+
zero_value_customers="exclude",
107+
)
108+
result_df = seg.df
109+
assert result_df.loc[1, "segment_name"] == "Low"
110+
assert result_df.loc[2, "segment_name"] == "Low"
111+
assert result_df.loc[3, "segment_name"] == "High"
112+
assert result_df.loc[4, "segment_name"] == "High"
113+
114+
def test_single_customer(self):
115+
"""Test that the method correctly segments a DataFrame with only one customer."""
116+
df = pd.DataFrame({"customer_id": [1], "total_price": [100]})
117+
thresholds = [0.5, 1]
118+
segments = {0: "Low"}
119+
with pytest.raises(ValueError):
120+
ThresholdSegmentation(
121+
df=df,
122+
thresholds=thresholds,
123+
segments=segments,
124+
)
125+
126+
def test_correct_aggregation_function(self):
127+
"""Test that the correct aggregation function is applied for product_id custom segmentation."""
128+
df = pd.DataFrame(
129+
{
130+
"customer_id": [1, 2, 2, 2, 2, 2, 3, 3, 3, 4, 4, 5, 5, 5],
131+
"product_id": [3, 4, 4, 6, 1, 5, 7, 2, 2, 3, 2, 3, 4, 1],
132+
},
133+
)
134+
value_col = "product_id"
135+
agg_func = "nunique"
136+
137+
my_seg = ThresholdSegmentation(
138+
df=df,
139+
value_col=value_col,
140+
agg_func=agg_func,
141+
thresholds=[0.2, 0.8, 1],
142+
segments={"A": "Low", "B": "Medium", "C": "High"},
143+
zero_value_customers="separate_segment",
144+
)
145+
146+
expected_result = pd.DataFrame(
147+
{
148+
"customer_id": [1, 2, 3, 4, 5],
149+
"product_id": [1, 4, 2, 2, 3],
150+
"segment_name": ["Low", "High", "Medium", "Medium", "Medium"],
151+
"segment_id": ["A", "C", "B", "B", "B"],
152+
},
153+
)
154+
expected_result["segment_id"] = pd.Categorical(
155+
expected_result["segment_id"],
156+
categories=["A", "B", "C"],
157+
ordered=True,
158+
)
159+
expected_result["segment_name"] = pd.Categorical(
160+
expected_result["segment_name"],
161+
categories=["Low", "Medium", "High"],
162+
ordered=True,
163+
)
164+
pd.testing.assert_frame_equal(my_seg.df.reset_index(), expected_result)
165+
166+
def test_correctly_checks_segment_data(self):
167+
"""Test that the method correctly merges segment data back into the original DataFrame."""
168+
df = pd.DataFrame(
169+
{
170+
"customer_id": [1, 2, 3, 4, 5],
171+
"total_price": [100, 200, 0, 150, 0],
172+
},
173+
)
174+
value_col = "total_price"
175+
agg_func = "sum"
176+
thresholds = [0.33, 0.66, 1]
177+
segments = {"A": "Low", "B": "Medium", "C": "High"}
178+
zero_value_customers = "separate_segment"
179+
180+
# Create ThresholdSegmentation instance
181+
threshold_seg = ThresholdSegmentation(
182+
df=df,
183+
value_col=value_col,
184+
agg_func=agg_func,
185+
thresholds=thresholds,
186+
segments=segments,
187+
zero_value_customers=zero_value_customers,
188+
)
189+
190+
# Call add_segment method
191+
segmented_df = threshold_seg.add_segment(df)
192+
193+
# Assert the correct segment_name and segment_id
194+
expected_df = pd.DataFrame(
195+
{
196+
"customer_id": [1, 2, 3, 4, 5],
197+
"total_price": [100, 200, 0, 150, 0],
198+
"segment_name": ["Low", "High", "Zero", "Medium", "Zero"],
199+
"segment_id": ["A", "C", "Z", "B", "Z"],
200+
},
201+
)
202+
pd.testing.assert_frame_equal(segmented_df, expected_df)
203+
204+
def test_handles_dataframe_with_duplicate_customer_id_entries(self):
205+
"""Test that the method correctly handles a DataFrame with duplicate customer_id entries."""
206+
df = pd.DataFrame({"customer_id": [1, 2, 3, 1, 2, 3], "total_price": [100, 200, 300, 150, 250, 350]})
207+
208+
my_seg = ThresholdSegmentation(
209+
df=df,
210+
value_col="total_price",
211+
agg_func="sum",
212+
thresholds=[0.5, 0.8, 1],
213+
segments={"L": "Light", "M": "Medium", "H": "Heavy"},
214+
zero_value_customers="include_with_light",
215+
)
216+
217+
result_df = my_seg.add_segment(df)
218+
assert len(result_df) == len(df)
219+
220+
def test_correctly_maps_segment_names_to_segment_ids_with_fixed_thresholds(self):
221+
"""Test that the method correctly maps segment names to segment IDs with fixed thresholds."""
222+
# Setup
223+
df = pd.DataFrame({"customer_id": [1, 2, 3, 4, 5], "total_price": [100, 200, 300, 400, 500]})
224+
value_col = "total_price"
225+
agg_func = "sum"
226+
thresholds = [0.33, 0.66, 1]
227+
segments = {1: "Low", 2: "Medium", 3: "High"}
228+
zero_value_customers = "separate_segment"
229+
230+
my_seg = ThresholdSegmentation(
231+
df=df,
232+
value_col=value_col,
233+
agg_func=agg_func,
234+
thresholds=thresholds,
235+
segments=segments,
236+
zero_value_customers=zero_value_customers,
237+
)
238+
239+
assert len(my_seg.df[["segment_id", "segment_name"]].drop_duplicates()) == len(segments)
240+
assert my_seg.df.set_index("segment_id")["segment_name"].to_dict() == segments
241+
242+
def test_thresholds_not_unique(self):
243+
"""Test that the method raises an error when the thresholds are not unique."""
244+
df = pd.DataFrame({"customer_id": [1, 2, 3, 4, 5], "total_price": [100, 200, 300, 400, 500]})
245+
thresholds = [0.5, 0.5, 0.8, 1]
246+
segments = {1: "Low", 2: "Medium", 3: "High"}
247+
248+
with pytest.raises(ValueError):
249+
ThresholdSegmentation(df, thresholds, segments)
250+
251+
def test_thresholds_too_few_segments(self):
252+
"""Test that the method raises an error when there are too few/many segments for the number of thresholds."""
253+
df = pd.DataFrame({"customer_id": [1, 2, 3, 4, 5], "total_price": [100, 200, 300, 400, 500]})
254+
thresholds = [0.4, 0.6, 0.8, 1]
255+
segments = {1: "Low", 3: "High"}
256+
257+
with pytest.raises(ValueError):
258+
ThresholdSegmentation(df, thresholds, segments)
259+
260+
segments = {1: "Low", 2: "Medium", 3: "High"}
261+
262+
with pytest.raises(ValueError):
263+
ThresholdSegmentation(df, thresholds, segments)
264+
265+
def test_thresholds_too_too_few_thresholds(self):
266+
"""Test that the method raises an error when there are too few/many thresholds for the number of segments."""
267+
df = pd.DataFrame({"customer_id": [1, 2, 3, 4, 5], "total_price": [100, 200, 300, 400, 500]})
268+
thresholds = [0.4, 1]
269+
segments = {1: "Low", 2: "Medium", 3: "High"}
270+
271+
with pytest.raises(ValueError):
272+
ThresholdSegmentation(df, thresholds, segments)
273+
274+
thresholds = [0.2, 0.5, 0.6, 0.8, 1]
275+
276+
with pytest.raises(ValueError):
277+
ThresholdSegmentation(df, thresholds, segments)
278+
279+
93280
class TestSegTransactionStats:
94281
"""Tests for the SegTransactionStats class."""
95282

0 commit comments

Comments
 (0)