Skip to content

Commit 030e7ed

Browse files
committed
feat: add total row to SegTransactionStats calculation
1 parent 9a0d672 commit 030e7ed

File tree

2 files changed

+145
-26
lines changed

2 files changed

+145
-26
lines changed

pyretailscience/segmentation.py

Lines changed: 44 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88
import pyretailscience.style.graph_utils as gu
99
from pyretailscience.data.contracts import (
1010
CustomContract,
11-
TransactionItemLevelContract,
12-
TransactionLevelContract,
1311
build_expected_columns,
1412
build_expected_unique_columns,
1513
build_non_null_columns,
@@ -146,34 +144,49 @@ def __init__(self, df: pd.DataFrame, segment_col: str = "segment_id") -> None:
146144
TransactionLevelContract.
147145
148146
"""
147+
required_cols = ["customer_id", "total_price", "transaction_id", segment_col]
148+
contract = CustomContract(
149+
df,
150+
basic_expectations=build_expected_columns(columns=required_cols),
151+
extended_expectations=build_non_null_columns(columns=required_cols),
152+
)
153+
154+
if contract.validate() is False:
155+
msg = f"The dataframe requires the columns {required_cols} and they must be non-null"
156+
raise ValueError(msg)
157+
149158
self.segment_col = segment_col
150-
if TransactionItemLevelContract(df).validate() is True:
151-
stats_df = df.groupby(segment_col).agg(
152-
revenue=("total_price", "sum"),
153-
transactions=("transaction_id", "nunique"),
154-
customers=("customer_id", "nunique"),
155-
total_quantity=("quantity", "sum"),
156-
)
159+
160+
self.df = self._calc_seg_stats(df, segment_col)
161+
162+
@staticmethod
163+
def _calc_seg_stats(df: pd.DataFrame, segment_col: str) -> pd.DataFrame:
164+
aggs = {
165+
"revenue": ("total_price", "sum"),
166+
"transactions": ("transaction_id", "nunique"),
167+
"customers": ("customer_id", "nunique"),
168+
}
169+
total_aggs = {
170+
"revenue": [df["total_price"].sum()],
171+
"transactions": [df["transaction_id"].nunique()],
172+
"customers": [df["customer_id"].nunique()],
173+
}
174+
if "quantity" in df.columns:
175+
aggs["total_quantity"] = ("quantity", "sum")
176+
total_aggs["total_quantity"] = [df["quantity"].sum()]
177+
178+
stats_df = pd.concat(
179+
[
180+
df.groupby(segment_col).agg(**aggs),
181+
pd.DataFrame(total_aggs, index=["total"]),
182+
],
183+
)
184+
185+
if "quantity" in df.columns:
157186
stats_df["price_per_unit"] = stats_df["revenue"] / stats_df["total_quantity"]
158187
stats_df["quantity_per_transaction"] = stats_df["total_quantity"] / stats_df["transactions"]
159-
elif TransactionLevelContract(df).validate() is True:
160-
stats_df = df.groupby(segment_col).agg(
161-
revenue=("total_price", "sum"),
162-
transactions=("transaction_id", "nunique"),
163-
customers=("customer_id", "nunique"),
164-
)
165-
else:
166-
raise NotImplementedError(
167-
"The dataframe does not comply with the TransactionItemLevelContract or TransactionLevelContract. "
168-
"These are the only two contracts supported at this time.",
169-
)
170-
total_num_customers = df["customer_id"].nunique()
171-
stats_df["spend_per_cust"] = stats_df["revenue"] / stats_df["customers"]
172-
stats_df["spend_per_transaction"] = stats_df["revenue"] / stats_df["transactions"]
173-
stats_df["transactions_per_customer"] = stats_df["transactions"] / stats_df["customers"]
174-
stats_df["customers_pct"] = stats_df["customers"] / total_num_customers
175188

176-
self.df = stats_df
189+
return stats_df
177190

178191
def plot(
179192
self,
@@ -185,6 +198,7 @@ def plot(
185198
orientation: Literal["vertical", "horizontal"] = "vertical",
186199
sort_order: Literal["ascending", "descending", None] = None,
187200
source_text: str | None = None,
201+
hide_total: bool = True,
188202
**kwargs: dict[str, any],
189203
) -> SubplotBase:
190204
"""Plots the value_col by segment.
@@ -203,6 +217,7 @@ def plot(
203217
sort_order (Literal["ascending", "descending", None], optional): The sort order of the segments.
204218
Defaults to None. If None, the segments are plotted in the order they appear in the dataframe.
205219
source_text (str, optional): The source text to add to the plot. Defaults to None.
220+
hide_total (bool, optional): Whether to hide the total row. Defaults to True.
206221
**kwargs: Additional keyword arguments to pass to the Pandas plot function.
207222
208223
Returns:
@@ -223,6 +238,9 @@ def plot(
223238
kind = "barh"
224239

225240
val_s = self.df[value_col]
241+
if hide_total:
242+
val_s = val_s[val_s.index != "total"]
243+
226244
if sort_order is not None:
227245
ascending = sort_order == "ascending"
228246
val_s = val_s.sort_values(ascending=ascending)

tests/test_segmentation.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
"""Tests for the SegTransactionStats class."""
2+
3+
import pandas as pd
4+
import pytest
5+
6+
from pyretailscience.segmentation import SegTransactionStats
7+
8+
9+
class TestCalcSegStats:
10+
"""Tests for the _calc_seg_stats method."""
11+
12+
@pytest.fixture()
13+
def base_df(self):
14+
"""Return a base DataFrame for testing."""
15+
return pd.DataFrame(
16+
{
17+
"customer_id": [1, 2, 3, 4, 5],
18+
"total_price": [100, 200, 150, 300, 250],
19+
"transaction_id": [101, 102, 103, 104, 105],
20+
"segment_id": ["A", "B", "A", "B", "A"],
21+
"quantity": [10, 20, 15, 30, 25],
22+
},
23+
)
24+
25+
def test_correctly_calculates_revenue_transactions_customers_per_segment(self, base_df):
26+
"""Test that the method correctly calculates at the transaction-item level."""
27+
expected_output = pd.DataFrame(
28+
{
29+
"revenue": [500, 500, 1000],
30+
"transactions": [3, 2, 5],
31+
"customers": [3, 2, 5],
32+
"total_quantity": [50, 50, 100],
33+
"price_per_unit": [10.0, 10.0, 10.0],
34+
"quantity_per_transaction": [16.666667, 25.0, 20.0],
35+
},
36+
index=["A", "B", "total"],
37+
)
38+
39+
segment_stats = SegTransactionStats._calc_seg_stats(base_df, "segment_id")
40+
pd.testing.assert_frame_equal(segment_stats, expected_output)
41+
42+
def test_correctly_calculates_revenue_transactions_customers(self):
43+
"""Test that the method correctly calculates at the transaction level."""
44+
df = pd.DataFrame(
45+
{
46+
"customer_id": [1, 2, 3, 4, 5],
47+
"total_price": [100, 200, 150, 300, 250],
48+
"transaction_id": [101, 102, 103, 104, 105],
49+
"segment_id": ["A", "B", "A", "B", "A"],
50+
},
51+
)
52+
53+
expected_output = pd.DataFrame(
54+
{
55+
"revenue": [500, 500, 1000],
56+
"transactions": [3, 2, 5],
57+
"customers": [3, 2, 5],
58+
},
59+
index=["A", "B", "total"],
60+
)
61+
62+
segment_stats = SegTransactionStats._calc_seg_stats(df, "segment_id")
63+
pd.testing.assert_frame_equal(segment_stats, expected_output)
64+
65+
def test_does_not_alter_original_dataframe(self, base_df):
66+
"""Test that the method does not alter the original DataFrame."""
67+
original_df = base_df.copy()
68+
_ = SegTransactionStats._calc_seg_stats(base_df, "segment_id")
69+
70+
pd.testing.assert_frame_equal(base_df, original_df)
71+
72+
def test_handles_dataframe_with_one_segment(self, base_df):
73+
"""Test that the method correctly handles a DataFrame with only one segment."""
74+
df = base_df.copy()
75+
df["segment_id"] = "A"
76+
77+
expected_output = pd.DataFrame(
78+
{
79+
"revenue": [1000, 1000],
80+
"transactions": [5, 5],
81+
"customers": [5, 5],
82+
"total_quantity": [100, 100],
83+
"price_per_unit": [10.0, 10.0],
84+
"quantity_per_transaction": [20.0, 20.0],
85+
},
86+
index=["A", "total"],
87+
)
88+
89+
segment_stats = SegTransactionStats._calc_seg_stats(df, "segment_id")
90+
pd.testing.assert_frame_equal(segment_stats, expected_output)
91+
92+
93+
class TestSegTransactionStats:
94+
"""Tests for the SegTransactionStats class."""
95+
96+
def test_handles_empty_dataframe_with_errors(self):
97+
"""Test that the method raises an error when the DataFrame is missing a required column."""
98+
df = pd.DataFrame(columns=["total_price", "transaction_id", "segment_id", "quantity"])
99+
100+
with pytest.raises(ValueError):
101+
SegTransactionStats(df, "segment_id")

0 commit comments

Comments
 (0)