Skip to content

Commit d88fc65

Browse files
committed
feat(tests): integrate dynamic column names using get_option in segmentation code and related tests
1 parent 347823a commit d88fc65

File tree

2 files changed

+133
-71
lines changed

2 files changed

+133
-71
lines changed

pyretailscience/segmentation.py

Lines changed: 50 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
build_expected_unique_columns,
1313
build_non_null_columns,
1414
)
15+
from pyretailscience.options import get_option
1516
from pyretailscience.style.tailwind import COLORS
1617

1718

@@ -31,7 +32,12 @@ def add_segment(self, df: pd.DataFrame) -> pd.DataFrame:
3132
ValueError: If the number of rows before and after the merge do not match.
3233
"""
3334
rows_before = len(df)
34-
df = df.merge(self.df[["segment_name", "segment_id"]], how="left", left_on="customer_id", right_index=True)
35+
df = df.merge(
36+
self.df[["segment_name", "segment_id"]],
37+
how="left",
38+
left_on=get_option("column.customer_id"),
39+
right_index=True,
40+
)
3541
rows_after = len(df)
3642
if rows_before != rows_after:
3743
raise ValueError("The number of rows before and after the merge do not match. This should not happen.")
@@ -51,7 +57,7 @@ def __init__(self, df: pd.DataFrame) -> None:
5157
Raises:
5258
ValueError: If the dataframe does not have the columns customer_id, segment_name and segment_id.
5359
"""
54-
required_cols = "customer_id", "segment_name", "segment_id"
60+
required_cols = get_option("column.customer_id"), "segment_name", "segment_id"
5561
contract = CustomContract(
5662
df,
5763
basic_expectations=build_expected_columns(columns=required_cols),
@@ -63,7 +69,9 @@ def __init__(self, df: pd.DataFrame) -> None:
6369
msg = f"The dataframe requires the columns {required_cols} and they must be non-null and unique."
6470
raise ValueError(msg)
6571

66-
self.df = df[["customer_id", "segment_name", "segment_id"]].set_index("customer_id")
72+
self.df = df[[get_option("column.customer_id"), "segment_name", "segment_id"]].set_index(
73+
get_option("column.customer_id"),
74+
)
6775

6876

6977
class ThresholdSegmentation(BaseSegmentation):
@@ -74,7 +82,7 @@ def __init__(
7482
df: pd.DataFrame,
7583
thresholds: list[float],
7684
segments: dict[any, str],
77-
value_col: str = "total_price",
85+
value_col: str | None = None,
7886
agg_func: str = "sum",
7987
zero_segment_name: str = "Zero",
8088
zero_segment_id: str = "Z",
@@ -86,7 +94,7 @@ def __init__(
8694
df (pd.DataFrame): A dataframe with the transaction data. The dataframe must contain a customer_id column.
8795
thresholds (List[float]): The percentile thresholds for segmentation.
8896
segments (Dict[str, str]): A dictionary where keys are segment IDs and values are segment names.
89-
value_col (str): The column to use for the segmentation.
97+
value_col (str, optional): The column to use for the segmentation. Defaults to get_option("column.unit_spend").
9098
agg_func (str, optional): The aggregation function to use when grouping by customer_id. Defaults to "sum".
9199
zero_segment_name (str, optional): The name of the segment for customers with zero spend. Defaults to "Zero".
92100
zero_segment_id (str, optional): The ID of the segment for customers with zero spend. Defaults to "Z".
@@ -100,7 +108,9 @@ def __init__(
100108
if df.empty:
101109
raise ValueError("Input DataFrame is empty")
102110

103-
required_cols = ["customer_id", value_col]
111+
value_col = get_option("column.unit_spend") if value_col is None else value_col
112+
113+
required_cols = [get_option("column.customer_id"), value_col]
104114
contract = CustomContract(
105115
df,
106116
basic_expectations=build_expected_columns(columns=required_cols),
@@ -128,7 +138,7 @@ def __init__(
128138
raise ValueError("The number of thresholds must match the number of segments.")
129139

130140
# Group by customer_id and calculate total_spend
131-
grouped_df = df.groupby("customer_id")[value_col].agg(agg_func).to_frame(value_col)
141+
grouped_df = df.groupby(get_option("column.customer_id"))[value_col].agg(agg_func).to_frame(value_col)
132142

133143
# Separate customers with zero spend
134144
self.df = grouped_df
@@ -138,7 +148,7 @@ def __init__(
138148
zero_cust_df["segment_name"] = zero_segment_name
139149
zero_cust_df["segment_id"] = zero_segment_id
140150

141-
self.df = grouped_df[~zero_idx]
151+
self.df = grouped_df[~zero_idx].copy()
142152

143153
# Create a new column 'segment' based on the total_spend
144154
labels = list(segments.values())
@@ -161,20 +171,25 @@ class HMLSegmentation(ThresholdSegmentation):
161171
def __init__(
162172
self,
163173
df: pd.DataFrame,
164-
value_col: str = "total_price",
174+
value_col: str | None = None,
165175
agg_func: str = "sum",
166176
zero_value_customers: Literal["separate_segment", "exclude", "include_with_light"] = "separate_segment",
167177
) -> None:
168178
"""Segments customers into Heavy, Medium, Light and Zero spenders based on the total spend.
169179
180+
HMLSegmentation is a subclass of ThresholdSegmentation and based around an industry standard definition. The
181+
thresholds for Heavy (top 20%), Medium (next 30%) and Light (bottom 50%) are chosen based on the pareto
182+
distribution, commonly know as the 80/20 rule. It is typically used in retail to segment customers based on
183+
their spend, transaction volume or quantities purchased.
184+
170185
Args:
171186
df (pd.DataFrame): A dataframe with the transaction data. The dataframe must contain a customer_id column.
172-
value_col (str, optional): The column to use for the segmentation. Defaults to "total_price".
187+
value_col (str, optional): The column to use for the segmentation. Defaults to get_option("column.unit_spend").
173188
agg_func (str, optional): The aggregation function to use when grouping by customer_id. Defaults to "sum".
174189
zero_value_customers (Literal["separate_segment", "exclude", "include_with_light"], optional): How to handle
175190
customers with zero spend. Defaults to "separate_segment".
176191
"""
177-
thresholds = [0.500, 0.800, 1]
192+
thresholds = [0, 0.500, 0.800, 1]
178193
segments = {"L": "Light", "M": "Medium", "H": "Heavy"}
179194
super().__init__(
180195
df=df,
@@ -202,9 +217,14 @@ def __init__(self, df: pd.DataFrame, segment_col: str = "segment_id") -> None:
202217
TransactionLevelContract.
203218
204219
"""
205-
required_cols = ["customer_id", "total_price", "transaction_id", segment_col]
206-
if "quantity" in df.columns:
207-
required_cols.append("quantity")
220+
required_cols = [
221+
get_option("column.customer_id"),
222+
get_option("column.unit_spend"),
223+
get_option("column.transaction_id"),
224+
segment_col,
225+
]
226+
if get_option("column.unit_quantity") in df.columns:
227+
required_cols.append(get_option("column.unit_quantity"))
208228
contract = CustomContract(
209229
df,
210230
basic_expectations=build_expected_columns(columns=required_cols),
@@ -222,18 +242,18 @@ def __init__(self, df: pd.DataFrame, segment_col: str = "segment_id") -> None:
222242
@staticmethod
223243
def _calc_seg_stats(df: pd.DataFrame, segment_col: str) -> pd.DataFrame:
224244
aggs = {
225-
"revenue": ("total_price", "sum"),
226-
"transactions": ("transaction_id", "nunique"),
227-
"customers": ("customer_id", "nunique"),
245+
get_option("column.agg.unit_spend"): (get_option("column.unit_spend"), "sum"),
246+
get_option("column.agg.transaction_id"): (get_option("column.transaction_id"), "nunique"),
247+
get_option("column.agg.customer_id"): (get_option("column.customer_id"), "nunique"),
228248
}
229249
total_aggs = {
230-
"revenue": [df["total_price"].sum()],
231-
"transactions": [df["transaction_id"].nunique()],
232-
"customers": [df["customer_id"].nunique()],
250+
get_option("column.agg.unit_spend"): [df[get_option("column.unit_spend")].sum()],
251+
get_option("column.agg.transaction_id"): [df[get_option("column.transaction_id")].nunique()],
252+
get_option("column.agg.customer_id"): [df[get_option("column.customer_id")].nunique()],
233253
}
234-
if "quantity" in df.columns:
235-
aggs["total_quantity"] = ("quantity", "sum")
236-
total_aggs["total_quantity"] = [df["quantity"].sum()]
254+
if get_option("column.unit_quantity") in df.columns:
255+
aggs[get_option("column.agg.unit_quantity")] = (get_option("column.unit_quantity"), "sum")
256+
total_aggs[get_option("column.agg.unit_quantity")] = [df[get_option("column.unit_quantity")].sum()]
237257

238258
stats_df = pd.concat(
239259
[
@@ -242,9 +262,13 @@ def _calc_seg_stats(df: pd.DataFrame, segment_col: str) -> pd.DataFrame:
242262
],
243263
)
244264

245-
if "quantity" in df.columns:
246-
stats_df["price_per_unit"] = stats_df["revenue"] / stats_df["total_quantity"]
247-
stats_df["quantity_per_transaction"] = stats_df["total_quantity"] / stats_df["transactions"]
265+
if get_option("column.unit_quantity") in df.columns:
266+
stats_df[get_option("column.calc.price_per_unit")] = (
267+
stats_df[get_option("column.agg.unit_spend")] / stats_df[get_option("column.agg.unit_quantity")]
268+
)
269+
stats_df[get_option("column.calc.units_per_transaction")] = (
270+
stats_df[get_option("column.agg.unit_quantity")] / stats_df[get_option("column.agg.transaction_id")]
271+
)
248272

249273
return stats_df
250274

0 commit comments

Comments
 (0)