From 79a08293e27f85498e5e54410c676f0bad4fda8a Mon Sep 17 00:00:00 2001 From: Murray Vanwyk Date: Fri, 5 Jul 2024 19:33:43 +0200 Subject: [PATCH 1/2] feat: add input validation and tests in HMLSegmentation --- pyretailscience/segmentation.py | 10 ++- tests/test_segmentation.py | 109 +++++++++++++++++++++++++++++++- 2 files changed, 117 insertions(+), 2 deletions(-) diff --git a/pyretailscience/segmentation.py b/pyretailscience/segmentation.py index df2d67ad..00de7b48 100644 --- a/pyretailscience/segmentation.py +++ b/pyretailscience/segmentation.py @@ -88,6 +88,9 @@ def __init__( ValueError: If the dataframe is missing the columns "customer_id" or `value_col`, or these columns contain null values. """ + if df.empty: + raise ValueError("Input DataFrame is empty") + required_cols = ["customer_id", value_col] contract = CustomContract( df, @@ -99,6 +102,11 @@ def __init__( msg = f"The dataframe requires the columns {required_cols} and they must be non-null" raise ValueError(msg) + hml_cuts = [0.500, 0.800, 1] + if len(df) < len(hml_cuts): + msg = f"There are {len(df)} customers, which is less than is less than the number of segment thresholds." + raise ValueError(msg) + # Group by customer_id and calculate total_spend grouped_df = df.groupby("customer_id")[value_col].sum().to_frame(value_col) @@ -114,7 +122,7 @@ def __init__( # Create a new column 'segment' based on the total_spend hml_df["segment_name"] = pd.qcut( hml_df[value_col], - q=[0, 0.500, 0.800, 1], + q=[0, *hml_cuts], labels=["Light", "Medium", "Heavy"], ) diff --git a/tests/test_segmentation.py b/tests/test_segmentation.py index 22b215a5..a299980c 100644 --- a/tests/test_segmentation.py +++ b/tests/test_segmentation.py @@ -3,7 +3,7 @@ import pandas as pd import pytest -from pyretailscience.segmentation import SegTransactionStats +from pyretailscience.segmentation import HMLSegmentation, SegTransactionStats class TestCalcSegStats: @@ -99,3 +99,110 @@ def test_handles_empty_dataframe_with_errors(self): with pytest.raises(ValueError): SegTransactionStats(df, "segment_id") + + +class TestHMLSegmentation: + """Tests for the HMLSegmentation class.""" + + @pytest.fixture() + def base_df(self): + """Return a base DataFrame for testing.""" + return pd.DataFrame({"customer_id": [1, 2, 3, 4, 5], "total_price": [1000, 200, 0, 500, 300]}) + + def test_no_transactions(self): + """Test that the method raises an error when there are no transactions.""" + data = {"customer_id": [], "total_price": []} + df = pd.DataFrame(data) + with pytest.raises(ValueError): + HMLSegmentation(df) + + # Correctly handles zero spend customers when zero_value_customers is "exclude" + def test_handles_zero_spend_customers_are_excluded_in_result(self, base_df): + """Test that the method correctly handles zero spend customers when zero_value_customers is "exclude".""" + hml_segmentation = HMLSegmentation(base_df, zero_value_customers="exclude") + result_df = hml_segmentation.df + + zero_spend_customer_id = 3 + + assert result_df.loc[1, "segment_name"] == "Heavy" + assert result_df.loc[1, "segment_id"] == "H" + assert result_df.loc[2, "segment_name"] == "Light" + assert result_df.loc[2, "segment_id"] == "L" + assert zero_spend_customer_id not in result_df.index + assert result_df.loc[4, "segment_name"] == "Medium" + assert result_df.loc[4, "segment_id"] == "M" + assert result_df.loc[5, "segment_name"] == "Light" + assert result_df.loc[5, "segment_id"] == "L" + + # Correctly handles zero spend customers when zero_value_customers is "include_with_light" + def test_handles_zero_spend_customers_include_with_light(self, base_df): + """Test that the method correctly handles zero spend customers when zero_value_customers is "include_with_light".""" + hml_segmentation = HMLSegmentation(base_df, zero_value_customers="include_with_light") + result_df = hml_segmentation.df + + assert result_df.loc[1, "segment_name"] == "Heavy" + assert result_df.loc[1, "segment_id"] == "H" + assert result_df.loc[2, "segment_name"] == "Light" + assert result_df.loc[2, "segment_id"] == "L" + assert result_df.loc[3, "segment_name"] == "Light" + assert result_df.loc[3, "segment_id"] == "L" + assert result_df.loc[4, "segment_name"] == "Medium" + assert result_df.loc[4, "segment_id"] == "M" + assert result_df.loc[5, "segment_name"] == "Light" + assert result_df.loc[5, "segment_id"] == "L" + + # Correctly handles zero spend customers when zero_value_customers is "separate_segment" + def test_handles_zero_spend_customers_separate_segment(self, base_df): + """Test that the method correctly handles zero spend customers when zero_value_customers is "separate_segment".""" + hml_segmentation = HMLSegmentation(base_df, zero_value_customers="separate_segment") + result_df = hml_segmentation.df + + assert result_df.loc[1, "segment_name"] == "Heavy" + assert result_df.loc[1, "segment_id"] == "H" + assert result_df.loc[2, "segment_name"] == "Light" + assert result_df.loc[2, "segment_id"] == "L" + assert result_df.loc[3, "segment_name"] == "Zero" + assert result_df.loc[3, "segment_id"] == "Z" + assert result_df.loc[4, "segment_name"] == "Medium" + assert result_df.loc[4, "segment_id"] == "M" + assert result_df.loc[5, "segment_name"] == "Light" + assert result_df.loc[5, "segment_id"] == "L" + + # Raises ValueError if required columns are missing + def test_raises_value_error_if_required_columns_missing(self, base_df): + """Test that the method raises an error when the DataFrame is missing a required column.""" + with pytest.raises(ValueError): + HMLSegmentation(base_df.drop(columns=["customer_id"])) + + # DataFrame with only one customer + def test_segments_customer_single(self): + """Test that the method correctly segments a DataFrame with only one customer.""" + data = {"customer_id": [1], "total_price": [0]} + df = pd.DataFrame(data) + with pytest.raises(ValueError): + HMLSegmentation(df) + + # Validate that the input dataframe is not changed + def test_input_dataframe_not_changed(self, base_df): + """Test that the method does not alter the original DataFrame.""" + original_df = base_df.copy() + + hml_segmentation = HMLSegmentation(base_df) + _ = hml_segmentation.df + + assert original_df.equals(base_df) # Check if the original dataframe is not changed + + def test_alternate_value_col(self, base_df): + """Test that the method correctly segments a DataFrame with an alternate value column.""" + base_df = base_df.rename(columns={"total_price": "quantity"}) + hml_segmentation = HMLSegmentation(base_df, value_col="quantity") + result_df = hml_segmentation.df + + assert result_df.loc[1, "segment_name"] == "Heavy" + assert result_df.loc[1, "segment_id"] == "H" + assert result_df.loc[2, "segment_name"] == "Light" + assert result_df.loc[2, "segment_id"] == "L" + assert result_df.loc[4, "segment_name"] == "Medium" + assert result_df.loc[4, "segment_id"] == "M" + assert result_df.loc[5, "segment_name"] == "Light" + assert result_df.loc[5, "segment_id"] == "L" From fa6d7d7d0845c433705dc8a3600fdae0abc56b68 Mon Sep 17 00:00:00 2001 From: Murray Vanwyk Date: Tue, 9 Jul 2024 16:32:15 +0200 Subject: [PATCH 2/2] feat: added treshold segmentation creation --- pyretailscience/segmentation.py | 91 +++++++++++---- tests/test_segmentation.py | 189 +++++++++++++++++++++++++++++++- 2 files changed, 259 insertions(+), 21 deletions(-) diff --git a/pyretailscience/segmentation.py b/pyretailscience/segmentation.py index 00de7b48..86093829 100644 --- a/pyretailscience/segmentation.py +++ b/pyretailscience/segmentation.py @@ -67,20 +67,30 @@ def __init__(self, df: pd.DataFrame) -> None: self.df = df[["customer_id", "segment_name", "segment_id"]].set_index("customer_id") -class HMLSegmentation(BaseSegmentation): - """Segments customers into Heavy, Medium, Light and Zero spenders based on the total spend.""" +class ThresholdSegmentation(BaseSegmentation): + """Segments customers based on user-defined thresholds and segments.""" def __init__( self, df: pd.DataFrame, + thresholds: list[float], + segments: dict[any, str], value_col: str = "total_price", + agg_func: str = "sum", + zero_segment_name: str = "Zero", + zero_segment_id: str = "Z", zero_value_customers: Literal["separate_segment", "exclude", "include_with_light"] = "separate_segment", ) -> None: - """Segments customers into Heavy, Medium, Light and Zero spenders based on the total spend. + """Segments customers based on user-defined thresholds and segments. Args: df (pd.DataFrame): A dataframe with the transaction data. The dataframe must contain a customer_id column. - value_col (str, optional): The column to use for the segmentation. Defaults to "total_price". + thresholds (List[float]): The percentile thresholds for segmentation. + segments (Dict[str, str]): A dictionary where keys are segment IDs and values are segment names. + value_col (str): The column to use for the segmentation. + agg_func (str, optional): The aggregation function to use when grouping by customer_id. Defaults to "sum". + zero_segment_name (str, optional): The name of the segment for customers with zero spend. Defaults to "Zero". + zero_segment_id (str, optional): The ID of the segment for customers with zero spend. Defaults to "Z". zero_value_customers (Literal["separate_segment", "exclude", "include_with_light"], optional): How to handle customers with zero spend. Defaults to "separate_segment". @@ -102,38 +112,79 @@ def __init__( msg = f"The dataframe requires the columns {required_cols} and they must be non-null" raise ValueError(msg) - hml_cuts = [0.500, 0.800, 1] - if len(df) < len(hml_cuts): - msg = f"There are {len(df)} customers, which is less than is less than the number of segment thresholds." + if len(df) < len(thresholds): + msg = f"There are {len(df)} customers, which is less than the number of segment thresholds." raise ValueError(msg) + if set(thresholds) != set(thresholds): + raise ValueError("The thresholds must be unique.") + + thresholds = sorted(thresholds) + if thresholds[0] != 0: + thresholds = [0, *thresholds] + if thresholds[-1] != 1: + thresholds.append(1) + + if len(thresholds) - 1 != len(segments): + raise ValueError("The number of thresholds must match the number of segments.") + # Group by customer_id and calculate total_spend - grouped_df = df.groupby("customer_id")[value_col].sum().to_frame(value_col) + grouped_df = df.groupby("customer_id")[value_col].agg(agg_func).to_frame(value_col) # Separate customers with zero spend - hml_df = grouped_df + self.df = grouped_df if zero_value_customers in ["separate_segment", "exclude"]: zero_idx = grouped_df[value_col] == 0 - zero_cust_df = grouped_df[zero_idx] - zero_cust_df["segment_name"] = "Zero" + zero_cust_df = grouped_df[zero_idx].copy() + zero_cust_df["segment_name"] = zero_segment_name + zero_cust_df["segment_id"] = zero_segment_id - hml_df = grouped_df[~zero_idx] + self.df = grouped_df[~zero_idx] # Create a new column 'segment' based on the total_spend - hml_df["segment_name"] = pd.qcut( - hml_df[value_col], - q=[0, *hml_cuts], - labels=["Light", "Medium", "Heavy"], + labels = list(segments.values()) + + self.df["segment_name"] = pd.qcut( + self.df[value_col], + q=thresholds, + labels=labels, ) + self.df["segment_id"] = self.df["segment_name"].map({v: k for k, v in segments.items()}) + if zero_value_customers == "separate_segment": - hml_df = pd.concat([hml_df, zero_cust_df]) + self.df = pd.concat([self.df, zero_cust_df]) + - segment_code_map = {"Light": "L", "Medium": "M", "Heavy": "H", "Zero": "Z"} +class HMLSegmentation(ThresholdSegmentation): + """Segments customers into Heavy, Medium, Light and Zero spenders based on the total spend.""" - hml_df["segment_id"] = hml_df["segment_name"].map(segment_code_map) + def __init__( + self, + df: pd.DataFrame, + value_col: str = "total_price", + agg_func: str = "sum", + zero_value_customers: Literal["separate_segment", "exclude", "include_with_light"] = "separate_segment", + ) -> None: + """Segments customers into Heavy, Medium, Light and Zero spenders based on the total spend. - self.df = hml_df + Args: + df (pd.DataFrame): A dataframe with the transaction data. The dataframe must contain a customer_id column. + value_col (str, optional): The column to use for the segmentation. Defaults to "total_price". + agg_func (str, optional): The aggregation function to use when grouping by customer_id. Defaults to "sum". + zero_value_customers (Literal["separate_segment", "exclude", "include_with_light"], optional): How to handle + customers with zero spend. Defaults to "separate_segment". + """ + thresholds = [0.500, 0.800, 1] + segments = {"L": "Light", "M": "Medium", "H": "Heavy"} + super().__init__( + df=df, + value_col=value_col, + agg_func=agg_func, + thresholds=thresholds, + segments=segments, + zero_value_customers=zero_value_customers, + ) class SegTransactionStats: diff --git a/tests/test_segmentation.py b/tests/test_segmentation.py index a299980c..cc923a9d 100644 --- a/tests/test_segmentation.py +++ b/tests/test_segmentation.py @@ -3,7 +3,7 @@ import pandas as pd import pytest -from pyretailscience.segmentation import HMLSegmentation, SegTransactionStats +from pyretailscience.segmentation import HMLSegmentation, SegTransactionStats, ThresholdSegmentation class TestCalcSegStats: @@ -90,6 +90,193 @@ def test_handles_dataframe_with_one_segment(self, base_df): pd.testing.assert_frame_equal(segment_stats, expected_output) +class TestThresholdSegmentation: + """Tests for the ThresholdSegmentation class.""" + + def test_correct_segmentation(self): + """Test that the method correctly segments customers based on given thresholds and segments.""" + df = pd.DataFrame({"customer_id": [1, 2, 3, 4], "total_price": [100, 200, 300, 400]}) + thresholds = [0.5, 1] + segments = {0: "Low", 1: "High"} + seg = ThresholdSegmentation( + df=df, + thresholds=thresholds, + segments=segments, + value_col="total_price", + zero_value_customers="exclude", + ) + result_df = seg.df + assert result_df.loc[1, "segment_name"] == "Low" + assert result_df.loc[2, "segment_name"] == "Low" + assert result_df.loc[3, "segment_name"] == "High" + assert result_df.loc[4, "segment_name"] == "High" + + def test_single_customer(self): + """Test that the method correctly segments a DataFrame with only one customer.""" + df = pd.DataFrame({"customer_id": [1], "total_price": [100]}) + thresholds = [0.5, 1] + segments = {0: "Low"} + with pytest.raises(ValueError): + ThresholdSegmentation( + df=df, + thresholds=thresholds, + segments=segments, + ) + + def test_correct_aggregation_function(self): + """Test that the correct aggregation function is applied for product_id custom segmentation.""" + df = pd.DataFrame( + { + "customer_id": [1, 2, 2, 2, 2, 2, 3, 3, 3, 4, 4, 5, 5, 5], + "product_id": [3, 4, 4, 6, 1, 5, 7, 2, 2, 3, 2, 3, 4, 1], + }, + ) + value_col = "product_id" + agg_func = "nunique" + + my_seg = ThresholdSegmentation( + df=df, + value_col=value_col, + agg_func=agg_func, + thresholds=[0.2, 0.8, 1], + segments={"A": "Low", "B": "Medium", "C": "High"}, + zero_value_customers="separate_segment", + ) + + expected_result = pd.DataFrame( + { + "customer_id": [1, 2, 3, 4, 5], + "product_id": [1, 4, 2, 2, 3], + "segment_name": ["Low", "High", "Medium", "Medium", "Medium"], + "segment_id": ["A", "C", "B", "B", "B"], + }, + ) + expected_result["segment_id"] = pd.Categorical( + expected_result["segment_id"], + categories=["A", "B", "C"], + ordered=True, + ) + expected_result["segment_name"] = pd.Categorical( + expected_result["segment_name"], + categories=["Low", "Medium", "High"], + ordered=True, + ) + pd.testing.assert_frame_equal(my_seg.df.reset_index(), expected_result) + + def test_correctly_checks_segment_data(self): + """Test that the method correctly merges segment data back into the original DataFrame.""" + df = pd.DataFrame( + { + "customer_id": [1, 2, 3, 4, 5], + "total_price": [100, 200, 0, 150, 0], + }, + ) + value_col = "total_price" + agg_func = "sum" + thresholds = [0.33, 0.66, 1] + segments = {"A": "Low", "B": "Medium", "C": "High"} + zero_value_customers = "separate_segment" + + # Create ThresholdSegmentation instance + threshold_seg = ThresholdSegmentation( + df=df, + value_col=value_col, + agg_func=agg_func, + thresholds=thresholds, + segments=segments, + zero_value_customers=zero_value_customers, + ) + + # Call add_segment method + segmented_df = threshold_seg.add_segment(df) + + # Assert the correct segment_name and segment_id + expected_df = pd.DataFrame( + { + "customer_id": [1, 2, 3, 4, 5], + "total_price": [100, 200, 0, 150, 0], + "segment_name": ["Low", "High", "Zero", "Medium", "Zero"], + "segment_id": ["A", "C", "Z", "B", "Z"], + }, + ) + pd.testing.assert_frame_equal(segmented_df, expected_df) + + def test_handles_dataframe_with_duplicate_customer_id_entries(self): + """Test that the method correctly handles a DataFrame with duplicate customer_id entries.""" + df = pd.DataFrame({"customer_id": [1, 2, 3, 1, 2, 3], "total_price": [100, 200, 300, 150, 250, 350]}) + + my_seg = ThresholdSegmentation( + df=df, + value_col="total_price", + agg_func="sum", + thresholds=[0.5, 0.8, 1], + segments={"L": "Light", "M": "Medium", "H": "Heavy"}, + zero_value_customers="include_with_light", + ) + + result_df = my_seg.add_segment(df) + assert len(result_df) == len(df) + + def test_correctly_maps_segment_names_to_segment_ids_with_fixed_thresholds(self): + """Test that the method correctly maps segment names to segment IDs with fixed thresholds.""" + # Setup + df = pd.DataFrame({"customer_id": [1, 2, 3, 4, 5], "total_price": [100, 200, 300, 400, 500]}) + value_col = "total_price" + agg_func = "sum" + thresholds = [0.33, 0.66, 1] + segments = {1: "Low", 2: "Medium", 3: "High"} + zero_value_customers = "separate_segment" + + my_seg = ThresholdSegmentation( + df=df, + value_col=value_col, + agg_func=agg_func, + thresholds=thresholds, + segments=segments, + zero_value_customers=zero_value_customers, + ) + + assert len(my_seg.df[["segment_id", "segment_name"]].drop_duplicates()) == len(segments) + assert my_seg.df.set_index("segment_id")["segment_name"].to_dict() == segments + + def test_thresholds_not_unique(self): + """Test that the method raises an error when the thresholds are not unique.""" + df = pd.DataFrame({"customer_id": [1, 2, 3, 4, 5], "total_price": [100, 200, 300, 400, 500]}) + thresholds = [0.5, 0.5, 0.8, 1] + segments = {1: "Low", 2: "Medium", 3: "High"} + + with pytest.raises(ValueError): + ThresholdSegmentation(df, thresholds, segments) + + def test_thresholds_too_few_segments(self): + """Test that the method raises an error when there are too few/many segments for the number of thresholds.""" + df = pd.DataFrame({"customer_id": [1, 2, 3, 4, 5], "total_price": [100, 200, 300, 400, 500]}) + thresholds = [0.4, 0.6, 0.8, 1] + segments = {1: "Low", 3: "High"} + + with pytest.raises(ValueError): + ThresholdSegmentation(df, thresholds, segments) + + segments = {1: "Low", 2: "Medium", 3: "High"} + + with pytest.raises(ValueError): + ThresholdSegmentation(df, thresholds, segments) + + def test_thresholds_too_too_few_thresholds(self): + """Test that the method raises an error when there are too few/many thresholds for the number of segments.""" + df = pd.DataFrame({"customer_id": [1, 2, 3, 4, 5], "total_price": [100, 200, 300, 400, 500]}) + thresholds = [0.4, 1] + segments = {1: "Low", 2: "Medium", 3: "High"} + + with pytest.raises(ValueError): + ThresholdSegmentation(df, thresholds, segments) + + thresholds = [0.2, 0.5, 0.6, 0.8, 1] + + with pytest.raises(ValueError): + ThresholdSegmentation(df, thresholds, segments) + + class TestSegTransactionStats: """Tests for the SegTransactionStats class."""