Skip to content

Commit 79a0829

Browse files
committed
feat: add input validation and tests in HMLSegmentation
1 parent d8c9965 commit 79a0829

File tree

2 files changed

+117
-2
lines changed

2 files changed

+117
-2
lines changed

pyretailscience/segmentation.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,9 @@ def __init__(
8888
ValueError: If the dataframe is missing the columns "customer_id" or `value_col`, or these columns contain
8989
null values.
9090
"""
91+
if df.empty:
92+
raise ValueError("Input DataFrame is empty")
93+
9194
required_cols = ["customer_id", value_col]
9295
contract = CustomContract(
9396
df,
@@ -99,6 +102,11 @@ def __init__(
99102
msg = f"The dataframe requires the columns {required_cols} and they must be non-null"
100103
raise ValueError(msg)
101104

105+
hml_cuts = [0.500, 0.800, 1]
106+
if len(df) < len(hml_cuts):
107+
msg = f"There are {len(df)} customers, which is less than is less than the number of segment thresholds."
108+
raise ValueError(msg)
109+
102110
# Group by customer_id and calculate total_spend
103111
grouped_df = df.groupby("customer_id")[value_col].sum().to_frame(value_col)
104112

@@ -114,7 +122,7 @@ def __init__(
114122
# Create a new column 'segment' based on the total_spend
115123
hml_df["segment_name"] = pd.qcut(
116124
hml_df[value_col],
117-
q=[0, 0.500, 0.800, 1],
125+
q=[0, *hml_cuts],
118126
labels=["Light", "Medium", "Heavy"],
119127
)
120128

tests/test_segmentation.py

Lines changed: 108 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pandas as pd
44
import pytest
55

6-
from pyretailscience.segmentation import SegTransactionStats
6+
from pyretailscience.segmentation import HMLSegmentation, SegTransactionStats
77

88

99
class TestCalcSegStats:
@@ -99,3 +99,110 @@ def test_handles_empty_dataframe_with_errors(self):
9999

100100
with pytest.raises(ValueError):
101101
SegTransactionStats(df, "segment_id")
102+
103+
104+
class TestHMLSegmentation:
105+
"""Tests for the HMLSegmentation class."""
106+
107+
@pytest.fixture()
108+
def base_df(self):
109+
"""Return a base DataFrame for testing."""
110+
return pd.DataFrame({"customer_id": [1, 2, 3, 4, 5], "total_price": [1000, 200, 0, 500, 300]})
111+
112+
def test_no_transactions(self):
113+
"""Test that the method raises an error when there are no transactions."""
114+
data = {"customer_id": [], "total_price": []}
115+
df = pd.DataFrame(data)
116+
with pytest.raises(ValueError):
117+
HMLSegmentation(df)
118+
119+
# Correctly handles zero spend customers when zero_value_customers is "exclude"
120+
def test_handles_zero_spend_customers_are_excluded_in_result(self, base_df):
121+
"""Test that the method correctly handles zero spend customers when zero_value_customers is "exclude"."""
122+
hml_segmentation = HMLSegmentation(base_df, zero_value_customers="exclude")
123+
result_df = hml_segmentation.df
124+
125+
zero_spend_customer_id = 3
126+
127+
assert result_df.loc[1, "segment_name"] == "Heavy"
128+
assert result_df.loc[1, "segment_id"] == "H"
129+
assert result_df.loc[2, "segment_name"] == "Light"
130+
assert result_df.loc[2, "segment_id"] == "L"
131+
assert zero_spend_customer_id not in result_df.index
132+
assert result_df.loc[4, "segment_name"] == "Medium"
133+
assert result_df.loc[4, "segment_id"] == "M"
134+
assert result_df.loc[5, "segment_name"] == "Light"
135+
assert result_df.loc[5, "segment_id"] == "L"
136+
137+
# Correctly handles zero spend customers when zero_value_customers is "include_with_light"
138+
def test_handles_zero_spend_customers_include_with_light(self, base_df):
139+
"""Test that the method correctly handles zero spend customers when zero_value_customers is "include_with_light"."""
140+
hml_segmentation = HMLSegmentation(base_df, zero_value_customers="include_with_light")
141+
result_df = hml_segmentation.df
142+
143+
assert result_df.loc[1, "segment_name"] == "Heavy"
144+
assert result_df.loc[1, "segment_id"] == "H"
145+
assert result_df.loc[2, "segment_name"] == "Light"
146+
assert result_df.loc[2, "segment_id"] == "L"
147+
assert result_df.loc[3, "segment_name"] == "Light"
148+
assert result_df.loc[3, "segment_id"] == "L"
149+
assert result_df.loc[4, "segment_name"] == "Medium"
150+
assert result_df.loc[4, "segment_id"] == "M"
151+
assert result_df.loc[5, "segment_name"] == "Light"
152+
assert result_df.loc[5, "segment_id"] == "L"
153+
154+
# Correctly handles zero spend customers when zero_value_customers is "separate_segment"
155+
def test_handles_zero_spend_customers_separate_segment(self, base_df):
156+
"""Test that the method correctly handles zero spend customers when zero_value_customers is "separate_segment"."""
157+
hml_segmentation = HMLSegmentation(base_df, zero_value_customers="separate_segment")
158+
result_df = hml_segmentation.df
159+
160+
assert result_df.loc[1, "segment_name"] == "Heavy"
161+
assert result_df.loc[1, "segment_id"] == "H"
162+
assert result_df.loc[2, "segment_name"] == "Light"
163+
assert result_df.loc[2, "segment_id"] == "L"
164+
assert result_df.loc[3, "segment_name"] == "Zero"
165+
assert result_df.loc[3, "segment_id"] == "Z"
166+
assert result_df.loc[4, "segment_name"] == "Medium"
167+
assert result_df.loc[4, "segment_id"] == "M"
168+
assert result_df.loc[5, "segment_name"] == "Light"
169+
assert result_df.loc[5, "segment_id"] == "L"
170+
171+
# Raises ValueError if required columns are missing
172+
def test_raises_value_error_if_required_columns_missing(self, base_df):
173+
"""Test that the method raises an error when the DataFrame is missing a required column."""
174+
with pytest.raises(ValueError):
175+
HMLSegmentation(base_df.drop(columns=["customer_id"]))
176+
177+
# DataFrame with only one customer
178+
def test_segments_customer_single(self):
179+
"""Test that the method correctly segments a DataFrame with only one customer."""
180+
data = {"customer_id": [1], "total_price": [0]}
181+
df = pd.DataFrame(data)
182+
with pytest.raises(ValueError):
183+
HMLSegmentation(df)
184+
185+
# Validate that the input dataframe is not changed
186+
def test_input_dataframe_not_changed(self, base_df):
187+
"""Test that the method does not alter the original DataFrame."""
188+
original_df = base_df.copy()
189+
190+
hml_segmentation = HMLSegmentation(base_df)
191+
_ = hml_segmentation.df
192+
193+
assert original_df.equals(base_df) # Check if the original dataframe is not changed
194+
195+
def test_alternate_value_col(self, base_df):
196+
"""Test that the method correctly segments a DataFrame with an alternate value column."""
197+
base_df = base_df.rename(columns={"total_price": "quantity"})
198+
hml_segmentation = HMLSegmentation(base_df, value_col="quantity")
199+
result_df = hml_segmentation.df
200+
201+
assert result_df.loc[1, "segment_name"] == "Heavy"
202+
assert result_df.loc[1, "segment_id"] == "H"
203+
assert result_df.loc[2, "segment_name"] == "Light"
204+
assert result_df.loc[2, "segment_id"] == "L"
205+
assert result_df.loc[4, "segment_name"] == "Medium"
206+
assert result_df.loc[4, "segment_id"] == "M"
207+
assert result_df.loc[5, "segment_name"] == "Light"
208+
assert result_df.loc[5, "segment_id"] == "L"

0 commit comments

Comments
 (0)