diff --git a/docs/analysis_modules.md b/docs/analysis_modules.md index 7a86777..7c3f30c 100644 --- a/docs/analysis_modules.md +++ b/docs/analysis_modules.md @@ -681,7 +681,7 @@ Example: ```python from pyretailscience.plots import bar -from pyretailscience.analysis.segmentation import HMLSegmentation +from pyretailscience.segmentation.hml import HMLSegmentation seg = HMLSegmentation(df, zero_value_customers="include_with_light") @@ -724,7 +724,7 @@ Example: ```python from pyretailscience.plots import bar -from pyretailscience.analysis.segmentation import ThresholdSegmentation +from pyretailscience.segmentation.threshold import ThresholdSegmentation # Create custom segmentation with quartiles # Define thresholds at 25%, 50%, 75%, and 100% (quartiles) @@ -766,7 +766,8 @@ segmentation. Example: ```python -from pyretailscience.analysis.segmentation import HMLSegmentation, SegTransactionStats +from pyretailscience.segmentation.segstats import SegTransactionStats +from pyretailscience.segmentation.hml import HMLSegmentation seg = HMLSegmentation(df, zero_value_customers="include_with_light") @@ -818,7 +819,7 @@ Example: ```python import pandas as pd -from pyretailscience.analysis.segmentation import RFMSegmentation +from pyretailscience.segmentation.rfm import RFMSegmentation data = pd.DataFrame({ "customer_id": [1, 1, 2, 2, 3, 3, 3], diff --git a/docs/api/segmentation/base.md b/docs/api/segmentation/base.md new file mode 100644 index 0000000..d3b574b --- /dev/null +++ b/docs/api/segmentation/base.md @@ -0,0 +1,3 @@ +# Base Segmentation + +::: pyretailscience.segmentation.base diff --git a/docs/api/segmentation/hml.md b/docs/api/segmentation/hml.md new file mode 100644 index 0000000..afabd03 --- /dev/null +++ b/docs/api/segmentation/hml.md @@ -0,0 +1,3 @@ +# HML Segmentation + +::: pyretailscience.segmentation.hml diff --git a/docs/api/segmentation/rfm.md b/docs/api/segmentation/rfm.md new file mode 100644 index 0000000..2382f08 --- /dev/null +++ b/docs/api/segmentation/rfm.md @@ -0,0 +1,3 @@ +# RFM Segmentation + +::: pyretailscience.segmentation.rfm diff --git a/docs/api/segmentation/segstats.md b/docs/api/segmentation/segstats.md new file mode 100644 index 0000000..4692c0a --- /dev/null +++ b/docs/api/segmentation/segstats.md @@ -0,0 +1,3 @@ +# SegTransactionStats Segmentation + +::: pyretailscience.segmentation.segstats diff --git a/docs/api/segmentation/threshold.md b/docs/api/segmentation/threshold.md new file mode 100644 index 0000000..e919d48 --- /dev/null +++ b/docs/api/segmentation/threshold.md @@ -0,0 +1,3 @@ +# Threshold Segmentation + +::: pyretailscience.segmentation.threshold diff --git a/mkdocs.yml b/mkdocs.yml index 961f1e8..267d096 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -26,7 +26,12 @@ nav: - Haversine Distance: api/analysis/haversine.md - Product Association: api/analysis/product_association.md - Revenue Tree: api/analysis/revenue_tree.md - - Segmentation: api/analysis/segmentation.md + - Segmentation: + - Base Segmentation: api/segmentation/base.md + - HML Segmentation: api/segmentation/hml.md + - RFM Segmentation: api/segmentation/rfm.md + - SegTransactionStats Segmentation: api/segmentation/segstats.md + - Threshold Segmentation: api/segmentation/threshold.md - Plots: - Area Plot: api/plots/area.md - Bar Plot: api/plots/bar.md diff --git a/pyretailscience/segmentation/__init__.py b/pyretailscience/segmentation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pyretailscience/segmentation/base.py b/pyretailscience/segmentation/base.py new file mode 100644 index 0000000..73a8381 --- /dev/null +++ b/pyretailscience/segmentation/base.py @@ -0,0 +1,34 @@ +"""This module provides a base class for segmenting customers based on their spend and transaction statistics.""" + +import pandas as pd + +from pyretailscience.options import get_option + + +class BaseSegmentation: + """A base class for customer segmentation.""" + + def add_segment(self, df: pd.DataFrame) -> pd.DataFrame: + """Adds the segment to the dataframe based on the customer_id column. + + Args: + df (pd.DataFrame): The dataframe to add the segment to. The dataframe must have a customer_id column. + + Returns: + pd.DataFrame: The dataframe with the segment added. + + Raises: + ValueError: If the number of rows before and after the merge do not match. + """ + rows_before = len(df) + df = df.merge( + self.df["segment_name"], + how="left", + left_on=get_option("column.customer_id"), + right_index=True, + ) + rows_after = len(df) + if rows_before != rows_after: + raise ValueError("The number of rows before and after the merge do not match. This should not happen.") + + return df diff --git a/pyretailscience/segmentation/hml.py b/pyretailscience/segmentation/hml.py new file mode 100644 index 0000000..0752bb3 --- /dev/null +++ b/pyretailscience/segmentation/hml.py @@ -0,0 +1,49 @@ +"""This module provides the `HMLSegmentation` class for categorizing customers into spend-based segments. + +HMLSegmentation extends `ThresholdSegmentation` and classifies customers into Heavy, Medium, Light, +and optionally Zero spenders based on the Pareto principle (80/20 rule). It is commonly used in retail +to analyze customer spending behavior and optimize marketing strategies. +""" + +from typing import Literal + +import ibis +import pandas as pd + +from pyretailscience.segmentation.threshold import ThresholdSegmentation + + +class HMLSegmentation(ThresholdSegmentation): + """Segments customers into Heavy, Medium, Light and Zero spenders based on the total spend.""" + + def __init__( + self, + df: pd.DataFrame | ibis.Table, + value_col: str | None = None, + agg_func: str = "sum", + zero_value_customers: Literal["separate_segment", "exclude", "include_with_light"] = "separate_segment", + ) -> None: + """Segments customers into Heavy, Medium, Light and Zero spenders based on the total spend. + + HMLSegmentation is a subclass of ThresholdSegmentation and based around an industry standard definition. The + thresholds for Heavy (top 20%), Medium (next 30%) and Light (bottom 50%) are chosen based on the pareto + distribution, commonly know as the 80/20 rule. It is typically used in retail to segment customers based on + their spend, transaction volume or quantities purchased. + + Args: + df (pd.DataFrame): A dataframe with the transaction data. The dataframe must contain a customer_id column. + value_col (str, optional): The column to use for the segmentation. Defaults to get_option("column.unit_spend"). + agg_func (str, optional): The aggregation function to use when grouping by customer_id. Defaults to "sum". + zero_value_customers (Literal["separate_segment", "exclude", "include_with_light"], optional): How to handle + customers with zero spend. Defaults to "separate_segment". + """ + thresholds = [0.500, 0.800, 1] + segments = ["Light", "Medium", "Heavy"] + super().__init__( + df=df, + value_col=value_col, + agg_func=agg_func, + thresholds=thresholds, + segments=segments, + zero_value_customers=zero_value_customers, + ) diff --git a/pyretailscience/segmentation/rfm.py b/pyretailscience/segmentation/rfm.py new file mode 100644 index 0000000..219a051 --- /dev/null +++ b/pyretailscience/segmentation/rfm.py @@ -0,0 +1,138 @@ +"""Customer Segmentation Using RFM Analysis. + +This module implements RFM (Recency, Frequency, Monetary) segmentation, a widely used technique in customer analytics +to categorize customers based on their purchasing behavior. + +RFM segmentation assigns scores to customers based on: +1. Recency (R): How recently a customer made a purchase. +2. Frequency (F): How often a customer makes purchases. +3. Monetary (M): The total amount spent by a customer. + +### Benefits of RFM Segmentation: +- **Customer Value Analysis**: Identifies high-value customers who contribute the most revenue. +- **Personalized Marketing**: Enables targeted campaigns based on customer purchasing behavior. +- **Customer Retention Strategies**: Helps recognize at-risk customers and develop engagement strategies. +- **Sales Forecasting**: Provides insights into future revenue trends based on past spending behavior. + +### Scoring Methodology: +- Each metric (R, F, M) is divided into 10 bins (0-9) using the NTILE(10) function. +- A higher score indicates a better customer (e.g., lower recency, higher frequency, and monetary value). +- The final RFM segment is computed as `R*100 + F*10 + M`, providing a unique customer classification. + +This module leverages `pandas` and `ibis` for efficient data processing and integrates with retail analytics workflows +to enhance customer insights and business decision-making. +""" + +import datetime + +import ibis +import pandas as pd + +from pyretailscience.options import ColumnHelper, get_option + + +class RFMSegmentation: + """Segments customers using the RFM (Recency, Frequency, Monetary) methodology. + + Customers are scored on three dimensions: + - Recency (R): Days since the last transaction (lower is better). + - Frequency (F): Number of unique transactions (higher is better). + - Monetary (M): Total amount spent (higher is better). + + Each metric is ranked into 10 bins (0-9) using NTILE(10) where, + - 9 represents the best score (top 10% of customers). + - 0 represents the lowest score (bottom 10% of customers). + The RFM segment is a 3-digit number (R*100 + F*10 + M), representing customer value. + """ + + _df: pd.DataFrame | None = None + + def __init__(self, df: pd.DataFrame | ibis.Table, current_date: str | datetime.date | None = None) -> None: + """Initializes the RFM segmentation process. + + Args: + df (pd.DataFrame | ibis.Table): A DataFrame or Ibis table containing transaction data. + Must include the following columns: + - customer_id + - transaction_date + - unit_spend + - transaction_id + current_date (Optional[Union[str, datetime.date]]): The reference date for calculating recency. + Can be a string (format: "YYYY-MM-DD"), a date object, or None (defaults to the current system date). + + Raises: + ValueError: If the dataframe is missing required columns. + TypeError: If the input data is not a pandas DataFrame or an Ibis Table. + """ + cols = ColumnHelper() + required_cols = [ + cols.customer_id, + cols.transaction_date, + cols.unit_spend, + cols.transaction_id, + ] + if isinstance(df, pd.DataFrame): + df = ibis.memtable(df) + elif not isinstance(df, ibis.Table): + raise TypeError("df must be either a pandas DataFrame or an Ibis Table") + + missing_cols = set(required_cols) - set(df.columns) + if missing_cols: + error_message = f"Missing required columns: {missing_cols}" + raise ValueError(error_message) + + if isinstance(current_date, str): + current_date = datetime.date.fromisoformat(current_date) + elif current_date is None: + current_date = datetime.datetime.now(datetime.UTC).date() + elif not isinstance(current_date, datetime.date): + raise TypeError("current_date must be a string in 'YYYY-MM-DD' format, a datetime.date object, or None") + + self.table = self._compute_rfm(df, current_date) + + def _compute_rfm(self, df: ibis.Table, current_date: datetime.date) -> ibis.Table: + """Computes the RFM metrics and segments customers accordingly. + + Args: + df (ibis.Table): The transaction data table. + current_date (datetime.date): The reference date for calculating recency. + + Returns: + ibis.Table: A table with RFM scores and segment values. + """ + cols = ColumnHelper() + current_date_expr = ibis.literal(current_date) + + customer_metrics = df.group_by(cols.customer_id).aggregate( + recency_days=(current_date_expr - df[cols.transaction_date].max().cast("date")).cast("int32"), + frequency=df[cols.transaction_id].nunique(), + monetary=df[cols.unit_spend].sum(), + ) + + window_recency = ibis.window( + order_by=[ibis.asc(customer_metrics.recency_days), ibis.asc(customer_metrics.customer_id)], + ) + window_frequency = ibis.window( + order_by=[ibis.asc(customer_metrics.frequency), ibis.asc(customer_metrics.customer_id)], + ) + window_monetary = ibis.window( + order_by=[ibis.asc(customer_metrics.monetary), ibis.asc(customer_metrics.customer_id)], + ) + + rfm_scores = customer_metrics.mutate( + r_score=(ibis.ntile(10).over(window_recency)), + f_score=(ibis.ntile(10).over(window_frequency)), + m_score=(ibis.ntile(10).over(window_monetary)), + ) + + return rfm_scores.mutate( + rfm_segment=(rfm_scores.r_score * 100 + rfm_scores.f_score * 10 + rfm_scores.m_score), + fm_segment=(rfm_scores.f_score * 10 + rfm_scores.m_score), + ) + + @property + def df(self) -> pd.DataFrame: + """Returns the dataframe with the segment names.""" + if self._df is None: + self._df = self.table.execute().set_index(get_option("column.customer_id")) + return self._df diff --git a/pyretailscience/analysis/segmentation.py b/pyretailscience/segmentation/segstats.py similarity index 50% rename from pyretailscience/analysis/segmentation.py rename to pyretailscience/segmentation/segstats.py index 8a4748a..e1a0605 100644 --- a/pyretailscience/analysis/segmentation.py +++ b/pyretailscience/segmentation/segstats.py @@ -1,6 +1,14 @@ -"""This module contains classes for segmenting customers based on their spend and transaction statistics by segment.""" +"""Module for calculating and visualizing transaction statistics by segment. + +This module provides the `SegTransactionStats` class, which allows for the computation of +transaction-based statistics grouped by one or more segment columns. The statistics include +aggregations such as total spend, unique customers, transactions per customer, and optional +custom aggregations. + +The module supports both Pandas DataFrames and Ibis Tables as input data formats. It also +offers visualization capabilities to generate plots of segment-based statistics. +""" -import datetime from typing import Literal import ibis @@ -8,183 +16,10 @@ from matplotlib.axes import Axes, SubplotBase import pyretailscience.style.graph_utils as gu -from pyretailscience.options import ColumnHelper, get_option +from pyretailscience.options import ColumnHelper from pyretailscience.style.tailwind import COLORS -class BaseSegmentation: - """A base class for customer segmentation.""" - - def add_segment(self, df: pd.DataFrame) -> pd.DataFrame: - """Adds the segment to the dataframe based on the customer_id column. - - Args: - df (pd.DataFrame): The dataframe to add the segment to. The dataframe must have a customer_id column. - - Returns: - pd.DataFrame: The dataframe with the segment added. - - Raises: - ValueError: If the number of rows before and after the merge do not match. - """ - rows_before = len(df) - df = df.merge( - self.df["segment_name"], - how="left", - left_on=get_option("column.customer_id"), - right_index=True, - ) - rows_after = len(df) - if rows_before != rows_after: - raise ValueError("The number of rows before and after the merge do not match. This should not happen.") - - return df - - -class ExistingSegmentation(BaseSegmentation): - """Segments customers based on an existing segment in the dataframe.""" - - def __init__(self, df: pd.DataFrame) -> None: - """Segments customers based on an existing segment in the dataframe. - - Args: - df (pd.DataFrame): A dataframe with the customer_id and segment_name columns. - - Raises: - ValueError: If the dataframe does not have the columns customer_id and segment_name. - """ - cols = ColumnHelper() - required_cols = [cols.customer_id, "segment_name"] - missing_cols = set(required_cols) - set(df.columns) - if len(missing_cols) > 0: - msg = f"The following columns are required but missing: {missing_cols}" - raise ValueError(msg) - - self.df = df[[cols.customer_id, "segment_name"]].set_index(cols.customer_id) - - -class ThresholdSegmentation(BaseSegmentation): - """Segments customers based on user-defined thresholds and segments.""" - - _df: pd.DataFrame | None = None - - def __init__( - self, - df: pd.DataFrame | ibis.Table, - thresholds: list[float], - segments: dict[any, str], - value_col: str | None = None, - agg_func: str = "sum", - zero_segment_name: str = "Zero", - zero_value_customers: Literal["separate_segment", "exclude", "include_with_light"] = "separate_segment", - ) -> None: - """Segments customers based on user-defined thresholds and segments. - - Args: - df (pd.DataFrame | ibis.Table): A dataframe with the transaction data. The dataframe must contain a customer_id column. - thresholds (List[float]): The percentile thresholds for segmentation. - segments (Dict[str, str]): A dictionary where keys are segment IDs and values are segment names. - value_col (str, optional): The column to use for the segmentation. Defaults to get_option("column.unit_spend"). - agg_func (str, optional): The aggregation function to use when grouping by customer_id. Defaults to "sum". - zero_segment_name (str, optional): The name of the segment for customers with zero spend. Defaults to "Zero". - zero_value_customers (Literal["separate_segment", "exclude", "include_with_light"], optional): How to handle - customers with zero spend. Defaults to "separate_segment". - - Raises: - ValueError: If the dataframe is missing the columns option column.customer_id or `value_col`, or these - columns contain null values. - """ - if len(thresholds) != len(set(thresholds)): - raise ValueError("The thresholds must be unique.") - - if len(thresholds) != len(segments): - raise ValueError("The number of thresholds must match the number of segments.") - - if isinstance(df, pd.DataFrame): - df: ibis.Table = ibis.memtable(df) - - value_col = get_option("column.unit_spend") if value_col is None else value_col - - required_cols = [get_option("column.customer_id"), value_col] - - missing_cols = set(required_cols) - set(df.columns) - if len(missing_cols) > 0: - msg = f"The following columns are required but missing: {missing_cols}" - raise ValueError(msg) - - df = df.group_by(get_option("column.customer_id")).aggregate( - **{value_col: getattr(df[value_col], agg_func)()}, - ) - - # Separate customers with zero spend - zero_df = None - if zero_value_customers == "exclude": - df = df.filter(df[value_col] != 0) - elif zero_value_customers == "separate_segment": - zero_df = df.filter(df[value_col] == 0).mutate(segment_name=ibis.literal(zero_segment_name)) - df = df.filter(df[value_col] != 0) - - window = ibis.window(order_by=ibis.asc(df[value_col])) - df = df.mutate(ptile=ibis.percent_rank().over(window)) - - case = ibis.case() - - for quantile, segment in zip(thresholds, segments, strict=True): - case = case.when(df["ptile"] <= quantile, segment) - - case = case.end() - - df = df.mutate(segment_name=case).drop(["ptile"]) - - if zero_value_customers == "separate_segment": - df = ibis.union(df, zero_df) - - self.table = df - - @property - def df(self) -> pd.DataFrame: - """Returns the dataframe with the segment names.""" - if self._df is None: - self._df = self.table.execute().set_index(get_option("column.customer_id")) - return self._df - - -class HMLSegmentation(ThresholdSegmentation): - """Segments customers into Heavy, Medium, Light and Zero spenders based on the total spend.""" - - def __init__( - self, - df: pd.DataFrame | ibis.Table, - value_col: str | None = None, - agg_func: str = "sum", - zero_value_customers: Literal["separate_segment", "exclude", "include_with_light"] = "separate_segment", - ) -> None: - """Segments customers into Heavy, Medium, Light and Zero spenders based on the total spend. - - HMLSegmentation is a subclass of ThresholdSegmentation and based around an industry standard definition. The - thresholds for Heavy (top 20%), Medium (next 30%) and Light (bottom 50%) are chosen based on the pareto - distribution, commonly know as the 80/20 rule. It is typically used in retail to segment customers based on - their spend, transaction volume or quantities purchased. - - Args: - df (pd.DataFrame): A dataframe with the transaction data. The dataframe must contain a customer_id column. - value_col (str, optional): The column to use for the segmentation. Defaults to get_option("column.unit_spend"). - agg_func (str, optional): The aggregation function to use when grouping by customer_id. Defaults to "sum". - zero_value_customers (Literal["separate_segment", "exclude", "include_with_light"], optional): How to handle - customers with zero spend. Defaults to "separate_segment". - """ - thresholds = [0.500, 0.800, 1] - segments = ["Light", "Medium", "Heavy"] - super().__init__( - df=df, - value_col=value_col, - agg_func=agg_func, - thresholds=thresholds, - segments=segments, - zero_value_customers=zero_value_customers, - ) - - class SegTransactionStats: """Calculates transaction statistics by segment.""" @@ -462,115 +297,3 @@ def plot( gu.standard_tick_styles(ax) return ax - - -class RFMSegmentation: - """Segments customers using the RFM (Recency, Frequency, Monetary) methodology. - - Customers are scored on three dimensions: - - Recency (R): Days since the last transaction (lower is better). - - Frequency (F): Number of unique transactions (higher is better). - - Monetary (M): Total amount spent (higher is better). - - Each metric is ranked into 10 bins (0-9) using NTILE(10) where, - - 9 represents the best score (top 10% of customers). - - 0 represents the lowest score (bottom 10% of customers). - The RFM segment is a 3-digit number (R*100 + F*10 + M), representing customer value. - """ - - _df: pd.DataFrame | None = None - - def __init__(self, df: pd.DataFrame | ibis.Table, current_date: str | datetime.date | None = None) -> None: - """Initializes the RFM segmentation process. - - Args: - df (pd.DataFrame | ibis.Table): A DataFrame or Ibis table containing transaction data. - Must include the following columns: - - customer_id - - transaction_date - - unit_spend - - transaction_id - current_date (Optional[Union[str, datetime.date]]): The reference date for calculating recency. - Can be a string (format: "YYYY-MM-DD"), a date object, or None (defaults to the current system date). - - Raises: - ValueError: If the dataframe is missing required columns. - TypeError: If the input data is not a pandas DataFrame or an Ibis Table. - """ - cols = ColumnHelper() - required_cols = [ - cols.customer_id, - cols.transaction_date, - cols.unit_spend, - cols.transaction_id, - ] - if isinstance(df, pd.DataFrame): - df = ibis.memtable(df) - elif not isinstance(df, ibis.Table): - raise TypeError("df must be either a pandas DataFrame or an Ibis Table") - - missing_cols = set(required_cols) - set(df.columns) - if missing_cols: - error_message = f"Missing required columns: {missing_cols}" - raise ValueError(error_message) - - if isinstance(current_date, str): - current_date = datetime.date.fromisoformat(current_date) - elif current_date is None: - current_date = datetime.datetime.now(datetime.UTC).date() - elif not isinstance(current_date, datetime.date): - raise TypeError("current_date must be a string in 'YYYY-MM-DD' format, a datetime.date object, or None") - - self.table = self._compute_rfm(df, current_date) - - def _compute_rfm(self, df: ibis.Table, current_date: datetime.date) -> ibis.Table: - """Computes the RFM metrics and segments customers accordingly. - - Args: - df (ibis.Table): The transaction data table. - current_date (datetime.date): The reference date for calculating recency. - - Returns: - ibis.Table: A table with RFM scores and segment values. - """ - cols = ColumnHelper() - current_date_expr = ibis.literal(current_date) - - customer_metrics = df.group_by(cols.customer_id).aggregate( - recency_days=(current_date_expr - df[cols.transaction_date].max().cast("date")).cast("int32"), - frequency=df[cols.transaction_id].nunique(), - monetary=df[cols.unit_spend].sum(), - ) - - window_recency = ibis.window( - order_by=[ibis.asc(customer_metrics.recency_days), ibis.asc(customer_metrics.customer_id)], - ) - window_frequency = ibis.window( - order_by=[ibis.asc(customer_metrics.frequency), ibis.asc(customer_metrics.customer_id)], - ) - window_monetary = ibis.window( - order_by=[ibis.asc(customer_metrics.monetary), ibis.asc(customer_metrics.customer_id)], - ) - - rfm_scores = customer_metrics.mutate( - r_score=(ibis.ntile(10).over(window_recency)), - f_score=(ibis.ntile(10).over(window_frequency)), - m_score=(ibis.ntile(10).over(window_monetary)), - ) - - return rfm_scores.mutate( - rfm_segment=(rfm_scores.r_score * 100 + rfm_scores.f_score * 10 + rfm_scores.m_score), - fm_segment=(rfm_scores.f_score * 10 + rfm_scores.m_score), - ) - - @property - def df(self) -> pd.DataFrame: - """Returns the dataframe with the segment names.""" - if self._df is None: - self._df = self.table.execute().set_index(get_option("column.customer_id")) - return self._df - - @property - def ibis_table(self) -> ibis.Table: - """Returns the computed Ibis table with RFM segmentation.""" - return self.table diff --git a/pyretailscience/segmentation/threshold.py b/pyretailscience/segmentation/threshold.py new file mode 100644 index 0000000..933c0bb --- /dev/null +++ b/pyretailscience/segmentation/threshold.py @@ -0,0 +1,105 @@ +"""Threshold-Based Customer Segmentation Module. + +This module provides the `ThresholdSegmentation` class, which segments customers +based on user-defined thresholds and segment mappings. + +Key Features: +- Segments customers based on specified percentile thresholds. +- Uses a specified column for segmentation, with an aggregation function applied. +- Handles customers with zero spend using configurable options. +- Utilizes Ibis for efficient query execution. +""" + +from typing import Literal + +import ibis +import pandas as pd + +from pyretailscience.options import get_option +from pyretailscience.segmentation.base import BaseSegmentation + + +class ThresholdSegmentation(BaseSegmentation): + """Segments customers based on user-defined thresholds and segments.""" + + _df: pd.DataFrame | None = None + + def __init__( + self, + df: pd.DataFrame | ibis.Table, + thresholds: list[float], + segments: list[str], + value_col: str | None = None, + agg_func: str = "sum", + zero_segment_name: str = "Zero", + zero_value_customers: Literal["separate_segment", "exclude", "include_with_light"] = "separate_segment", + ) -> None: + """Segments customers based on user-defined thresholds and segments. + + Args: + df (pd.DataFrame | ibis.Table): A dataframe with the transaction data. The dataframe must contain a customer_id column. + thresholds (List[float]): The percentile thresholds for segmentation. + segments (List[str]): A list of segment names for each threshold. + value_col (str, optional): The column to use for the segmentation. Defaults to get_option("column.unit_spend"). + agg_func (str, optional): The aggregation function to use when grouping by customer_id. Defaults to "sum". + zero_segment_name (str, optional): The name of the segment for customers with zero spend. Defaults to "Zero". + zero_value_customers (Literal["separate_segment", "exclude", "include_with_light"], optional): How to handle + customers with zero spend. Defaults to "separate_segment". + + Raises: + ValueError: If the dataframe is missing the columns option column.customer_id or `value_col`, or these + columns contain null values. + """ + if len(thresholds) != len(set(thresholds)): + raise ValueError("The thresholds must be unique.") + + if len(thresholds) != len(segments): + raise ValueError("The number of thresholds must match the number of segments.") + + if isinstance(df, pd.DataFrame): + df: ibis.Table = ibis.memtable(df) + + value_col = get_option("column.unit_spend") if value_col is None else value_col + + required_cols = [get_option("column.customer_id"), value_col] + + missing_cols = set(required_cols) - set(df.columns) + if len(missing_cols) > 0: + msg = f"The following columns are required but missing: {missing_cols}" + raise ValueError(msg) + + df = df.group_by(get_option("column.customer_id")).aggregate( + **{value_col: getattr(df[value_col], agg_func)()}, + ) + + # Separate customers with zero spend + zero_df = None + if zero_value_customers == "exclude": + df = df.filter(df[value_col] != 0) + elif zero_value_customers == "separate_segment": + zero_df = df.filter(df[value_col] == 0).mutate(segment_name=ibis.literal(zero_segment_name)) + df = df.filter(df[value_col] != 0) + + window = ibis.window(order_by=ibis.asc(df[value_col])) + df = df.mutate(ptile=ibis.percent_rank().over(window)) + + case = ibis.case() + + for quantile, segment in zip(thresholds, segments, strict=True): + case = case.when(df["ptile"] <= quantile, segment) + + case = case.end() + + df = df.mutate(segment_name=case).drop(["ptile"]) + + if zero_value_customers == "separate_segment": + df = ibis.union(df, zero_df) + + self.table = df + + @property + def df(self) -> pd.DataFrame: + """Returns the dataframe with the segment names.""" + if self._df is None: + self._df = self.table.execute().set_index(get_option("column.customer_id")) + return self._df diff --git a/tests/analysis/test_segmentation.py b/tests/analysis/test_segmentation.py deleted file mode 100644 index b1768ff..0000000 --- a/tests/analysis/test_segmentation.py +++ /dev/null @@ -1,733 +0,0 @@ -"""Tests for the SegTransactionStats class.""" - -import ibis -import numpy as np -import pandas as pd -import pytest -from freezegun import freeze_time - -from pyretailscience.analysis.segmentation import ( - HMLSegmentation, - RFMSegmentation, - SegTransactionStats, - ThresholdSegmentation, -) -from pyretailscience.options import ColumnHelper, get_option - -cols = ColumnHelper() - - -class TestCalcSegStats: - """Tests for the _calc_seg_stats method.""" - - @pytest.fixture - def base_df(self): - """Return a base DataFrame for testing.""" - return pd.DataFrame( - { - cols.customer_id: [1, 2, 3, 4, 5], - cols.unit_spend: [100.0, 200.0, 150.0, 300.0, 250.0], - cols.transaction_id: [101, 102, 103, 104, 105], - "segment_name": ["A", "B", "A", "B", "A"], - cols.unit_qty: [10, 20, 15, 30, 25], - }, - ) - - def test_correctly_calculates_revenue_transactions_customers_per_segment(self, base_df): - """Test that the method correctly calculates at the transaction-item level.""" - expected_output = pd.DataFrame( - { - "segment_name": ["A", "B", "Total"], - cols.agg_unit_spend: [500.0, 500.0, 1000.0], - cols.agg_transaction_id: [3, 2, 5], - cols.agg_customer_id: [3, 2, 5], - cols.agg_unit_qty: [50, 50, 100], - cols.calc_spend_per_cust: [166.666667, 250.0, 200.0], - cols.calc_spend_per_trans: [166.666667, 250.0, 200.0], - cols.calc_trans_per_cust: [1.0, 1.0, 1.0], - cols.calc_price_per_unit: [10.0, 10.0, 10.0], - cols.calc_units_per_trans: [16.666667, 25.0, 20.0], - cols.customers_pct: [0.6, 0.4, 1.0], - }, - ) - segment_stats = ( - SegTransactionStats(base_df, "segment_name").df.sort_values("segment_name").reset_index(drop=True) - ) - pd.testing.assert_frame_equal(segment_stats, expected_output) - - def test_correctly_calculates_revenue_transactions_customers(self): - """Test that the method correctly calculates at the transaction level.""" - df = pd.DataFrame( - { - get_option("column.customer_id"): [1, 2, 3, 4, 5], - cols.unit_spend: [100.0, 200.0, 150.0, 300.0, 250.0], - cols.transaction_id: [101, 102, 103, 104, 105], - "segment_name": ["A", "B", "A", "B", "A"], - }, - ) - - expected_output = pd.DataFrame( - { - "segment_name": ["A", "B", "Total"], - cols.agg_unit_spend: [500.0, 500.0, 1000.0], - cols.agg_transaction_id: [3, 2, 5], - cols.agg_customer_id: [3, 2, 5], - cols.calc_spend_per_cust: [166.666667, 250.0, 200.0], - cols.calc_spend_per_trans: [166.666667, 250.0, 200.0], - cols.calc_trans_per_cust: [1.0, 1.0, 1.0], - cols.customers_pct: [0.6, 0.4, 1.0], - }, - ) - - segment_stats = SegTransactionStats(df, "segment_name").df.sort_values("segment_name").reset_index(drop=True) - pd.testing.assert_frame_equal(segment_stats, expected_output) - - def test_handles_dataframe_with_one_segment(self, base_df): - """Test that the method correctly handles a DataFrame with only one segment.""" - df = base_df.copy() - df["segment_name"] = "A" - - expected_output = pd.DataFrame( - { - "segment_name": ["A", "Total"], - cols.agg_unit_spend: [1000.0, 1000.0], - cols.agg_transaction_id: [5, 5], - cols.agg_customer_id: [5, 5], - cols.agg_unit_qty: [100, 100], - cols.calc_spend_per_cust: [200.0, 200.0], - cols.calc_spend_per_trans: [200.0, 200.0], - cols.calc_trans_per_cust: [1.0, 1.0], - cols.calc_price_per_unit: [10.0, 10.0], - cols.calc_units_per_trans: [20.0, 20.0], - cols.customers_pct: [1.0, 1.0], - }, - ) - - segment_stats = SegTransactionStats(df, "segment_name").df.sort_values("segment_name").reset_index(drop=True) - pd.testing.assert_frame_equal(segment_stats, expected_output) - - def test_handles_dataframe_with_zero_net_units(self, base_df): - """Test that the method correctly handles a DataFrame with a segment with net zero units.""" - df = base_df.copy() - df[cols.unit_qty] = [10, 20, 15, 30, -25] - - expected_output = pd.DataFrame( - { - "segment_name": ["A", "B", "Total"], - cols.agg_unit_spend: [500.0, 500.0, 1000.0], - cols.agg_transaction_id: [3, 2, 5], - cols.agg_customer_id: [3, 2, 5], - cols.agg_unit_qty: [0, 50, 50], - cols.calc_spend_per_cust: [166.666667, 250.0, 200.0], - cols.calc_spend_per_trans: [166.666667, 250.0, 200.0], - cols.calc_trans_per_cust: [1.0, 1.0, 1.0], - cols.calc_price_per_unit: [np.nan, 10.0, 20.0], - cols.calc_units_per_trans: [0, 25.0, 10.0], - cols.customers_pct: [0.6, 0.4, 1.0], - }, - ) - segment_stats = SegTransactionStats(df, "segment_name").df.sort_values("segment_name").reset_index(drop=True) - - pd.testing.assert_frame_equal(segment_stats, expected_output) - - def test_excludes_total_row_when_calc_total_false(self, base_df): - """Test that the method excludes the total row when calc_total=False.""" - expected_output = pd.DataFrame( - { - "segment_name": ["A", "B"], - cols.agg_unit_spend: [500.0, 500.0], - cols.agg_transaction_id: [3, 2], - cols.agg_customer_id: [3, 2], - cols.agg_unit_qty: [50, 50], - cols.calc_spend_per_cust: [166.666667, 250.0], - cols.calc_spend_per_trans: [166.666667, 250.0], - cols.calc_trans_per_cust: [1.0, 1.0], - cols.calc_price_per_unit: [10.0, 10.0], - cols.calc_units_per_trans: [16.666667, 25.0], - cols.customers_pct: [1.0, 1.0], - }, - ) - - segment_stats = ( - SegTransactionStats(base_df, "segment_name", calc_total=False) - .df.sort_values("segment_name") - .reset_index(drop=True) - ) - - pd.testing.assert_frame_equal(segment_stats, expected_output) - - -class TestThresholdSegmentation: - """Tests for the ThresholdSegmentation class.""" - - def test_correct_segmentation(self): - """Test that the method correctly segments customers based on given thresholds and segments.""" - df = pd.DataFrame( - { - get_option("column.customer_id"): [1, 2, 3, 4], - cols.unit_spend: [100, 200, 300, 400], - }, - ) - thresholds = [0.5, 1] - segments = ["Low", "High"] - seg = ThresholdSegmentation( - df=df, - thresholds=thresholds, - segments=segments, - value_col=cols.unit_spend, - zero_value_customers="exclude", - ) - result_df = seg.df - assert result_df.loc[1, "segment_name"] == "Low" - assert result_df.loc[2, "segment_name"] == "Low" - assert result_df.loc[3, "segment_name"] == "High" - assert result_df.loc[4, "segment_name"] == "High" - - def test_single_customer(self): - """Test that the method correctly segments a DataFrame with only one customer.""" - df = pd.DataFrame({get_option("column.customer_id"): [1], cols.unit_spend: [100]}) - thresholds = [0.5, 1] - segments = ["Low"] - with pytest.raises(ValueError): - ThresholdSegmentation( - df=df, - thresholds=thresholds, - segments=segments, - ) - - def test_correct_aggregation_function(self): - """Test that the correct aggregation function is applied for product_id custom segmentation.""" - df = pd.DataFrame( - { - cols.customer_id: [1, 2, 2, 2, 2, 2, 3, 3, 3, 4, 4, 5, 5, 5], - "product_id": [3, 4, 4, 6, 1, 5, 7, 2, 2, 3, 2, 3, 4, 1], - }, - ) - value_col = "product_id" - agg_func = "nunique" - - my_seg = ThresholdSegmentation( - df=df, - value_col=value_col, - agg_func=agg_func, - thresholds=[0.2, 0.8, 1], - segments=["Low", "Medium", "High"], - zero_value_customers="separate_segment", - ) - - expected_result = pd.DataFrame( - { - cols.customer_id: [1, 2, 3, 4, 5], - "product_id": [1, 4, 2, 2, 3], - "segment_name": ["Low", "High", "Medium", "Medium", "Medium"], - }, - ) - pd.testing.assert_frame_equal(my_seg.df.sort_values(cols.customer_id).reset_index(), expected_result) - - def test_correctly_checks_segment_data(self): - """Test that the method correctly merges segment data back into the original DataFrame.""" - df = pd.DataFrame( - { - get_option("column.customer_id"): [1, 2, 3, 4, 5], - cols.unit_spend: [100, 200, 0, 150, 0], - }, - ) - value_col = cols.unit_spend - agg_func = "sum" - thresholds = [0.33, 0.66, 1] - segments = ["Low", "Medium", "High"] - zero_value_customers = "separate_segment" - - # Create ThresholdSegmentation instance - threshold_seg = ThresholdSegmentation( - df=df, - value_col=value_col, - agg_func=agg_func, - thresholds=thresholds, - segments=segments, - zero_value_customers=zero_value_customers, - ) - - # Call add_segment method - segmented_df = threshold_seg.add_segment(df) - - # Assert the correct segment_name - expected_df = pd.DataFrame( - { - get_option("column.customer_id"): [1, 2, 3, 4, 5], - cols.unit_spend: [100, 200, 0, 150, 0], - "segment_name": ["Low", "High", "Zero", "Medium", "Zero"], - }, - ) - pd.testing.assert_frame_equal(segmented_df, expected_df) - - def test_handles_dataframe_with_duplicate_customer_id_entries(self): - """Test that the method correctly handles a DataFrame with duplicate customer_id entries.""" - df = pd.DataFrame( - { - get_option("column.customer_id"): [1, 2, 3, 1, 2, 3], - cols.unit_spend: [100, 200, 300, 150, 250, 350], - }, - ) - - my_seg = ThresholdSegmentation( - df=df, - value_col=cols.unit_spend, - agg_func="sum", - thresholds=[0.5, 0.8, 1], - segments=["Light", "Medium", "Heavy"], - zero_value_customers="include_with_light", - ) - - result_df = my_seg.add_segment(df) - assert len(result_df) == len(df) - - def test_thresholds_not_unique(self): - """Test that the method raises an error when the thresholds are not unique.""" - df = pd.DataFrame( - { - get_option("column.customer_id"): [1, 2, 3, 4, 5], - cols.unit_spend: [100, 200, 300, 400, 500], - }, - ) - thresholds = [0.5, 0.5, 0.8, 1] - segments = ["Low", "Medium", "High"] - - with pytest.raises(ValueError): - ThresholdSegmentation(df, thresholds, segments) - - def test_thresholds_too_few_segments(self): - """Test that the method raises an error when there are too few/many segments for the number of thresholds.""" - df = pd.DataFrame( - { - get_option("column.customer_id"): [1, 2, 3, 4, 5], - cols.unit_spend: [100, 200, 300, 400, 500], - }, - ) - thresholds = [0.4, 0.6, 0.8, 1] - segments = ["Low", "High"] - - with pytest.raises(ValueError): - ThresholdSegmentation(df, thresholds, segments) - - segments = ["Low", "Medium", "High"] - - with pytest.raises(ValueError): - ThresholdSegmentation(df, thresholds, segments) - - def test_thresholds_too_too_few_thresholds(self): - """Test that the method raises an error when there are too few/many thresholds for the number of segments.""" - df = pd.DataFrame( - { - get_option("column.customer_id"): [1, 2, 3, 4, 5], - cols.unit_spend: [100, 200, 300, 400, 500], - }, - ) - thresholds = [0.4, 1] - segments = ["Low", "Medium", "High"] - - with pytest.raises(ValueError): - ThresholdSegmentation(df, thresholds, segments) - - thresholds = [0.2, 0.5, 0.6, 0.8, 1] - - with pytest.raises(ValueError): - ThresholdSegmentation(df, thresholds, segments) - - -class TestSegTransactionStats: - """Tests for the SegTransactionStats class.""" - - def test_handles_empty_dataframe_with_errors(self): - """Test that the method raises an error when the DataFrame is missing a required column.""" - df = pd.DataFrame( - columns=[cols.unit_spend, cols.transaction_id, cols.unit_qty], - ) - - with pytest.raises(ValueError): - SegTransactionStats(df, "segment_name") - - def test_multiple_segment_columns(self): - """Test that the class correctly handles multiple segment columns.""" - df = pd.DataFrame( - { - cols.customer_id: [1, 1, 2, 2, 3, 3], - cols.unit_spend: [100.0, 150.0, 200.0, 250.0, 300.0, 350.0], - cols.transaction_id: [101, 102, 103, 104, 105, 106], - "segment_name": ["A", "A", "B", "B", "A", "A"], - "region": ["North", "North", "South", "South", "East", "East"], - }, - ) - - # Test with a list of segment columns - seg_stats = SegTransactionStats(df, ["segment_name", "region"]) - - # Create expected DataFrame with the combinations actually produced - expected_output = pd.DataFrame( - { - "segment_name": ["A", "A", "B", "Total"], - "region": ["East", "North", "South", "Total"], - cols.agg_unit_spend: [650.0, 250.0, 450.0, 1350.0], - cols.agg_transaction_id: [2, 2, 2, 6], - cols.agg_customer_id: [1, 1, 1, 3], - cols.calc_spend_per_cust: [650.0, 250.0, 450.0, 450.0], - cols.calc_spend_per_trans: [325.0, 125.0, 225.0, 225.0], - cols.calc_trans_per_cust: [2.0, 2.0, 2.0, 2.0], - cols.customers_pct: [1 / 3, 1 / 3, 1 / 3, 1.0], - }, - ) - - # Sort both dataframes by the segment columns for consistent comparison - result_df = seg_stats.df.sort_values(["segment_name", "region"]).reset_index(drop=True) - expected_output = expected_output.sort_values(["segment_name", "region"]).reset_index(drop=True) - - # Check that both segment columns are in the result - assert "segment_name" in result_df.columns - assert "region" in result_df.columns - - # Check number of rows - the implementation only returns actual combinations that exist in data - # plus the Total row, not all possible combinations - assert len(result_df) == len(expected_output) - - # Use pandas testing to compare the dataframes - pd.testing.assert_frame_equal(result_df[expected_output.columns], expected_output) - - def test_plot_with_multiple_segment_columns(self): - """Test that plotting with multiple segment columns raises a ValueError.""" - df = pd.DataFrame( - { - cols.customer_id: [1, 2, 3], - cols.unit_spend: [100.0, 200.0, 300.0], - cols.transaction_id: [101, 102, 103], - "segment_name": ["A", "B", "A"], - "region": ["North", "South", "East"], - }, - ) - - seg_stats = SegTransactionStats(df, ["segment_name", "region"]) - - with pytest.raises(ValueError) as excinfo: - seg_stats.plot("spend") - - assert "Plotting is only supported for a single segment column" in str(excinfo.value) - - def test_extra_aggs_functionality(self): - """Test that the extra_aggs parameter works correctly.""" - # Constants for expected values - segment_a_store_count = 3 # Segment A has stores 1, 2, 4 - segment_b_store_count = 2 # Segment B has stores 1, 3 - total_store_count = 4 # Total has stores 1, 2, 3, 4 - - segment_a_product_count = 3 # Segment A has products 10, 20, 40 - segment_b_product_count = 2 # Segment B has products 10, 30 - total_product_count = 4 # Total has products 10, 20, 30, 40 - df = pd.DataFrame( - { - cols.customer_id: [1, 1, 2, 2, 3, 3], - cols.unit_spend: [100.0, 150.0, 200.0, 250.0, 300.0, 350.0], - cols.transaction_id: [101, 102, 103, 104, 105, 106], - "segment_name": ["A", "A", "B", "B", "A", "A"], - "store_id": [1, 2, 1, 3, 2, 4], - "product_id": [10, 20, 10, 30, 20, 40], - }, - ) - - # Test with a single extra aggregation - seg_stats = SegTransactionStats( - df, - "segment_name", - extra_aggs={"distinct_stores": ("store_id", "nunique")}, - ) - - # Verify the extra column exists and has correct values - assert "distinct_stores" in seg_stats.df.columns - - # Sort by segment_name to ensure consistent order - result_df = seg_stats.df.sort_values("segment_name").reset_index(drop=True) - - assert result_df.loc[0, "distinct_stores"] == segment_a_store_count # Segment A - assert result_df.loc[1, "distinct_stores"] == segment_b_store_count # Segment B - assert result_df.loc[2, "distinct_stores"] == total_store_count # Total - - # Test with multiple extra aggregations - seg_stats_multi = SegTransactionStats( - df, - "segment_name", - extra_aggs={ - "distinct_stores": ("store_id", "nunique"), - "distinct_products": ("product_id", "nunique"), - }, - ) - - # Verify both extra columns exist - assert "distinct_stores" in seg_stats_multi.df.columns - assert "distinct_products" in seg_stats_multi.df.columns - - # Sort by segment_name to ensure consistent order - result_df_multi = seg_stats_multi.df.sort_values("segment_name").reset_index(drop=True) - - assert result_df_multi["distinct_products"].to_list() == [ - segment_a_product_count, - segment_b_product_count, - total_product_count, - ] - - def test_extra_aggs_with_invalid_column(self): - """Test that an error is raised when an invalid column is specified in extra_aggs.""" - df = pd.DataFrame( - { - cols.customer_id: [1, 2, 3], - cols.unit_spend: [100.0, 200.0, 300.0], - cols.transaction_id: [101, 102, 103], - "segment_name": ["A", "B", "A"], - }, - ) - - with pytest.raises(ValueError) as excinfo: - SegTransactionStats(df, "segment_name", extra_aggs={"invalid_agg": ("nonexistent_column", "nunique")}) - - assert "does not exist in the data" in str(excinfo.value) - - def test_extra_aggs_with_invalid_function(self): - """Test that an error is raised when an invalid function is specified in extra_aggs.""" - df = pd.DataFrame( - { - cols.customer_id: [1, 2, 3], - cols.unit_spend: [100.0, 200.0, 300.0], - cols.transaction_id: [101, 102, 103], - "segment_name": ["A", "B", "A"], - }, - ) - - with pytest.raises(ValueError) as excinfo: - SegTransactionStats(df, "segment_name", extra_aggs={"invalid_agg": (cols.customer_id, "invalid_function")}) - - assert "not available for column" in str(excinfo.value) - - -class TestHMLSegmentation: - """Tests for the HMLSegmentation class.""" - - @pytest.fixture - def base_df(self): - """Return a base DataFrame for testing.""" - return pd.DataFrame( - { - get_option("column.customer_id"): [1, 2, 3, 4, 5], - cols.unit_spend: [1000, 200, 0, 500, 300], - }, - ) - - # Correctly handles zero spend customers when zero_value_customers is "exclude" - def test_handles_zero_spend_customers_are_excluded_in_result(self, base_df): - """Test that the method correctly handles zero spend customers when zero_value_customers is "exclude".""" - hml_segmentation = HMLSegmentation(base_df, zero_value_customers="exclude") - result_df = hml_segmentation.df - - zero_spend_customer_id = 3 - - assert result_df.loc[1, "segment_name"] == "Heavy" - assert result_df.loc[2, "segment_name"] == "Light" - assert zero_spend_customer_id not in result_df.index - assert result_df.loc[4, "segment_name"] == "Medium" - assert result_df.loc[5, "segment_name"] == "Light" - - # Correctly handles zero spend customers when zero_value_customers is "include_with_light" - def test_handles_zero_spend_customers_include_with_light(self, base_df): - """Test that the method correctly handles zero spend customers when zero_value_customers is "include_with_light".""" - hml_segmentation = HMLSegmentation(base_df, zero_value_customers="include_with_light") - result_df = hml_segmentation.df - - assert result_df.loc[1, "segment_name"] == "Heavy" - assert result_df.loc[2, "segment_name"] == "Light" - assert result_df.loc[3, "segment_name"] == "Light" - assert result_df.loc[4, "segment_name"] == "Medium" - assert result_df.loc[5, "segment_name"] == "Light" - - # Correctly handles zero spend customers when zero_value_customers is "separate_segment" - def test_handles_zero_spend_customers_separate_segment(self, base_df): - """Test that the method correctly handles zero spend customers when zero_value_customers is "separate_segment".""" - hml_segmentation = HMLSegmentation(base_df, zero_value_customers="separate_segment") - result_df = hml_segmentation.df - - assert result_df.loc[1, "segment_name"] == "Heavy" - assert result_df.loc[2, "segment_name"] == "Light" - assert result_df.loc[3, "segment_name"] == "Zero" - assert result_df.loc[4, "segment_name"] == "Medium" - assert result_df.loc[5, "segment_name"] == "Light" - - # Raises ValueError if required columns are missing - def test_raises_value_error_if_required_columns_missing(self, base_df): - """Test that the method raises an error when the DataFrame is missing a required column.""" - with pytest.raises(ValueError): - HMLSegmentation(base_df.drop(columns=[get_option("column.customer_id")])) - - # Validate that the input dataframe is not changed - def test_input_dataframe_not_changed(self, base_df): - """Test that the method does not alter the original DataFrame.""" - original_df = base_df.copy() - - hml_segmentation = HMLSegmentation(base_df) - _ = hml_segmentation.df - - assert original_df.equals(base_df) # Check if the original dataframe is not changed - - def test_alternate_value_col(self, base_df): - """Test that the method correctly segments a DataFrame with an alternate value column.""" - base_df = base_df.rename(columns={cols.unit_spend: cols.unit_qty}) - hml_segmentation = HMLSegmentation(base_df, value_col=cols.unit_qty) - result_df = hml_segmentation.df - - assert result_df.loc[1, "segment_name"] == "Heavy" - assert result_df.loc[2, "segment_name"] == "Light" - assert result_df.loc[4, "segment_name"] == "Medium" - assert result_df.loc[5, "segment_name"] == "Light" - - -class TestRFMSegmentation: - """Tests for the RFMSegmentation class.""" - - @pytest.fixture - def base_df(self): - """Return a base DataFrame for testing.""" - return pd.DataFrame( - { - cols.customer_id: [1, 2, 3, 4, 5], - cols.transaction_id: [101, 102, 103, 104, 105], - cols.unit_spend: [100.0, 200.0, 150.0, 300.0, 250.0], - cols.transaction_date: [ - "2025-03-01", - "2025-02-15", - "2025-01-30", - "2025-03-10", - "2025-02-20", - ], - }, - ) - - @pytest.fixture - def expected_df(self): - """Returns the expected DataFrame for testing segmentation.""" - return pd.DataFrame( - { - "customer_id": [1, 2, 3, 4, 5], - "frequency": [1, 1, 1, 1, 1], - "monetary": [100.0, 200.0, 150.0, 300.0, 250.0], - "r_score": [1, 3, 4, 0, 2], - "f_score": [0, 1, 2, 3, 4], - "m_score": [0, 2, 1, 4, 3], - "rfm_segment": [100, 312, 421, 34, 243], - "fm_segment": [0, 12, 21, 34, 43], - }, - ).set_index("customer_id") - - def test_correct_rfm_segmentation(self, base_df, expected_df): - """Test that the RFM segmentation correctly calculates the RFM scores and segments.""" - current_date = "2025-03-17" - rfm_segmentation = RFMSegmentation(df=base_df, current_date=current_date) - result_df = rfm_segmentation.df - expected_df["recency_days"] = [16, 30, 46, 7, 25] - expected_df["recency_days"] = expected_df["recency_days"].astype(result_df["recency_days"].dtype) - - pd.testing.assert_frame_equal( - result_df.sort_index(), - expected_df.sort_index(), - check_like=True, - ) - - def test_handles_dataframe_with_missing_columns(self): - """Test that the method raises an error when required columns are missing.""" - base_df = pd.DataFrame( - { - cols.customer_id: [1, 2, 3], - cols.unit_spend: [100.0, 200.0, 150.0], - cols.transaction_id: [101, 102, 103], - }, - ) - - with pytest.raises(ValueError): - RFMSegmentation(df=base_df, current_date="2025-03-17") - - def test_single_customer(self): - """Test that the method correctly calculates RFM segmentation for a single customer.""" - df_single_customer = pd.DataFrame( - { - cols.customer_id: [1], - cols.transaction_id: [101], - cols.unit_spend: [200.0], - cols.transaction_date: ["2025-03-01"], - }, - ) - current_date = "2025-03-17" - rfm_segmentation = RFMSegmentation(df=df_single_customer, current_date=current_date) - result_df = rfm_segmentation.df - assert result_df.loc[1, "rfm_segment"] == 0 - - def test_multiple_transactions_per_customer(self): - """Test that the method correctly handles multiple transactions for the same customer.""" - df_multiple_transactions = pd.DataFrame( - { - cols.customer_id: [1, 1, 1, 1, 1], - cols.transaction_id: [101, 102, 103, 104, 105], - cols.unit_spend: [120.0, 250.0, 180.0, 300.0, 220.0], - cols.transaction_date: [ - "2025-03-01", - "2025-02-15", - "2025-01-10", - "2025-03-10", - "2025-02-25", - ], - }, - ) - current_date = "2025-03-17" - rfm_segmentation = RFMSegmentation(df=df_multiple_transactions, current_date=current_date) - result_df = rfm_segmentation.df - - assert result_df.loc[1, "rfm_segment"] == 0 - - def test_calculates_rfm_correctly_for_all_customers(self, base_df): - """Test that RFM scores are calculated correctly for all customers.""" - current_date = "2025-03-17" - expected_customer_count = 5 - rfm_segmentation = RFMSegmentation(df=base_df, current_date=current_date) - result_df = rfm_segmentation.df - - assert len(result_df) == expected_customer_count - assert "rfm_segment" in result_df.columns - - @freeze_time("2025-03-19") - def test_rfm_segmentation_with_no_date(self, base_df, expected_df): - """Test that the RFM segmentation correctly calculates the RFM scores and segments.""" - rfm_segmentation = RFMSegmentation(df=base_df) - result_df = rfm_segmentation.df - expected_df["recency_days"] = [18, 32, 48, 9, 27] - expected_df["recency_days"] = expected_df["recency_days"].astype(result_df["recency_days"].dtype) - - pd.testing.assert_frame_equal( - result_df.sort_index(), - expected_df.sort_index(), - check_like=True, - ) - - def test_invalid_current_date_type(self, base_df): - """Test that RFMSegmentation raises a TypeError when an invalid current_date is provided.""" - with pytest.raises( - TypeError, - match="current_date must be a string in 'YYYY-MM-DD' format, a datetime.date object, or None", - ): - RFMSegmentation(base_df, current_date=12345) - - def test_invalid_df_type(self): - """Test that RFMSegmentation raises a TypeError when df is neither a DataFrame nor an Ibis Table.""" - invalid_df = "this is not a dataframe" - - with pytest.raises(TypeError, match="df must be either a pandas DataFrame or an Ibis Table"): - RFMSegmentation(df=invalid_df, current_date="2025-03-17") - - def test_ibis_table_property(self, base_df): - """Test that ibis_table property returns an Ibis Table.""" - segmentation = RFMSegmentation(df=base_df, current_date="2025-03-17") - - result = segmentation.ibis_table - - assert isinstance(result, ibis.Table), "Expected ibis.Table but got a different type" diff --git a/tests/segmentation/test_hml.py b/tests/segmentation/test_hml.py new file mode 100644 index 0000000..65f6db7 --- /dev/null +++ b/tests/segmentation/test_hml.py @@ -0,0 +1,88 @@ +"""Tests for the HMLSegmentation class.""" + +import pandas as pd +import pytest + +from pyretailscience.options import ColumnHelper, get_option +from pyretailscience.segmentation.hml import HMLSegmentation + +cols = ColumnHelper() + + +class TestHMLSegmentation: + """Tests for the HMLSegmentation class.""" + + @pytest.fixture + def base_df(self): + """Return a base DataFrame for testing.""" + return pd.DataFrame( + { + get_option("column.customer_id"): [1, 2, 3, 4, 5], + cols.unit_spend: [1000, 200, 0, 500, 300], + }, + ) + + # Correctly handles zero spend customers when zero_value_customers is "exclude" + def test_handles_zero_spend_customers_are_excluded_in_result(self, base_df): + """Test that the method correctly handles zero spend customers when zero_value_customers is "exclude".""" + hml_segmentation = HMLSegmentation(base_df, zero_value_customers="exclude") + result_df = hml_segmentation.df + + zero_spend_customer_id = 3 + + assert result_df.loc[1, "segment_name"] == "Heavy" + assert result_df.loc[2, "segment_name"] == "Light" + assert zero_spend_customer_id not in result_df.index + assert result_df.loc[4, "segment_name"] == "Medium" + assert result_df.loc[5, "segment_name"] == "Light" + + # Correctly handles zero spend customers when zero_value_customers is "include_with_light" + def test_handles_zero_spend_customers_include_with_light(self, base_df): + """Test that the method correctly handles zero spend customers when zero_value_customers is "include_with_light".""" + hml_segmentation = HMLSegmentation(base_df, zero_value_customers="include_with_light") + result_df = hml_segmentation.df + + assert result_df.loc[1, "segment_name"] == "Heavy" + assert result_df.loc[2, "segment_name"] == "Light" + assert result_df.loc[3, "segment_name"] == "Light" + assert result_df.loc[4, "segment_name"] == "Medium" + assert result_df.loc[5, "segment_name"] == "Light" + + # Correctly handles zero spend customers when zero_value_customers is "separate_segment" + def test_handles_zero_spend_customers_separate_segment(self, base_df): + """Test that the method correctly handles zero spend customers when zero_value_customers is "separate_segment".""" + hml_segmentation = HMLSegmentation(base_df, zero_value_customers="separate_segment") + result_df = hml_segmentation.df + + assert result_df.loc[1, "segment_name"] == "Heavy" + assert result_df.loc[2, "segment_name"] == "Light" + assert result_df.loc[3, "segment_name"] == "Zero" + assert result_df.loc[4, "segment_name"] == "Medium" + assert result_df.loc[5, "segment_name"] == "Light" + + # Raises ValueError if required columns are missing + def test_raises_value_error_if_required_columns_missing(self, base_df): + """Test that the method raises an error when the DataFrame is missing a required column.""" + with pytest.raises(ValueError): + HMLSegmentation(base_df.drop(columns=[get_option("column.customer_id")])) + + # Validate that the input dataframe is not changed + def test_input_dataframe_not_changed(self, base_df): + """Test that the method does not alter the original DataFrame.""" + original_df = base_df.copy() + + hml_segmentation = HMLSegmentation(base_df) + _ = hml_segmentation.df + + assert original_df.equals(base_df) # Check if the original dataframe is not changed + + def test_alternate_value_col(self, base_df): + """Test that the method correctly segments a DataFrame with an alternate value column.""" + base_df = base_df.rename(columns={cols.unit_spend: cols.unit_qty}) + hml_segmentation = HMLSegmentation(base_df, value_col=cols.unit_qty) + result_df = hml_segmentation.df + + assert result_df.loc[1, "segment_name"] == "Heavy" + assert result_df.loc[2, "segment_name"] == "Light" + assert result_df.loc[4, "segment_name"] == "Medium" + assert result_df.loc[5, "segment_name"] == "Light" diff --git a/tests/segmentation/test_rfm.py b/tests/segmentation/test_rfm.py new file mode 100644 index 0000000..648743c --- /dev/null +++ b/tests/segmentation/test_rfm.py @@ -0,0 +1,151 @@ +"""Tests for the RFMSegmentation class.""" + +import pandas as pd +import pytest +from freezegun import freeze_time + +from pyretailscience.options import ColumnHelper +from pyretailscience.segmentation.rfm import RFMSegmentation + +cols = ColumnHelper() + + +class TestRFMSegmentation: + """Tests for the RFMSegmentation class.""" + + @pytest.fixture + def base_df(self): + """Return a base DataFrame for testing.""" + return pd.DataFrame( + { + cols.customer_id: [1, 2, 3, 4, 5], + cols.transaction_id: [101, 102, 103, 104, 105], + cols.unit_spend: [100.0, 200.0, 150.0, 300.0, 250.0], + cols.transaction_date: [ + "2025-03-01", + "2025-02-15", + "2025-01-30", + "2025-03-10", + "2025-02-20", + ], + }, + ) + + @pytest.fixture + def expected_df(self): + """Returns the expected DataFrame for testing segmentation.""" + return pd.DataFrame( + { + "customer_id": [1, 2, 3, 4, 5], + "frequency": [1, 1, 1, 1, 1], + "monetary": [100.0, 200.0, 150.0, 300.0, 250.0], + "r_score": [1, 3, 4, 0, 2], + "f_score": [0, 1, 2, 3, 4], + "m_score": [0, 2, 1, 4, 3], + "rfm_segment": [100, 312, 421, 34, 243], + "fm_segment": [0, 12, 21, 34, 43], + }, + ).set_index("customer_id") + + def test_correct_rfm_segmentation(self, base_df, expected_df): + """Test that the RFM segmentation correctly calculates the RFM scores and segments.""" + current_date = "2025-03-17" + rfm_segmentation = RFMSegmentation(df=base_df, current_date=current_date) + result_df = rfm_segmentation.df + expected_df["recency_days"] = [16, 30, 46, 7, 25] + expected_df["recency_days"] = expected_df["recency_days"].astype(result_df["recency_days"].dtype) + + pd.testing.assert_frame_equal( + result_df.sort_index(), + expected_df.sort_index(), + check_like=True, + ) + + def test_handles_dataframe_with_missing_columns(self): + """Test that the method raises an error when required columns are missing.""" + base_df = pd.DataFrame( + { + cols.customer_id: [1, 2, 3], + cols.unit_spend: [100.0, 200.0, 150.0], + cols.transaction_id: [101, 102, 103], + }, + ) + + with pytest.raises(ValueError): + RFMSegmentation(df=base_df, current_date="2025-03-17") + + def test_single_customer(self): + """Test that the method correctly calculates RFM segmentation for a single customer.""" + df_single_customer = pd.DataFrame( + { + cols.customer_id: [1], + cols.transaction_id: [101], + cols.unit_spend: [200.0], + cols.transaction_date: ["2025-03-01"], + }, + ) + current_date = "2025-03-17" + rfm_segmentation = RFMSegmentation(df=df_single_customer, current_date=current_date) + result_df = rfm_segmentation.df + assert result_df.loc[1, "rfm_segment"] == 0 + + def test_multiple_transactions_per_customer(self): + """Test that the method correctly handles multiple transactions for the same customer.""" + df_multiple_transactions = pd.DataFrame( + { + cols.customer_id: [1, 1, 1, 1, 1], + cols.transaction_id: [101, 102, 103, 104, 105], + cols.unit_spend: [120.0, 250.0, 180.0, 300.0, 220.0], + cols.transaction_date: [ + "2025-03-01", + "2025-02-15", + "2025-01-10", + "2025-03-10", + "2025-02-25", + ], + }, + ) + current_date = "2025-03-17" + rfm_segmentation = RFMSegmentation(df=df_multiple_transactions, current_date=current_date) + result_df = rfm_segmentation.df + + assert result_df.loc[1, "rfm_segment"] == 0 + + def test_calculates_rfm_correctly_for_all_customers(self, base_df): + """Test that RFM scores are calculated correctly for all customers.""" + current_date = "2025-03-17" + expected_customer_count = 5 + rfm_segmentation = RFMSegmentation(df=base_df, current_date=current_date) + result_df = rfm_segmentation.df + + assert len(result_df) == expected_customer_count + assert "rfm_segment" in result_df.columns + + @freeze_time("2025-03-19") + def test_rfm_segmentation_with_no_date(self, base_df, expected_df): + """Test that the RFM segmentation correctly calculates the RFM scores and segments.""" + rfm_segmentation = RFMSegmentation(df=base_df) + result_df = rfm_segmentation.df + expected_df["recency_days"] = [18, 32, 48, 9, 27] + expected_df["recency_days"] = expected_df["recency_days"].astype(result_df["recency_days"].dtype) + + pd.testing.assert_frame_equal( + result_df.sort_index(), + expected_df.sort_index(), + check_like=True, + ) + + def test_invalid_current_date_type(self, base_df): + """Test that RFMSegmentation raises a TypeError when an invalid current_date is provided.""" + with pytest.raises( + TypeError, + match="current_date must be a string in 'YYYY-MM-DD' format, a datetime.date object, or None", + ): + RFMSegmentation(base_df, current_date=12345) + + def test_invalid_df_type(self): + """Test that RFMSegmentation raises a TypeError when df is neither a DataFrame nor an Ibis Table.""" + invalid_df = "this is not a dataframe" + + with pytest.raises(TypeError, match="df must be either a pandas DataFrame or an Ibis Table"): + RFMSegmentation(df=invalid_df, current_date="2025-03-17") diff --git a/tests/segmentation/test_segstats.py b/tests/segmentation/test_segstats.py new file mode 100644 index 0000000..a6902bd --- /dev/null +++ b/tests/segmentation/test_segstats.py @@ -0,0 +1,320 @@ +"""Tests for the SegTransactionStats class.""" + +import numpy as np +import pandas as pd +import pytest + +from pyretailscience.options import ColumnHelper, get_option +from pyretailscience.segmentation.segstats import SegTransactionStats + +cols = ColumnHelper() + + +class TestCalcSegStats: + """Tests for the _calc_seg_stats method.""" + + @pytest.fixture + def base_df(self): + """Return a base DataFrame for testing.""" + return pd.DataFrame( + { + cols.customer_id: [1, 2, 3, 4, 5], + cols.unit_spend: [100.0, 200.0, 150.0, 300.0, 250.0], + cols.transaction_id: [101, 102, 103, 104, 105], + "segment_name": ["A", "B", "A", "B", "A"], + cols.unit_qty: [10, 20, 15, 30, 25], + }, + ) + + def test_correctly_calculates_revenue_transactions_customers_per_segment(self, base_df): + """Test that the method correctly calculates at the transaction-item level.""" + expected_output = pd.DataFrame( + { + "segment_name": ["A", "B", "Total"], + cols.agg_unit_spend: [500.0, 500.0, 1000.0], + cols.agg_transaction_id: [3, 2, 5], + cols.agg_customer_id: [3, 2, 5], + cols.agg_unit_qty: [50, 50, 100], + cols.calc_spend_per_cust: [166.666667, 250.0, 200.0], + cols.calc_spend_per_trans: [166.666667, 250.0, 200.0], + cols.calc_trans_per_cust: [1.0, 1.0, 1.0], + cols.calc_price_per_unit: [10.0, 10.0, 10.0], + cols.calc_units_per_trans: [16.666667, 25.0, 20.0], + cols.customers_pct: [0.6, 0.4, 1.0], + }, + ) + segment_stats = ( + SegTransactionStats(base_df, "segment_name").df.sort_values("segment_name").reset_index(drop=True) + ) + pd.testing.assert_frame_equal(segment_stats, expected_output) + + def test_correctly_calculates_revenue_transactions_customers(self): + """Test that the method correctly calculates at the transaction level.""" + df = pd.DataFrame( + { + get_option("column.customer_id"): [1, 2, 3, 4, 5], + cols.unit_spend: [100.0, 200.0, 150.0, 300.0, 250.0], + cols.transaction_id: [101, 102, 103, 104, 105], + "segment_name": ["A", "B", "A", "B", "A"], + }, + ) + + expected_output = pd.DataFrame( + { + "segment_name": ["A", "B", "Total"], + cols.agg_unit_spend: [500.0, 500.0, 1000.0], + cols.agg_transaction_id: [3, 2, 5], + cols.agg_customer_id: [3, 2, 5], + cols.calc_spend_per_cust: [166.666667, 250.0, 200.0], + cols.calc_spend_per_trans: [166.666667, 250.0, 200.0], + cols.calc_trans_per_cust: [1.0, 1.0, 1.0], + cols.customers_pct: [0.6, 0.4, 1.0], + }, + ) + + segment_stats = SegTransactionStats(df, "segment_name").df.sort_values("segment_name").reset_index(drop=True) + pd.testing.assert_frame_equal(segment_stats, expected_output) + + def test_handles_dataframe_with_one_segment(self, base_df): + """Test that the method correctly handles a DataFrame with only one segment.""" + df = base_df.copy() + df["segment_name"] = "A" + + expected_output = pd.DataFrame( + { + "segment_name": ["A", "Total"], + cols.agg_unit_spend: [1000.0, 1000.0], + cols.agg_transaction_id: [5, 5], + cols.agg_customer_id: [5, 5], + cols.agg_unit_qty: [100, 100], + cols.calc_spend_per_cust: [200.0, 200.0], + cols.calc_spend_per_trans: [200.0, 200.0], + cols.calc_trans_per_cust: [1.0, 1.0], + cols.calc_price_per_unit: [10.0, 10.0], + cols.calc_units_per_trans: [20.0, 20.0], + cols.customers_pct: [1.0, 1.0], + }, + ) + + segment_stats = SegTransactionStats(df, "segment_name").df.sort_values("segment_name").reset_index(drop=True) + pd.testing.assert_frame_equal(segment_stats, expected_output) + + def test_handles_dataframe_with_zero_net_units(self, base_df): + """Test that the method correctly handles a DataFrame with a segment with net zero units.""" + df = base_df.copy() + df[cols.unit_qty] = [10, 20, 15, 30, -25] + + expected_output = pd.DataFrame( + { + "segment_name": ["A", "B", "Total"], + cols.agg_unit_spend: [500.0, 500.0, 1000.0], + cols.agg_transaction_id: [3, 2, 5], + cols.agg_customer_id: [3, 2, 5], + cols.agg_unit_qty: [0, 50, 50], + cols.calc_spend_per_cust: [166.666667, 250.0, 200.0], + cols.calc_spend_per_trans: [166.666667, 250.0, 200.0], + cols.calc_trans_per_cust: [1.0, 1.0, 1.0], + cols.calc_price_per_unit: [np.nan, 10.0, 20.0], + cols.calc_units_per_trans: [0, 25.0, 10.0], + cols.customers_pct: [0.6, 0.4, 1.0], + }, + ) + segment_stats = SegTransactionStats(df, "segment_name").df.sort_values("segment_name").reset_index(drop=True) + + pd.testing.assert_frame_equal(segment_stats, expected_output) + + def test_excludes_total_row_when_calc_total_false(self, base_df): + """Test that the method excludes the total row when calc_total=False.""" + expected_output = pd.DataFrame( + { + "segment_name": ["A", "B"], + cols.agg_unit_spend: [500.0, 500.0], + cols.agg_transaction_id: [3, 2], + cols.agg_customer_id: [3, 2], + cols.agg_unit_qty: [50, 50], + cols.calc_spend_per_cust: [166.666667, 250.0], + cols.calc_spend_per_trans: [166.666667, 250.0], + cols.calc_trans_per_cust: [1.0, 1.0], + cols.calc_price_per_unit: [10.0, 10.0], + cols.calc_units_per_trans: [16.666667, 25.0], + cols.customers_pct: [1.0, 1.0], + }, + ) + + segment_stats = ( + SegTransactionStats(base_df, "segment_name", calc_total=False) + .df.sort_values("segment_name") + .reset_index(drop=True) + ) + + pd.testing.assert_frame_equal(segment_stats, expected_output) + + +class TestSegTransactionStats: + """Tests for the SegTransactionStats class.""" + + def test_handles_empty_dataframe_with_errors(self): + """Test that the method raises an error when the DataFrame is missing a required column.""" + df = pd.DataFrame( + columns=[cols.unit_spend, cols.transaction_id, cols.unit_qty], + ) + + with pytest.raises(ValueError): + SegTransactionStats(df, "segment_name") + + def test_multiple_segment_columns(self): + """Test that the class correctly handles multiple segment columns.""" + df = pd.DataFrame( + { + cols.customer_id: [1, 1, 2, 2, 3, 3], + cols.unit_spend: [100.0, 150.0, 200.0, 250.0, 300.0, 350.0], + cols.transaction_id: [101, 102, 103, 104, 105, 106], + "segment_name": ["A", "A", "B", "B", "A", "A"], + "region": ["North", "North", "South", "South", "East", "East"], + }, + ) + + # Test with a list of segment columns + seg_stats = SegTransactionStats(df, ["segment_name", "region"]) + + # Create expected DataFrame with the combinations actually produced + expected_output = pd.DataFrame( + { + "segment_name": ["A", "A", "B", "Total"], + "region": ["East", "North", "South", "Total"], + cols.agg_unit_spend: [650.0, 250.0, 450.0, 1350.0], + cols.agg_transaction_id: [2, 2, 2, 6], + cols.agg_customer_id: [1, 1, 1, 3], + cols.calc_spend_per_cust: [650.0, 250.0, 450.0, 450.0], + cols.calc_spend_per_trans: [325.0, 125.0, 225.0, 225.0], + cols.calc_trans_per_cust: [2.0, 2.0, 2.0, 2.0], + cols.customers_pct: [1 / 3, 1 / 3, 1 / 3, 1.0], + }, + ) + + # Sort both dataframes by the segment columns for consistent comparison + result_df = seg_stats.df.sort_values(["segment_name", "region"]).reset_index(drop=True) + expected_output = expected_output.sort_values(["segment_name", "region"]).reset_index(drop=True) + + # Check that both segment columns are in the result + assert "segment_name" in result_df.columns + assert "region" in result_df.columns + + # Check number of rows - the implementation only returns actual combinations that exist in data + # plus the Total row, not all possible combinations + assert len(result_df) == len(expected_output) + + # Use pandas testing to compare the dataframes + pd.testing.assert_frame_equal(result_df[expected_output.columns], expected_output) + + def test_plot_with_multiple_segment_columns(self): + """Test that plotting with multiple segment columns raises a ValueError.""" + df = pd.DataFrame( + { + cols.customer_id: [1, 2, 3], + cols.unit_spend: [100.0, 200.0, 300.0], + cols.transaction_id: [101, 102, 103], + "segment_name": ["A", "B", "A"], + "region": ["North", "South", "East"], + }, + ) + + seg_stats = SegTransactionStats(df, ["segment_name", "region"]) + + with pytest.raises(ValueError) as excinfo: + seg_stats.plot("spend") + + assert "Plotting is only supported for a single segment column" in str(excinfo.value) + + def test_extra_aggs_functionality(self): + """Test that the extra_aggs parameter works correctly.""" + # Constants for expected values + segment_a_store_count = 3 # Segment A has stores 1, 2, 4 + segment_b_store_count = 2 # Segment B has stores 1, 3 + total_store_count = 4 # Total has stores 1, 2, 3, 4 + + segment_a_product_count = 3 # Segment A has products 10, 20, 40 + segment_b_product_count = 2 # Segment B has products 10, 30 + total_product_count = 4 # Total has products 10, 20, 30, 40 + df = pd.DataFrame( + { + cols.customer_id: [1, 1, 2, 2, 3, 3], + cols.unit_spend: [100.0, 150.0, 200.0, 250.0, 300.0, 350.0], + cols.transaction_id: [101, 102, 103, 104, 105, 106], + "segment_name": ["A", "A", "B", "B", "A", "A"], + "store_id": [1, 2, 1, 3, 2, 4], + "product_id": [10, 20, 10, 30, 20, 40], + }, + ) + + # Test with a single extra aggregation + seg_stats = SegTransactionStats( + df, + "segment_name", + extra_aggs={"distinct_stores": ("store_id", "nunique")}, + ) + + # Verify the extra column exists and has correct values + assert "distinct_stores" in seg_stats.df.columns + + # Sort by segment_name to ensure consistent order + result_df = seg_stats.df.sort_values("segment_name").reset_index(drop=True) + + assert result_df.loc[0, "distinct_stores"] == segment_a_store_count # Segment A + assert result_df.loc[1, "distinct_stores"] == segment_b_store_count # Segment B + assert result_df.loc[2, "distinct_stores"] == total_store_count # Total + + # Test with multiple extra aggregations + seg_stats_multi = SegTransactionStats( + df, + "segment_name", + extra_aggs={ + "distinct_stores": ("store_id", "nunique"), + "distinct_products": ("product_id", "nunique"), + }, + ) + + # Verify both extra columns exist + assert "distinct_stores" in seg_stats_multi.df.columns + assert "distinct_products" in seg_stats_multi.df.columns + + # Sort by segment_name to ensure consistent order + result_df_multi = seg_stats_multi.df.sort_values("segment_name").reset_index(drop=True) + + assert result_df_multi["distinct_products"].to_list() == [ + segment_a_product_count, + segment_b_product_count, + total_product_count, + ] + + def test_extra_aggs_with_invalid_column(self): + """Test that an error is raised when an invalid column is specified in extra_aggs.""" + df = pd.DataFrame( + { + cols.customer_id: [1, 2, 3], + cols.unit_spend: [100.0, 200.0, 300.0], + cols.transaction_id: [101, 102, 103], + "segment_name": ["A", "B", "A"], + }, + ) + + with pytest.raises(ValueError) as excinfo: + SegTransactionStats(df, "segment_name", extra_aggs={"invalid_agg": ("nonexistent_column", "nunique")}) + + assert "does not exist in the data" in str(excinfo.value) + + def test_extra_aggs_with_invalid_function(self): + """Test that an error is raised when an invalid function is specified in extra_aggs.""" + df = pd.DataFrame( + { + cols.customer_id: [1, 2, 3], + cols.unit_spend: [100.0, 200.0, 300.0], + cols.transaction_id: [101, 102, 103], + "segment_name": ["A", "B", "A"], + }, + ) + + with pytest.raises(ValueError) as excinfo: + SegTransactionStats(df, "segment_name", extra_aggs={"invalid_agg": (cols.customer_id, "invalid_function")}) + + assert "not available for column" in str(excinfo.value) diff --git a/tests/segmentation/test_threshold.py b/tests/segmentation/test_threshold.py new file mode 100644 index 0000000..10981ee --- /dev/null +++ b/tests/segmentation/test_threshold.py @@ -0,0 +1,187 @@ +"""Tests for the ThresholdSegmentation class.""" + +import pandas as pd +import pytest + +from pyretailscience.options import ColumnHelper, get_option +from pyretailscience.segmentation.threshold import ThresholdSegmentation + +cols = ColumnHelper() + + +class TestThresholdSegmentation: + """Tests for the ThresholdSegmentation class.""" + + def test_correct_segmentation(self): + """Test that the method correctly segments customers based on given thresholds and segments.""" + df = pd.DataFrame( + { + get_option("column.customer_id"): [1, 2, 3, 4], + cols.unit_spend: [100, 200, 300, 400], + }, + ) + thresholds = [0.5, 1] + segments = ["Low", "High"] + seg = ThresholdSegmentation( + df=df, + thresholds=thresholds, + segments=segments, + value_col=cols.unit_spend, + zero_value_customers="exclude", + ) + result_df = seg.df + assert result_df.loc[1, "segment_name"] == "Low" + assert result_df.loc[2, "segment_name"] == "Low" + assert result_df.loc[3, "segment_name"] == "High" + assert result_df.loc[4, "segment_name"] == "High" + + def test_single_customer(self): + """Test that the method correctly segments a DataFrame with only one customer.""" + df = pd.DataFrame({get_option("column.customer_id"): [1], cols.unit_spend: [100]}) + thresholds = [0.5, 1] + segments = ["Low"] + with pytest.raises(ValueError): + ThresholdSegmentation( + df=df, + thresholds=thresholds, + segments=segments, + ) + + def test_correct_aggregation_function(self): + """Test that the correct aggregation function is applied for product_id custom segmentation.""" + df = pd.DataFrame( + { + cols.customer_id: [1, 2, 2, 2, 2, 2, 3, 3, 3, 4, 4, 5, 5, 5], + "product_id": [3, 4, 4, 6, 1, 5, 7, 2, 2, 3, 2, 3, 4, 1], + }, + ) + value_col = "product_id" + agg_func = "nunique" + + my_seg = ThresholdSegmentation( + df=df, + value_col=value_col, + agg_func=agg_func, + thresholds=[0.2, 0.8, 1], + segments=["Low", "Medium", "High"], + zero_value_customers="separate_segment", + ) + + expected_result = pd.DataFrame( + { + cols.customer_id: [1, 2, 3, 4, 5], + "product_id": [1, 4, 2, 2, 3], + "segment_name": ["Low", "High", "Medium", "Medium", "Medium"], + }, + ) + pd.testing.assert_frame_equal(my_seg.df.sort_values(cols.customer_id).reset_index(), expected_result) + + def test_correctly_checks_segment_data(self): + """Test that the method correctly merges segment data back into the original DataFrame.""" + df = pd.DataFrame( + { + get_option("column.customer_id"): [1, 2, 3, 4, 5], + cols.unit_spend: [100, 200, 0, 150, 0], + }, + ) + value_col = cols.unit_spend + agg_func = "sum" + thresholds = [0.33, 0.66, 1] + segments = ["Low", "Medium", "High"] + zero_value_customers = "separate_segment" + + # Create ThresholdSegmentation instance + threshold_seg = ThresholdSegmentation( + df=df, + value_col=value_col, + agg_func=agg_func, + thresholds=thresholds, + segments=segments, + zero_value_customers=zero_value_customers, + ) + + # Call add_segment method + segmented_df = threshold_seg.add_segment(df) + + # Assert the correct segment_name + expected_df = pd.DataFrame( + { + get_option("column.customer_id"): [1, 2, 3, 4, 5], + cols.unit_spend: [100, 200, 0, 150, 0], + "segment_name": ["Low", "High", "Zero", "Medium", "Zero"], + }, + ) + pd.testing.assert_frame_equal(segmented_df, expected_df) + + def test_handles_dataframe_with_duplicate_customer_id_entries(self): + """Test that the method correctly handles a DataFrame with duplicate customer_id entries.""" + df = pd.DataFrame( + { + get_option("column.customer_id"): [1, 2, 3, 1, 2, 3], + cols.unit_spend: [100, 200, 300, 150, 250, 350], + }, + ) + + my_seg = ThresholdSegmentation( + df=df, + value_col=cols.unit_spend, + agg_func="sum", + thresholds=[0.5, 0.8, 1], + segments=["Light", "Medium", "Heavy"], + zero_value_customers="include_with_light", + ) + + result_df = my_seg.add_segment(df) + assert len(result_df) == len(df) + + def test_thresholds_not_unique(self): + """Test that the method raises an error when the thresholds are not unique.""" + df = pd.DataFrame( + { + get_option("column.customer_id"): [1, 2, 3, 4, 5], + cols.unit_spend: [100, 200, 300, 400, 500], + }, + ) + thresholds = [0.5, 0.5, 0.8, 1] + segments = ["Low", "Medium", "High"] + + with pytest.raises(ValueError): + ThresholdSegmentation(df, thresholds, segments) + + def test_thresholds_too_few_segments(self): + """Test that the method raises an error when there are too few/many segments for the number of thresholds.""" + df = pd.DataFrame( + { + get_option("column.customer_id"): [1, 2, 3, 4, 5], + cols.unit_spend: [100, 200, 300, 400, 500], + }, + ) + thresholds = [0.4, 0.6, 0.8, 1] + segments = ["Low", "High"] + + with pytest.raises(ValueError): + ThresholdSegmentation(df, thresholds, segments) + + segments = ["Low", "Medium", "High"] + + with pytest.raises(ValueError): + ThresholdSegmentation(df, thresholds, segments) + + def test_thresholds_too_too_few_thresholds(self): + """Test that the method raises an error when there are too few/many thresholds for the number of segments.""" + df = pd.DataFrame( + { + get_option("column.customer_id"): [1, 2, 3, 4, 5], + cols.unit_spend: [100, 200, 300, 400, 500], + }, + ) + thresholds = [0.4, 1] + segments = ["Low", "Medium", "High"] + + with pytest.raises(ValueError): + ThresholdSegmentation(df, thresholds, segments) + + thresholds = [0.2, 0.5, 0.6, 0.8, 1] + + with pytest.raises(ValueError): + ThresholdSegmentation(df, thresholds, segments)