Skip to content

Commit de3a12b

Browse files
committed
feat: convert seg stats to use Ibis
1 parent f4b4825 commit de3a12b

File tree

5 files changed

+151
-135
lines changed

5 files changed

+151
-135
lines changed

docs/examples/segmentation.ipynb

Lines changed: 43 additions & 61 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ ignore = [
6868
"TRY003", # Disable until we start creating proper exception classes
6969
"PT011", # Disable until we start creating proper exception classes
7070
"PTH123", # Not using open() to open files
71+
"SLF001", # Ibis makes a lot of use of the ibis._[column] which triggers this
7172
]
7273
select = [
7374
"A", # Builtins

pyretailscience/options.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def load_from_toml(cls, file_path: str) -> "Options":
237237

238238
for section, options in toml_data.items():
239239
for option_name, option_value in Options.flatten_options(section, options).items():
240-
if option_name in options_instance._options: # noqa: SLF001
240+
if option_name in options_instance._options:
241241
options_instance.set_option(option_name, option_value)
242242
else:
243243
msg = f"Unknown option in TOML file: {option_name}"
@@ -392,6 +392,7 @@ def __init__(self) -> None:
392392
self.agg_customer_id_diff = self.join_options("column.agg.customer_id", "column.suffix.difference")
393393
self.agg_customer_id_pct_diff = self.join_options("column.agg.customer_id", "column.suffix.percent_difference")
394394
self.agg_customer_id_contrib = self.join_options("column.agg.customer_id", "column.suffix.contribution")
395+
self.customers_pct = self.join_options("column.agg.customer_id", "column.suffix.percent")
395396
# Transactions
396397
self.transaction_id = get_option("column.transaction_id")
397398
self.agg_transaction_id = get_option("column.agg.transaction_id")

pyretailscience/segmentation.py

Lines changed: 93 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,8 @@
22

33
from typing import Literal
44

5-
import duckdb
65
import ibis
76
import pandas as pd
8-
from duckdb import DuckDBPyRelation
97
from matplotlib.axes import Axes, SubplotBase
108

119
import pyretailscience.style.graph_utils as gu
@@ -155,7 +153,7 @@ class HMLSegmentation(ThresholdSegmentation):
155153

156154
def __init__(
157155
self,
158-
df: pd.DataFrame,
156+
df: pd.DataFrame | ibis.Table,
159157
value_col: str | None = None,
160158
agg_func: str = "sum",
161159
zero_value_customers: Literal["separate_segment", "exclude", "include_with_light"] = "separate_segment",
@@ -189,24 +187,27 @@ def __init__(
189187
class SegTransactionStats:
190188
"""Calculates transaction statistics by segment."""
191189

192-
def __init__(self, data: pd.DataFrame | DuckDBPyRelation, segment_col: str = "segment_name") -> None:
190+
_df: pd.DataFrame | None = None
191+
192+
def __init__(self, data: pd.DataFrame | ibis.Table, segment_col: str = "segment_name") -> None:
193193
"""Calculates transaction statistics by segment.
194194
195195
Args:
196-
data (pd.DataFrame | DuckDBPyRelation): The transaction data. The dataframe must contain the columns
196+
data (pd.DataFrame | ibis.Table): The transaction data. The dataframe must contain the columns
197197
customer_id, unit_spend and transaction_id. If the dataframe contains the column unit_quantity, then
198198
the columns unit_spend and unit_quantity are used to calculate the price_per_unit and
199199
units_per_transaction.
200200
segment_col (str, optional): The column to use for the segmentation. Defaults to "segment_name".
201201
"""
202+
cols = ColumnHelper()
202203
required_cols = [
203-
get_option("column.customer_id"),
204-
get_option("column.unit_spend"),
205-
get_option("column.transaction_id"),
204+
cols.customer_id,
205+
cols.unit_spend,
206+
cols.transaction_id,
206207
segment_col,
207208
]
208-
if get_option("column.unit_quantity") in data.columns:
209-
required_cols.append(get_option("column.unit_quantity"))
209+
if cols.unit_qty in data.columns:
210+
required_cols.append(cols.unit_qty)
210211

211212
missing_cols = set(required_cols) - set(data.columns)
212213
if len(missing_cols) > 0:
@@ -215,66 +216,103 @@ def __init__(self, data: pd.DataFrame | DuckDBPyRelation, segment_col: str = "se
215216

216217
self.segment_col = segment_col
217218

218-
self.df = self._calc_seg_stats(data, segment_col)
219+
self.table = self._calc_seg_stats(data, segment_col)
219220

220221
@staticmethod
221-
def _calc_seg_stats(data: pd.DataFrame | DuckDBPyRelation, segment_col: str) -> pd.DataFrame:
222+
def _get_col_order(include_quantity: bool) -> list[str]:
223+
"""Returns the default column order.
224+
225+
Columns should be supplied in the same order regardless of the function being called.
226+
227+
Args:
228+
include_quantity (bool): Whether to include the columns related to quantity.
229+
230+
Returns:
231+
list[str]: The default column order.
232+
"""
233+
cols = ColumnHelper()
234+
col_order = [
235+
cols.agg_unit_spend,
236+
cols.agg_transaction_id,
237+
cols.agg_customer_id,
238+
cols.calc_spend_per_cust,
239+
cols.calc_spend_per_trans,
240+
cols.calc_trans_per_cust,
241+
cols.customers_pct,
242+
]
243+
if include_quantity:
244+
col_order.insert(3, "units")
245+
col_order.insert(7, cols.calc_units_per_trans)
246+
col_order.insert(7, cols.calc_price_per_unit)
247+
248+
return col_order
249+
250+
@staticmethod
251+
def _calc_seg_stats(data: pd.DataFrame | ibis.Table, segment_col: str) -> ibis.Table:
222252
"""Calculates the transaction statistics by segment.
223253
224254
Args:
225-
data (DuckDBPyRelation): The transaction data.
255+
data (pd.DataFrame | ibis.Table): The transaction data.
226256
segment_col (str): The column to use for the segmentation.
227257
228258
Returns:
229259
pd.DataFrame: The transaction statistics by segment.
230260
231261
"""
232262
if isinstance(data, pd.DataFrame):
233-
data = duckdb.from_df(data)
234-
elif not isinstance(data, DuckDBPyRelation):
235-
raise TypeError("data must be either a pandas DataFrame or a DuckDBPyRelation")
236-
237-
base_aggs = [
238-
f"SUM({get_option('column.unit_spend')}) as {get_option('column.agg.unit_spend')},",
239-
f"COUNT(DISTINCT {get_option('column.transaction_id')}) as {get_option('column.agg.transaction_id')},",
240-
f"COUNT(DISTINCT {get_option('column.customer_id')}) as {get_option('column.agg.customer_id')},",
241-
]
263+
data = ibis.memtable(data)
242264

243-
total_customers = data.aggregate("COUNT(DISTINCT customer_id)").fetchone()[0]
244-
return_cols = [
245-
"*,",
246-
f"{get_option('column.agg.unit_spend')} / {get_option('column.agg.customer_id')} ",
247-
f"as {get_option('column.calc.spend_per_customer')},",
248-
f"{get_option('column.agg.unit_spend')} / {get_option('column.agg.transaction_id')} ",
249-
f"as {get_option('column.calc.spend_per_transaction')},",
250-
f"{get_option('column.agg.transaction_id')} / {get_option('column.agg.customer_id')} ",
251-
f"as {get_option('column.calc.transactions_per_customer')},",
252-
f"{get_option('column.agg.customer_id')} / {total_customers}",
253-
f"as customers_{get_option('column.suffix.percent')},",
254-
]
265+
elif not isinstance(data, ibis.Table):
266+
raise TypeError("data must be either a pandas DataFrame or a ibis Table")
255267

256-
if get_option("column.unit_quantity") in data.columns:
257-
base_aggs.append(
258-
f"SUM({get_option('column.unit_quantity')})::bigint as {get_option('column.agg.unit_quantity')},",
259-
)
260-
return_cols.extend(
261-
[
262-
f"({get_option('column.agg.unit_spend')} / {get_option('column.agg.unit_quantity')}) ",
263-
f"as {get_option('column.calc.price_per_unit')},",
264-
f"({get_option('column.agg.unit_quantity')} / {get_option('column.agg.transaction_id')}) ",
265-
f"as {get_option('column.calc.units_per_transaction')},",
266-
],
267-
)
268+
cols = ColumnHelper()
268269

269-
segment_stats = data.aggregate(f"{segment_col} as segment_name," + "".join(base_aggs))
270-
total_stats = data.aggregate("'Total' as segment_name," + "".join(base_aggs))
271-
final_stats_df = segment_stats.union(total_stats).select("".join(return_cols)).df()
272-
final_stats_df = final_stats_df.set_index("segment_name").sort_index()
270+
# Base aggregations for segments
271+
aggs = {
272+
cols.agg_unit_spend: data[cols.unit_spend].sum(),
273+
cols.agg_transaction_id: data[cols.transaction_id].nunique(),
274+
cols.agg_customer_id: data[cols.customer_id].nunique(),
275+
}
276+
if cols.unit_qty in data.columns:
277+
aggs[cols.agg_unit_qty] = data[cols.unit_qty].sum()
278+
279+
# Calculate metrics for segments and total
280+
segment_metrics = data.group_by(segment_col).aggregate(**aggs)
281+
total_metrics = data.aggregate(**aggs).mutate(**{segment_col: ibis.literal("Total")})
282+
283+
total_customers = data[cols.customer_id].nunique()
284+
285+
# Cross join with total_customers to make it available for percentage calculation
286+
final_metrics = ibis.union(segment_metrics, total_metrics).mutate(
287+
**{
288+
cols.calc_spend_per_cust: ibis._[cols.agg_unit_spend] / ibis._[cols.agg_customer_id],
289+
cols.calc_spend_per_trans: ibis._[cols.agg_unit_spend] / ibis._[cols.agg_transaction_id],
290+
cols.calc_trans_per_cust: ibis._[cols.agg_transaction_id] / ibis._[cols.agg_customer_id],
291+
cols.customers_pct: ibis._[cols.agg_customer_id] / total_customers,
292+
},
293+
)
273294

274-
# Make sure Total is the last row
275-
desired_index_sort = final_stats_df.index.drop("Total").tolist() + ["Total"] # noqa: RUF005
295+
if cols.unit_qty in data.columns:
296+
final_metrics = final_metrics.mutate(
297+
**{
298+
cols.calc_price_per_unit: ibis._[cols.agg_unit_spend] / ibis._[cols.agg_unit_qty],
299+
cols.calc_units_per_trans: ibis._[cols.agg_unit_qty] / ibis._[cols.agg_transaction_id],
300+
},
301+
)
302+
303+
return final_metrics
276304

277-
return final_stats_df.reindex(desired_index_sort)
305+
@property
306+
def df(self) -> pd.DataFrame:
307+
"""Returns the dataframe with the transaction statistics by segment."""
308+
if self._df is None:
309+
cols = ColumnHelper()
310+
col_order = [
311+
self.segment_col,
312+
*SegTransactionStats._get_col_order(include_quantity=cols.agg_unit_qty in self.table.columns),
313+
]
314+
self._df = self.table.execute()[col_order]
315+
return self._df
278316

279317
def plot(
280318
self,
@@ -325,9 +363,9 @@ def plot(
325363
if orientation == "horizontal":
326364
kind = "barh"
327365

328-
val_s = self.df[value_col]
366+
val_s = self.df.set_index(self.segment_col)[value_col]
329367
if hide_total:
330-
val_s = val_s[val_s.index != "total"]
368+
val_s = val_s[val_s.index != "Total"]
331369

332370
if sort_order is not None:
333371
ascending = sort_order == "ascending"

tests/test_segmentation.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def base_df(self):
1818
return pd.DataFrame(
1919
{
2020
cols.customer_id: [1, 2, 3, 4, 5],
21-
cols.unit_spend: [100, 200, 150, 300, 250],
21+
cols.unit_spend: [100.0, 200.0, 150.0, 300.0, 250.0],
2222
cols.transaction_id: [101, 102, 103, 104, 105],
2323
"segment_name": ["A", "B", "A", "B", "A"],
2424
cols.unit_qty: [10, 20, 15, 30, 25],
@@ -37,21 +37,22 @@ def test_correctly_calculates_revenue_transactions_customers_per_segment(self, b
3737
cols.calc_spend_per_cust: [166.666667, 250.0, 200.0],
3838
cols.calc_spend_per_trans: [166.666667, 250.0, 200.0],
3939
cols.calc_trans_per_cust: [1.0, 1.0, 1.0],
40-
f"customers_{get_option('column.suffix.percent')}": [0.6, 0.4, 1.0],
4140
cols.calc_price_per_unit: [10.0, 10.0, 10.0],
4241
cols.calc_units_per_trans: [16.666667, 25.0, 20.0],
42+
f"customers_{get_option('column.suffix.percent')}": [0.6, 0.4, 1.0],
4343
},
44-
).set_index("segment_name")
45-
46-
segment_stats = SegTransactionStats._calc_seg_stats(base_df, "segment_name")
44+
)
45+
segment_stats = (
46+
SegTransactionStats(base_df, "segment_name").df.sort_values("segment_name").reset_index(drop=True)
47+
)
4748
pd.testing.assert_frame_equal(segment_stats, expected_output)
4849

4950
def test_correctly_calculates_revenue_transactions_customers(self):
5051
"""Test that the method correctly calculates at the transaction level."""
5152
df = pd.DataFrame(
5253
{
5354
get_option("column.customer_id"): [1, 2, 3, 4, 5],
54-
cols.unit_spend: [100, 200, 150, 300, 250],
55+
cols.unit_spend: [100.0, 200.0, 150.0, 300.0, 250.0],
5556
cols.transaction_id: [101, 102, 103, 104, 105],
5657
"segment_name": ["A", "B", "A", "B", "A"],
5758
},
@@ -68,18 +69,11 @@ def test_correctly_calculates_revenue_transactions_customers(self):
6869
cols.calc_trans_per_cust: [1.0, 1.0, 1.0],
6970
f"customers_{get_option('column.suffix.percent')}": [0.6, 0.4, 1.0],
7071
},
71-
).set_index("segment_name")
72+
)
7273

73-
segment_stats = SegTransactionStats._calc_seg_stats(df, "segment_name")
74+
segment_stats = SegTransactionStats(df, "segment_name").df.sort_values("segment_name").reset_index(drop=True)
7475
pd.testing.assert_frame_equal(segment_stats, expected_output)
7576

76-
def test_does_not_alter_original_dataframe(self, base_df):
77-
"""Test that the method does not alter the original DataFrame."""
78-
original_df = base_df.copy()
79-
_ = SegTransactionStats._calc_seg_stats(base_df, "segment_name")
80-
81-
pd.testing.assert_frame_equal(base_df, original_df)
82-
8377
def test_handles_dataframe_with_one_segment(self, base_df):
8478
"""Test that the method correctly handles a DataFrame with only one segment."""
8579
df = base_df.copy()
@@ -95,13 +89,13 @@ def test_handles_dataframe_with_one_segment(self, base_df):
9589
cols.calc_spend_per_cust: [200.0, 200.0],
9690
cols.calc_spend_per_trans: [200.0, 200.0],
9791
cols.calc_trans_per_cust: [1.0, 1.0],
98-
f"customers_{get_option('column.suffix.percent')}": [1.0, 1.0],
9992
cols.calc_price_per_unit: [10.0, 10.0],
10093
cols.calc_units_per_trans: [20.0, 20.0],
94+
f"customers_{get_option('column.suffix.percent')}": [1.0, 1.0],
10195
},
102-
).set_index("segment_name")
96+
)
10397

104-
segment_stats = SegTransactionStats._calc_seg_stats(df, "segment_name")
98+
segment_stats = SegTransactionStats(df, "segment_name").df
10599
pd.testing.assert_frame_equal(segment_stats, expected_output)
106100

107101

0 commit comments

Comments
 (0)