diff --git a/pyretailscience/segmentation.py b/pyretailscience/segmentation.py index d3d8ed39..df2d67ad 100644 --- a/pyretailscience/segmentation.py +++ b/pyretailscience/segmentation.py @@ -8,8 +8,6 @@ import pyretailscience.style.graph_utils as gu from pyretailscience.data.contracts import ( CustomContract, - TransactionItemLevelContract, - TransactionLevelContract, build_expected_columns, build_expected_unique_columns, build_non_null_columns, @@ -146,34 +144,51 @@ def __init__(self, df: pd.DataFrame, segment_col: str = "segment_id") -> None: TransactionLevelContract. """ + required_cols = ["customer_id", "total_price", "transaction_id", segment_col] + if "quantity" in df.columns: + required_cols.append("quantity") + contract = CustomContract( + df, + basic_expectations=build_expected_columns(columns=required_cols), + extended_expectations=build_non_null_columns(columns=required_cols), + ) + + if contract.validate() is False: + msg = f"The dataframe requires the columns {required_cols} and they must be non-null" + raise ValueError(msg) + self.segment_col = segment_col - if TransactionItemLevelContract(df).validate() is True: - stats_df = df.groupby(segment_col).agg( - revenue=("total_price", "sum"), - transactions=("transaction_id", "nunique"), - customers=("customer_id", "nunique"), - total_quantity=("quantity", "sum"), - ) + + self.df = self._calc_seg_stats(df, segment_col) + + @staticmethod + def _calc_seg_stats(df: pd.DataFrame, segment_col: str) -> pd.DataFrame: + aggs = { + "revenue": ("total_price", "sum"), + "transactions": ("transaction_id", "nunique"), + "customers": ("customer_id", "nunique"), + } + total_aggs = { + "revenue": [df["total_price"].sum()], + "transactions": [df["transaction_id"].nunique()], + "customers": [df["customer_id"].nunique()], + } + if "quantity" in df.columns: + aggs["total_quantity"] = ("quantity", "sum") + total_aggs["total_quantity"] = [df["quantity"].sum()] + + stats_df = pd.concat( + [ + df.groupby(segment_col).agg(**aggs), + pd.DataFrame(total_aggs, index=["total"]), + ], + ) + + if "quantity" in df.columns: stats_df["price_per_unit"] = stats_df["revenue"] / stats_df["total_quantity"] stats_df["quantity_per_transaction"] = stats_df["total_quantity"] / stats_df["transactions"] - elif TransactionLevelContract(df).validate() is True: - stats_df = df.groupby(segment_col).agg( - revenue=("total_price", "sum"), - transactions=("transaction_id", "nunique"), - customers=("customer_id", "nunique"), - ) - else: - raise NotImplementedError( - "The dataframe does not comply with the TransactionItemLevelContract or TransactionLevelContract. " - "These are the only two contracts supported at this time.", - ) - total_num_customers = df["customer_id"].nunique() - stats_df["spend_per_cust"] = stats_df["revenue"] / stats_df["customers"] - stats_df["spend_per_transaction"] = stats_df["revenue"] / stats_df["transactions"] - stats_df["transactions_per_customer"] = stats_df["transactions"] / stats_df["customers"] - stats_df["customers_pct"] = stats_df["customers"] / total_num_customers - self.df = stats_df + return stats_df def plot( self, @@ -185,6 +200,7 @@ def plot( orientation: Literal["vertical", "horizontal"] = "vertical", sort_order: Literal["ascending", "descending", None] = None, source_text: str | None = None, + hide_total: bool = True, **kwargs: dict[str, any], ) -> SubplotBase: """Plots the value_col by segment. @@ -203,6 +219,7 @@ def plot( sort_order (Literal["ascending", "descending", None], optional): The sort order of the segments. Defaults to None. If None, the segments are plotted in the order they appear in the dataframe. source_text (str, optional): The source text to add to the plot. Defaults to None. + hide_total (bool, optional): Whether to hide the total row. Defaults to True. **kwargs: Additional keyword arguments to pass to the Pandas plot function. Returns: @@ -223,6 +240,9 @@ def plot( kind = "barh" val_s = self.df[value_col] + if hide_total: + val_s = val_s[val_s.index != "total"] + if sort_order is not None: ascending = sort_order == "ascending" val_s = val_s.sort_values(ascending=ascending) diff --git a/tests/test_segmentation.py b/tests/test_segmentation.py new file mode 100644 index 00000000..22b215a5 --- /dev/null +++ b/tests/test_segmentation.py @@ -0,0 +1,101 @@ +"""Tests for the SegTransactionStats class.""" + +import pandas as pd +import pytest + +from pyretailscience.segmentation import SegTransactionStats + + +class TestCalcSegStats: + """Tests for the _calc_seg_stats method.""" + + @pytest.fixture() + def base_df(self): + """Return a base DataFrame for testing.""" + return pd.DataFrame( + { + "customer_id": [1, 2, 3, 4, 5], + "total_price": [100, 200, 150, 300, 250], + "transaction_id": [101, 102, 103, 104, 105], + "segment_id": ["A", "B", "A", "B", "A"], + "quantity": [10, 20, 15, 30, 25], + }, + ) + + def test_correctly_calculates_revenue_transactions_customers_per_segment(self, base_df): + """Test that the method correctly calculates at the transaction-item level.""" + expected_output = pd.DataFrame( + { + "revenue": [500, 500, 1000], + "transactions": [3, 2, 5], + "customers": [3, 2, 5], + "total_quantity": [50, 50, 100], + "price_per_unit": [10.0, 10.0, 10.0], + "quantity_per_transaction": [16.666667, 25.0, 20.0], + }, + index=["A", "B", "total"], + ) + + segment_stats = SegTransactionStats._calc_seg_stats(base_df, "segment_id") + pd.testing.assert_frame_equal(segment_stats, expected_output) + + def test_correctly_calculates_revenue_transactions_customers(self): + """Test that the method correctly calculates at the transaction level.""" + df = pd.DataFrame( + { + "customer_id": [1, 2, 3, 4, 5], + "total_price": [100, 200, 150, 300, 250], + "transaction_id": [101, 102, 103, 104, 105], + "segment_id": ["A", "B", "A", "B", "A"], + }, + ) + + expected_output = pd.DataFrame( + { + "revenue": [500, 500, 1000], + "transactions": [3, 2, 5], + "customers": [3, 2, 5], + }, + index=["A", "B", "total"], + ) + + segment_stats = SegTransactionStats._calc_seg_stats(df, "segment_id") + pd.testing.assert_frame_equal(segment_stats, expected_output) + + def test_does_not_alter_original_dataframe(self, base_df): + """Test that the method does not alter the original DataFrame.""" + original_df = base_df.copy() + _ = SegTransactionStats._calc_seg_stats(base_df, "segment_id") + + pd.testing.assert_frame_equal(base_df, original_df) + + def test_handles_dataframe_with_one_segment(self, base_df): + """Test that the method correctly handles a DataFrame with only one segment.""" + df = base_df.copy() + df["segment_id"] = "A" + + expected_output = pd.DataFrame( + { + "revenue": [1000, 1000], + "transactions": [5, 5], + "customers": [5, 5], + "total_quantity": [100, 100], + "price_per_unit": [10.0, 10.0], + "quantity_per_transaction": [20.0, 20.0], + }, + index=["A", "total"], + ) + + segment_stats = SegTransactionStats._calc_seg_stats(df, "segment_id") + pd.testing.assert_frame_equal(segment_stats, expected_output) + + +class TestSegTransactionStats: + """Tests for the SegTransactionStats class.""" + + def test_handles_empty_dataframe_with_errors(self): + """Test that the method raises an error when the DataFrame is missing a required column.""" + df = pd.DataFrame(columns=["total_price", "transaction_id", "segment_id", "quantity"]) + + with pytest.raises(ValueError): + SegTransactionStats(df, "segment_id")