Skip to content

Commit 91edbf4

Browse files
authored
Merge pull request #153 from Data-Simply/feature/seg-transaction-stats
Seg-Transaction-Stats
2 parents a12b64b + af3dca1 commit 91edbf4

File tree

3 files changed

+46
-7
lines changed

3 files changed

+46
-7
lines changed

pyretailscience/analysis/segmentation.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ def __init__(
194194
self,
195195
data: pd.DataFrame | ibis.Table,
196196
segment_col: str | list[str] = "segment_name",
197+
calc_total: bool = True,
197198
extra_aggs: dict[str, tuple[str, str]] | None = None,
198199
) -> None:
199200
"""Calculates transaction statistics by segment.
@@ -205,6 +206,7 @@ def __init__(
205206
units_per_transaction.
206207
segment_col (str | list[str], optional): The column or list of columns to use for the segmentation.
207208
Defaults to "segment_name".
209+
calc_total (bool, optional): Whether to include the total row. Defaults to True.
208210
extra_aggs (dict[str, tuple[str, str]], optional): Additional aggregations to perform.
209211
The keys in the dictionary will be the column names for the aggregation results.
210212
The values are tuples with (column_name, aggregation_function), where:
@@ -244,7 +246,7 @@ def __init__(
244246
self.segment_col = segment_col
245247
self.extra_aggs = {} if extra_aggs is None else extra_aggs
246248

247-
self.table = self._calc_seg_stats(data, segment_col, self.extra_aggs)
249+
self.table = self._calc_seg_stats(data, segment_col, calc_total, self.extra_aggs)
248250

249251
@staticmethod
250252
def _get_col_order(include_quantity: bool) -> list[str]:
@@ -279,6 +281,7 @@ def _get_col_order(include_quantity: bool) -> list[str]:
279281
def _calc_seg_stats(
280282
data: pd.DataFrame | ibis.Table,
281283
segment_col: list[str],
284+
calc_total: bool = True,
282285
extra_aggs: dict[str, tuple[str, str]] | None = None,
283286
) -> ibis.Table:
284287
"""Calculates the transaction statistics by segment.
@@ -287,6 +290,7 @@ def _calc_seg_stats(
287290
data (pd.DataFrame | ibis.Table): The transaction data.
288291
segment_col (list[str]): The columns to use for the segmentation.
289292
extra_aggs (dict[str, tuple[str, str]], optional): Additional aggregations to perform.
293+
calc_total (bool, optional): Whether to include the total row. Defaults to True.
290294
The keys in the dictionary will be the column names for the aggregation results.
291295
The values are tuples with (column_name, aggregation_function).
292296
@@ -298,7 +302,7 @@ def _calc_seg_stats(
298302
data = ibis.memtable(data)
299303

300304
elif not isinstance(data, ibis.Table):
301-
raise TypeError("data must be either a pandas DataFrame or a ibis Table")
305+
raise TypeError("data must be either a pandas DataFrame or an ibis Table")
302306

303307
cols = ColumnHelper()
304308

@@ -317,13 +321,18 @@ def _calc_seg_stats(
317321
col, func = col_tuple
318322
aggs[agg_name] = getattr(data[col], func)()
319323

320-
# Calculate metrics for segments and total
324+
# Calculate metrics for segments
321325
segment_metrics = data.group_by(segment_col).aggregate(**aggs)
322-
total_metrics = data.aggregate(**aggs).mutate({col: ibis.literal("Total") for col in segment_col})
326+
final_metrics = segment_metrics
327+
328+
if calc_total:
329+
total_metrics = data.aggregate(**aggs).mutate({col: ibis.literal("Total") for col in segment_col})
330+
final_metrics = ibis.union(segment_metrics, total_metrics)
331+
323332
total_customers = data[cols.customer_id].nunique()
324333

325334
# Cross join with total_customers to make it available for percentage calculation
326-
final_metrics = ibis.union(segment_metrics, total_metrics).mutate(
335+
final_metrics = final_metrics.mutate(
327336
**{
328337
cols.calc_spend_per_cust: ibis._[cols.agg_unit_spend] / ibis._[cols.agg_customer_id],
329338
cols.calc_spend_per_trans: ibis._[cols.agg_unit_spend] / ibis._[cols.agg_transaction_id],

tests/analysis/test_segmentation.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,32 @@ def test_handles_dataframe_with_zero_net_units(self, base_df):
130130

131131
pd.testing.assert_frame_equal(segment_stats, expected_output)
132132

133+
def test_excludes_total_row_when_calc_total_false(self, base_df):
134+
"""Test that the method excludes the total row when calc_total=False."""
135+
expected_output = pd.DataFrame(
136+
{
137+
"segment_name": ["A", "B"],
138+
cols.agg_unit_spend: [500.0, 500.0],
139+
cols.agg_transaction_id: [3, 2],
140+
cols.agg_customer_id: [3, 2],
141+
cols.agg_unit_qty: [50, 50],
142+
cols.calc_spend_per_cust: [166.666667, 250.0],
143+
cols.calc_spend_per_trans: [166.666667, 250.0],
144+
cols.calc_trans_per_cust: [1.0, 1.0],
145+
cols.calc_price_per_unit: [10.0, 10.0],
146+
cols.calc_units_per_trans: [16.666667, 25.0],
147+
cols.customers_pct: [1.0, 1.0],
148+
},
149+
)
150+
151+
segment_stats = (
152+
SegTransactionStats(base_df, "segment_name", calc_total=False)
153+
.df.sort_values("segment_name")
154+
.reset_index(drop=True)
155+
)
156+
157+
pd.testing.assert_frame_equal(segment_stats, expected_output)
158+
133159

134160
class TestThresholdSegmentation:
135161
"""Tests for the ThresholdSegmentation class."""
@@ -407,7 +433,11 @@ def test_extra_aggs_functionality(self):
407433
)
408434

409435
# Test with a single extra aggregation
410-
seg_stats = SegTransactionStats(df, "segment_name", extra_aggs={"distinct_stores": ("store_id", "nunique")})
436+
seg_stats = SegTransactionStats(
437+
df,
438+
"segment_name",
439+
extra_aggs={"distinct_stores": ("store_id", "nunique")},
440+
)
411441

412442
# Verify the extra column exists and has correct values
413443
assert "distinct_stores" in seg_stats.df.columns

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)