Skip to content

Commit d8c9965

Browse files
feat: add total row to SegTransactionStats calculation (#57)
* feat: add total row to SegTransactionStats calculation * Update pyretailscience/segmentation.py Co-authored-by: codiumai-pr-agent-pro[bot] <151058649+codiumai-pr-agent-pro[bot]@users.noreply.github.com> --------- Co-authored-by: codiumai-pr-agent-pro[bot] <151058649+codiumai-pr-agent-pro[bot]@users.noreply.github.com>
1 parent 9a0d672 commit d8c9965

File tree

2 files changed

+147
-26
lines changed

2 files changed

+147
-26
lines changed

pyretailscience/segmentation.py

Lines changed: 46 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88
import pyretailscience.style.graph_utils as gu
99
from pyretailscience.data.contracts import (
1010
CustomContract,
11-
TransactionItemLevelContract,
12-
TransactionLevelContract,
1311
build_expected_columns,
1412
build_expected_unique_columns,
1513
build_non_null_columns,
@@ -146,34 +144,51 @@ def __init__(self, df: pd.DataFrame, segment_col: str = "segment_id") -> None:
146144
TransactionLevelContract.
147145
148146
"""
147+
required_cols = ["customer_id", "total_price", "transaction_id", segment_col]
148+
if "quantity" in df.columns:
149+
required_cols.append("quantity")
150+
contract = CustomContract(
151+
df,
152+
basic_expectations=build_expected_columns(columns=required_cols),
153+
extended_expectations=build_non_null_columns(columns=required_cols),
154+
)
155+
156+
if contract.validate() is False:
157+
msg = f"The dataframe requires the columns {required_cols} and they must be non-null"
158+
raise ValueError(msg)
159+
149160
self.segment_col = segment_col
150-
if TransactionItemLevelContract(df).validate() is True:
151-
stats_df = df.groupby(segment_col).agg(
152-
revenue=("total_price", "sum"),
153-
transactions=("transaction_id", "nunique"),
154-
customers=("customer_id", "nunique"),
155-
total_quantity=("quantity", "sum"),
156-
)
161+
162+
self.df = self._calc_seg_stats(df, segment_col)
163+
164+
@staticmethod
165+
def _calc_seg_stats(df: pd.DataFrame, segment_col: str) -> pd.DataFrame:
166+
aggs = {
167+
"revenue": ("total_price", "sum"),
168+
"transactions": ("transaction_id", "nunique"),
169+
"customers": ("customer_id", "nunique"),
170+
}
171+
total_aggs = {
172+
"revenue": [df["total_price"].sum()],
173+
"transactions": [df["transaction_id"].nunique()],
174+
"customers": [df["customer_id"].nunique()],
175+
}
176+
if "quantity" in df.columns:
177+
aggs["total_quantity"] = ("quantity", "sum")
178+
total_aggs["total_quantity"] = [df["quantity"].sum()]
179+
180+
stats_df = pd.concat(
181+
[
182+
df.groupby(segment_col).agg(**aggs),
183+
pd.DataFrame(total_aggs, index=["total"]),
184+
],
185+
)
186+
187+
if "quantity" in df.columns:
157188
stats_df["price_per_unit"] = stats_df["revenue"] / stats_df["total_quantity"]
158189
stats_df["quantity_per_transaction"] = stats_df["total_quantity"] / stats_df["transactions"]
159-
elif TransactionLevelContract(df).validate() is True:
160-
stats_df = df.groupby(segment_col).agg(
161-
revenue=("total_price", "sum"),
162-
transactions=("transaction_id", "nunique"),
163-
customers=("customer_id", "nunique"),
164-
)
165-
else:
166-
raise NotImplementedError(
167-
"The dataframe does not comply with the TransactionItemLevelContract or TransactionLevelContract. "
168-
"These are the only two contracts supported at this time.",
169-
)
170-
total_num_customers = df["customer_id"].nunique()
171-
stats_df["spend_per_cust"] = stats_df["revenue"] / stats_df["customers"]
172-
stats_df["spend_per_transaction"] = stats_df["revenue"] / stats_df["transactions"]
173-
stats_df["transactions_per_customer"] = stats_df["transactions"] / stats_df["customers"]
174-
stats_df["customers_pct"] = stats_df["customers"] / total_num_customers
175190

176-
self.df = stats_df
191+
return stats_df
177192

178193
def plot(
179194
self,
@@ -185,6 +200,7 @@ def plot(
185200
orientation: Literal["vertical", "horizontal"] = "vertical",
186201
sort_order: Literal["ascending", "descending", None] = None,
187202
source_text: str | None = None,
203+
hide_total: bool = True,
188204
**kwargs: dict[str, any],
189205
) -> SubplotBase:
190206
"""Plots the value_col by segment.
@@ -203,6 +219,7 @@ def plot(
203219
sort_order (Literal["ascending", "descending", None], optional): The sort order of the segments.
204220
Defaults to None. If None, the segments are plotted in the order they appear in the dataframe.
205221
source_text (str, optional): The source text to add to the plot. Defaults to None.
222+
hide_total (bool, optional): Whether to hide the total row. Defaults to True.
206223
**kwargs: Additional keyword arguments to pass to the Pandas plot function.
207224
208225
Returns:
@@ -223,6 +240,9 @@ def plot(
223240
kind = "barh"
224241

225242
val_s = self.df[value_col]
243+
if hide_total:
244+
val_s = val_s[val_s.index != "total"]
245+
226246
if sort_order is not None:
227247
ascending = sort_order == "ascending"
228248
val_s = val_s.sort_values(ascending=ascending)

tests/test_segmentation.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
"""Tests for the SegTransactionStats class."""
2+
3+
import pandas as pd
4+
import pytest
5+
6+
from pyretailscience.segmentation import SegTransactionStats
7+
8+
9+
class TestCalcSegStats:
10+
"""Tests for the _calc_seg_stats method."""
11+
12+
@pytest.fixture()
13+
def base_df(self):
14+
"""Return a base DataFrame for testing."""
15+
return pd.DataFrame(
16+
{
17+
"customer_id": [1, 2, 3, 4, 5],
18+
"total_price": [100, 200, 150, 300, 250],
19+
"transaction_id": [101, 102, 103, 104, 105],
20+
"segment_id": ["A", "B", "A", "B", "A"],
21+
"quantity": [10, 20, 15, 30, 25],
22+
},
23+
)
24+
25+
def test_correctly_calculates_revenue_transactions_customers_per_segment(self, base_df):
26+
"""Test that the method correctly calculates at the transaction-item level."""
27+
expected_output = pd.DataFrame(
28+
{
29+
"revenue": [500, 500, 1000],
30+
"transactions": [3, 2, 5],
31+
"customers": [3, 2, 5],
32+
"total_quantity": [50, 50, 100],
33+
"price_per_unit": [10.0, 10.0, 10.0],
34+
"quantity_per_transaction": [16.666667, 25.0, 20.0],
35+
},
36+
index=["A", "B", "total"],
37+
)
38+
39+
segment_stats = SegTransactionStats._calc_seg_stats(base_df, "segment_id")
40+
pd.testing.assert_frame_equal(segment_stats, expected_output)
41+
42+
def test_correctly_calculates_revenue_transactions_customers(self):
43+
"""Test that the method correctly calculates at the transaction level."""
44+
df = pd.DataFrame(
45+
{
46+
"customer_id": [1, 2, 3, 4, 5],
47+
"total_price": [100, 200, 150, 300, 250],
48+
"transaction_id": [101, 102, 103, 104, 105],
49+
"segment_id": ["A", "B", "A", "B", "A"],
50+
},
51+
)
52+
53+
expected_output = pd.DataFrame(
54+
{
55+
"revenue": [500, 500, 1000],
56+
"transactions": [3, 2, 5],
57+
"customers": [3, 2, 5],
58+
},
59+
index=["A", "B", "total"],
60+
)
61+
62+
segment_stats = SegTransactionStats._calc_seg_stats(df, "segment_id")
63+
pd.testing.assert_frame_equal(segment_stats, expected_output)
64+
65+
def test_does_not_alter_original_dataframe(self, base_df):
66+
"""Test that the method does not alter the original DataFrame."""
67+
original_df = base_df.copy()
68+
_ = SegTransactionStats._calc_seg_stats(base_df, "segment_id")
69+
70+
pd.testing.assert_frame_equal(base_df, original_df)
71+
72+
def test_handles_dataframe_with_one_segment(self, base_df):
73+
"""Test that the method correctly handles a DataFrame with only one segment."""
74+
df = base_df.copy()
75+
df["segment_id"] = "A"
76+
77+
expected_output = pd.DataFrame(
78+
{
79+
"revenue": [1000, 1000],
80+
"transactions": [5, 5],
81+
"customers": [5, 5],
82+
"total_quantity": [100, 100],
83+
"price_per_unit": [10.0, 10.0],
84+
"quantity_per_transaction": [20.0, 20.0],
85+
},
86+
index=["A", "total"],
87+
)
88+
89+
segment_stats = SegTransactionStats._calc_seg_stats(df, "segment_id")
90+
pd.testing.assert_frame_equal(segment_stats, expected_output)
91+
92+
93+
class TestSegTransactionStats:
94+
"""Tests for the SegTransactionStats class."""
95+
96+
def test_handles_empty_dataframe_with_errors(self):
97+
"""Test that the method raises an error when the DataFrame is missing a required column."""
98+
df = pd.DataFrame(columns=["total_price", "transaction_id", "segment_id", "quantity"])
99+
100+
with pytest.raises(ValueError):
101+
SegTransactionStats(df, "segment_id")

0 commit comments

Comments
 (0)