Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 76 additions & 17 deletions pyretailscience/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,27 +67,40 @@ def __init__(self, df: pd.DataFrame) -> None:
self.df = df[["customer_id", "segment_name", "segment_id"]].set_index("customer_id")


class HMLSegmentation(BaseSegmentation):
"""Segments customers into Heavy, Medium, Light and Zero spenders based on the total spend."""
class ThresholdSegmentation(BaseSegmentation):
"""Segments customers based on user-defined thresholds and segments."""

def __init__(
self,
df: pd.DataFrame,
thresholds: list[float],
segments: dict[any, str],
value_col: str = "total_price",
agg_func: str = "sum",
zero_segment_name: str = "Zero",
zero_segment_id: str = "Z",
zero_value_customers: Literal["separate_segment", "exclude", "include_with_light"] = "separate_segment",
) -> None:
"""Segments customers into Heavy, Medium, Light and Zero spenders based on the total spend.
"""Segments customers based on user-defined thresholds and segments.

Args:
df (pd.DataFrame): A dataframe with the transaction data. The dataframe must contain a customer_id column.
value_col (str, optional): The column to use for the segmentation. Defaults to "total_price".
thresholds (List[float]): The percentile thresholds for segmentation.
segments (Dict[str, str]): A dictionary where keys are segment IDs and values are segment names.
value_col (str): The column to use for the segmentation.
agg_func (str, optional): The aggregation function to use when grouping by customer_id. Defaults to "sum".
zero_segment_name (str, optional): The name of the segment for customers with zero spend. Defaults to "Zero".
zero_segment_id (str, optional): The ID of the segment for customers with zero spend. Defaults to "Z".
zero_value_customers (Literal["separate_segment", "exclude", "include_with_light"], optional): How to handle
customers with zero spend. Defaults to "separate_segment".

Raises:
ValueError: If the dataframe is missing the columns "customer_id" or `value_col`, or these columns contain
null values.
"""
if df.empty:
raise ValueError("Input DataFrame is empty")

required_cols = ["customer_id", value_col]
contract = CustomContract(
df,
Expand All @@ -99,33 +112,79 @@ def __init__(
msg = f"The dataframe requires the columns {required_cols} and they must be non-null"
raise ValueError(msg)

if len(df) < len(thresholds):
msg = f"There are {len(df)} customers, which is less than the number of segment thresholds."
raise ValueError(msg)

if set(thresholds) != set(thresholds):
raise ValueError("The thresholds must be unique.")

thresholds = sorted(thresholds)
if thresholds[0] != 0:
thresholds = [0, *thresholds]
if thresholds[-1] != 1:
thresholds.append(1)

if len(thresholds) - 1 != len(segments):
raise ValueError("The number of thresholds must match the number of segments.")

# Group by customer_id and calculate total_spend
grouped_df = df.groupby("customer_id")[value_col].sum().to_frame(value_col)
grouped_df = df.groupby("customer_id")[value_col].agg(agg_func).to_frame(value_col)

# Separate customers with zero spend
hml_df = grouped_df
self.df = grouped_df
if zero_value_customers in ["separate_segment", "exclude"]:
zero_idx = grouped_df[value_col] == 0
zero_cust_df = grouped_df[zero_idx]
zero_cust_df["segment_name"] = "Zero"
zero_cust_df = grouped_df[zero_idx].copy()
zero_cust_df["segment_name"] = zero_segment_name
zero_cust_df["segment_id"] = zero_segment_id

hml_df = grouped_df[~zero_idx]
self.df = grouped_df[~zero_idx]

# Create a new column 'segment' based on the total_spend
hml_df["segment_name"] = pd.qcut(
hml_df[value_col],
q=[0, 0.500, 0.800, 1],
labels=["Light", "Medium", "Heavy"],
labels = list(segments.values())

self.df["segment_name"] = pd.qcut(
self.df[value_col],
q=thresholds,
labels=labels,
)

self.df["segment_id"] = self.df["segment_name"].map({v: k for k, v in segments.items()})

if zero_value_customers == "separate_segment":
hml_df = pd.concat([hml_df, zero_cust_df])
self.df = pd.concat([self.df, zero_cust_df])

segment_code_map = {"Light": "L", "Medium": "M", "Heavy": "H", "Zero": "Z"}

hml_df["segment_id"] = hml_df["segment_name"].map(segment_code_map)
class HMLSegmentation(ThresholdSegmentation):
"""Segments customers into Heavy, Medium, Light and Zero spenders based on the total spend."""

self.df = hml_df
def __init__(
self,
df: pd.DataFrame,
value_col: str = "total_price",
agg_func: str = "sum",
zero_value_customers: Literal["separate_segment", "exclude", "include_with_light"] = "separate_segment",
) -> None:
"""Segments customers into Heavy, Medium, Light and Zero spenders based on the total spend.

Args:
df (pd.DataFrame): A dataframe with the transaction data. The dataframe must contain a customer_id column.
value_col (str, optional): The column to use for the segmentation. Defaults to "total_price".
agg_func (str, optional): The aggregation function to use when grouping by customer_id. Defaults to "sum".
zero_value_customers (Literal["separate_segment", "exclude", "include_with_light"], optional): How to handle
customers with zero spend. Defaults to "separate_segment".
"""
thresholds = [0.500, 0.800, 1]
segments = {"L": "Light", "M": "Medium", "H": "Heavy"}
super().__init__(
df=df,
value_col=value_col,
agg_func=agg_func,
thresholds=thresholds,
segments=segments,
zero_value_customers=zero_value_customers,
)


class SegTransactionStats:
Expand Down
Loading