Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions pyretailscience/analysis/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ class SegTransactionStats:
def __init__(
self,
data: pd.DataFrame | ibis.Table,
segment_col: str = "segment_name",
segment_col: str | list[str] = "segment_name",
extra_aggs: dict[str, tuple[str, str]] | None = None,
) -> None:
"""Calculates transaction statistics by segment.
Expand All @@ -202,7 +202,8 @@ def __init__(
customer_id, unit_spend and transaction_id. If the dataframe contains the column unit_quantity, then
the columns unit_spend and unit_quantity are used to calculate the price_per_unit and
units_per_transaction.
segment_col (str, optional): The column to use for the segmentation. Defaults to "segment_name".
segment_col (str | list[str], optional): The column or list of columns to use for the segmentation.
Defaults to "segment_name".
extra_aggs (dict[str, tuple[str, str]], optional): Additional aggregations to perform.
The keys in the dictionary will be the column names for the aggregation results.
The values are tuples with (column_name, aggregation_function), where:
Expand All @@ -211,11 +212,14 @@ def __init__(
Example: {"stores": ("store_id", "nunique")} would count unique store_ids.
"""
cols = ColumnHelper()

if isinstance(segment_col, str):
segment_col = [segment_col]
required_cols = [
cols.customer_id,
cols.unit_spend,
cols.transaction_id,
segment_col,
*segment_col,
]
if cols.unit_qty in data.columns:
required_cols.append(cols.unit_qty)
Expand Down Expand Up @@ -273,14 +277,14 @@ def _get_col_order(include_quantity: bool) -> list[str]:
@staticmethod
def _calc_seg_stats(
data: pd.DataFrame | ibis.Table,
segment_col: str,
segment_col: list[str],
extra_aggs: dict[str, tuple[str, str]] | None = None,
) -> ibis.Table:
"""Calculates the transaction statistics by segment.

Args:
data (pd.DataFrame | ibis.Table): The transaction data.
segment_col (str): The column to use for the segmentation.
segment_col (list[str]): The columns to use for the segmentation.
extra_aggs (dict[str, tuple[str, str]], optional): Additional aggregations to perform.
The keys in the dictionary will be the column names for the aggregation results.
The values are tuples with (column_name, aggregation_function).
Expand Down Expand Up @@ -314,7 +318,7 @@ def _calc_seg_stats(

# Calculate metrics for segments and total
segment_metrics = data.group_by(segment_col).aggregate(**aggs)
total_metrics = data.aggregate(**aggs).mutate(segment_name=ibis.literal("Total"))
total_metrics = data.aggregate(**aggs).mutate({col: ibis.literal("Total") for col in segment_col})
total_customers = data[cols.customer_id].nunique()

# Cross join with total_customers to make it available for percentage calculation
Expand Down Expand Up @@ -343,7 +347,7 @@ def df(self) -> pd.DataFrame:
if self._df is None:
cols = ColumnHelper()
col_order = [
self.segment_col,
*self.segment_col,
*SegTransactionStats._get_col_order(include_quantity=cols.agg_unit_qty in self.table.columns),
]

Expand Down Expand Up @@ -392,18 +396,23 @@ def plot(
Raises:
ValueError: If the sort_order is not "ascending", "descending" or None.
ValueError: If the orientation is not "vertical" or "horizontal".
ValueError: If multiple segment columns are used, as plotting is only supported for a single segment column.
"""
if sort_order not in ["ascending", "descending", None]:
raise ValueError("sort_order must be either 'ascending' or 'descending' or None")
if orientation not in ["vertical", "horizontal"]:
raise ValueError("orientation must be either 'vertical' or 'horizontal'")
if len(self.segment_col) > 1:
raise ValueError("Plotting is only supported for a single segment column")

default_title = f"{value_col.title()} by Segment"
kind = "bar"
if orientation == "horizontal":
kind = "barh"

