diff --git a/pyretailscience/analysis/segmentation.py b/pyretailscience/analysis/segmentation.py index 22d631ef..8a4748ab 100644 --- a/pyretailscience/analysis/segmentation.py +++ b/pyretailscience/analysis/segmentation.py @@ -194,6 +194,7 @@ def __init__( self, data: pd.DataFrame | ibis.Table, segment_col: str | list[str] = "segment_name", + calc_total: bool = True, extra_aggs: dict[str, tuple[str, str]] | None = None, ) -> None: """Calculates transaction statistics by segment. @@ -205,6 +206,7 @@ def __init__( units_per_transaction. segment_col (str | list[str], optional): The column or list of columns to use for the segmentation. Defaults to "segment_name". + calc_total (bool, optional): Whether to include the total row. Defaults to True. extra_aggs (dict[str, tuple[str, str]], optional): Additional aggregations to perform. The keys in the dictionary will be the column names for the aggregation results. The values are tuples with (column_name, aggregation_function), where: @@ -244,7 +246,7 @@ def __init__( self.segment_col = segment_col self.extra_aggs = {} if extra_aggs is None else extra_aggs - self.table = self._calc_seg_stats(data, segment_col, self.extra_aggs) + self.table = self._calc_seg_stats(data, segment_col, calc_total, self.extra_aggs) @staticmethod def _get_col_order(include_quantity: bool) -> list[str]: @@ -279,6 +281,7 @@ def _get_col_order(include_quantity: bool) -> list[str]: def _calc_seg_stats( data: pd.DataFrame | ibis.Table, segment_col: list[str], + calc_total: bool = True, extra_aggs: dict[str, tuple[str, str]] | None = None, ) -> ibis.Table: """Calculates the transaction statistics by segment. @@ -287,6 +290,7 @@ def _calc_seg_stats( data (pd.DataFrame | ibis.Table): The transaction data. segment_col (list[str]): The columns to use for the segmentation. extra_aggs (dict[str, tuple[str, str]], optional): Additional aggregations to perform. + calc_total (bool, optional): Whether to include the total row. Defaults to True. The keys in the dictionary will be the column names for the aggregation results. The values are tuples with (column_name, aggregation_function). @@ -298,7 +302,7 @@ def _calc_seg_stats( data = ibis.memtable(data) elif not isinstance(data, ibis.Table): - raise TypeError("data must be either a pandas DataFrame or a ibis Table") + raise TypeError("data must be either a pandas DataFrame or an ibis Table") cols = ColumnHelper() @@ -317,13 +321,18 @@ def _calc_seg_stats( col, func = col_tuple aggs[agg_name] = getattr(data[col], func)() - # Calculate metrics for segments and total + # Calculate metrics for segments segment_metrics = data.group_by(segment_col).aggregate(**aggs) - total_metrics = data.aggregate(**aggs).mutate({col: ibis.literal("Total") for col in segment_col}) + final_metrics = segment_metrics + + if calc_total: + total_metrics = data.aggregate(**aggs).mutate({col: ibis.literal("Total") for col in segment_col}) + final_metrics = ibis.union(segment_metrics, total_metrics) + total_customers = data[cols.customer_id].nunique() # Cross join with total_customers to make it available for percentage calculation - final_metrics = ibis.union(segment_metrics, total_metrics).mutate( + final_metrics = final_metrics.mutate( **{ cols.calc_spend_per_cust: ibis._[cols.agg_unit_spend] / ibis._[cols.agg_customer_id], cols.calc_spend_per_trans: ibis._[cols.agg_unit_spend] / ibis._[cols.agg_transaction_id], diff --git a/tests/analysis/test_segmentation.py b/tests/analysis/test_segmentation.py index 4fe2e03f..b1768ffb 100644 --- a/tests/analysis/test_segmentation.py +++ b/tests/analysis/test_segmentation.py @@ -130,6 +130,32 @@ def test_handles_dataframe_with_zero_net_units(self, base_df): pd.testing.assert_frame_equal(segment_stats, expected_output) + def test_excludes_total_row_when_calc_total_false(self, base_df): + """Test that the method excludes the total row when calc_total=False.""" + expected_output = pd.DataFrame( + { + "segment_name": ["A", "B"], + cols.agg_unit_spend: [500.0, 500.0], + cols.agg_transaction_id: [3, 2], + cols.agg_customer_id: [3, 2], + cols.agg_unit_qty: [50, 50], + cols.calc_spend_per_cust: [166.666667, 250.0], + cols.calc_spend_per_trans: [166.666667, 250.0], + cols.calc_trans_per_cust: [1.0, 1.0], + cols.calc_price_per_unit: [10.0, 10.0], + cols.calc_units_per_trans: [16.666667, 25.0], + cols.customers_pct: [1.0, 1.0], + }, + ) + + segment_stats = ( + SegTransactionStats(base_df, "segment_name", calc_total=False) + .df.sort_values("segment_name") + .reset_index(drop=True) + ) + + pd.testing.assert_frame_equal(segment_stats, expected_output) + class TestThresholdSegmentation: """Tests for the ThresholdSegmentation class.""" @@ -407,7 +433,11 @@ def test_extra_aggs_functionality(self): ) # Test with a single extra aggregation - seg_stats = SegTransactionStats(df, "segment_name", extra_aggs={"distinct_stores": ("store_id", "nunique")}) + seg_stats = SegTransactionStats( + df, + "segment_name", + extra_aggs={"distinct_stores": ("store_id", "nunique")}, + ) # Verify the extra column exists and has correct values assert "distinct_stores" in seg_stats.df.columns diff --git a/uv.lock b/uv.lock index 4acc0114..bde00ac9 100644 --- a/uv.lock +++ b/uv.lock @@ -1847,7 +1847,7 @@ wheels = [ [[package]] name = "pyretailscience" -version = "0.12.1" +version = "0.12.2" source = { editable = "." } dependencies = [ { name = "duckdb" },