val_s = self.df.set_index(self.segment_col)[value_col]
# Use the first segment column for plotting
plot_segment_col = self.segment_col[0]
val_s = self.df.set_index(plot_segment_col)[value_col]
if hide_total:
val_s = val_s[val_s.index != "Total"]

Expand Down
68 changes: 66 additions & 2 deletions tests/analysis/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
class TestCalcSegStats:
"""Tests for the _calc_seg_stats method."""

@pytest.fixture()
@pytest.fixture
def base_df(self):
"""Return a base DataFrame for testing."""
return pd.DataFrame(
Expand Down Expand Up @@ -314,6 +314,70 @@ def test_handles_empty_dataframe_with_errors(self):
with pytest.raises(ValueError):
SegTransactionStats(df, "segment_name")

def test_multiple_segment_columns(self):
"""Test that the class correctly handles multiple segment columns."""
df = pd.DataFrame(
{
cols.customer_id: [1, 1, 2, 2, 3, 3],
cols.unit_spend: [100.0, 150.0, 200.0, 250.0, 300.0, 350.0],
cols.transaction_id: [101, 102, 103, 104, 105, 106],
"segment_name": ["A", "A", "B", "B", "A", "A"],
"region": ["North", "North", "South", "South", "East", "East"],
},
)

# Test with a list of segment columns
seg_stats = SegTransactionStats(df, ["segment_name", "region"])

# Create expected DataFrame with the combinations actually produced
expected_output = pd.DataFrame(
{
"segment_name": ["A", "A", "B", "Total"],
"region": ["East", "North", "South", "Total"],
cols.agg_unit_spend: [650.0, 250.0, 450.0, 1350.0],
cols.agg_transaction_id: [2, 2, 2, 6],
cols.agg_customer_id: [1, 1, 1, 3],
cols.calc_spend_per_cust: [650.0, 250.0, 450.0, 450.0],
cols.calc_spend_per_trans: [325.0, 125.0, 225.0, 225.0],
cols.calc_trans_per_cust: [2.0, 2.0, 2.0, 2.0],
cols.customers_pct: [1 / 3, 1 / 3, 1 / 3, 1.0],
},
)

# Sort both dataframes by the segment columns for consistent comparison
result_df = seg_stats.df.sort_values(["segment_name", "region"]).reset_index(drop=True)
expected_output = expected_output.sort_values(["segment_name", "region"]).reset_index(drop=True)

# Check that both segment columns are in the result
assert "segment_name" in result_df.columns
assert "region" in result_df.columns

# Check number of rows - the implementation only returns actual combinations that exist in data
# plus the Total row, not all possible combinations
assert len(result_df) == len(expected_output)

# Use pandas testing to compare the dataframes
pd.testing.assert_frame_equal(result_df[expected_output.columns], expected_output)

def test_plot_with_multiple_segment_columns(self):
"""Test that plotting with multiple segment columns raises a ValueError."""
df = pd.DataFrame(
{
cols.customer_id: [1, 2, 3],
cols.unit_spend: [100.0, 200.0, 300.0],
cols.transaction_id: [101, 102, 103],
"segment_name": ["A", "B", "A"],
"region": ["North", "South", "East"],
},
)

seg_stats = SegTransactionStats(df, ["segment_name", "region"])

with pytest.raises(ValueError) as excinfo:
seg_stats.plot("spend")

assert "Plotting is only supported for a single segment column" in str(excinfo.value)

def test_extra_aggs_functionality(self):
"""Test that the extra_aggs parameter works correctly."""
# Constants for expected values
Expand Down Expand Up @@ -405,7 +469,7 @@ def test_extra_aggs_with_invalid_function(self):
class TestHMLSegmentation:
"""Tests for the HMLSegmentation class."""

@pytest.fixture()
@pytest.fixture
def base_df(self):
"""Return a base DataFrame for testing."""
return pd.DataFrame(
Expand Down
Loading