2
2
3
3
from typing import Literal
4
4
5
- import duckdb
6
5
import ibis
7
6
import pandas as pd
8
- from duckdb import DuckDBPyRelation
9
7
from matplotlib .axes import Axes , SubplotBase
10
8
11
9
import pyretailscience .style .graph_utils as gu
@@ -155,7 +153,7 @@ class HMLSegmentation(ThresholdSegmentation):
155
153
156
154
def __init__ (
157
155
self ,
158
- df : pd .DataFrame ,
156
+ df : pd .DataFrame | ibis . Table ,
159
157
value_col : str | None = None ,
160
158
agg_func : str = "sum" ,
161
159
zero_value_customers : Literal ["separate_segment" , "exclude" , "include_with_light" ] = "separate_segment" ,
@@ -189,24 +187,27 @@ def __init__(
189
187
class SegTransactionStats :
190
188
"""Calculates transaction statistics by segment."""
191
189
192
- def __init__ (self , data : pd .DataFrame | DuckDBPyRelation , segment_col : str = "segment_name" ) -> None :
190
+ _df : pd .DataFrame | None = None
191
+
192
+ def __init__ (self , data : pd .DataFrame | ibis .Table , segment_col : str = "segment_name" ) -> None :
193
193
"""Calculates transaction statistics by segment.
194
194
195
195
Args:
196
- data (pd.DataFrame | DuckDBPyRelation ): The transaction data. The dataframe must contain the columns
196
+ data (pd.DataFrame | ibis.Table ): The transaction data. The dataframe must contain the columns
197
197
customer_id, unit_spend and transaction_id. If the dataframe contains the column unit_quantity, then
198
198
the columns unit_spend and unit_quantity are used to calculate the price_per_unit and
199
199
units_per_transaction.
200
200
segment_col (str, optional): The column to use for the segmentation. Defaults to "segment_name".
201
201
"""
202
+ cols = ColumnHelper ()
202
203
required_cols = [
203
- get_option ( "column .customer_id" ) ,
204
- get_option ( "column .unit_spend" ) ,
205
- get_option ( "column .transaction_id" ) ,
204
+ cols .customer_id ,
205
+ cols .unit_spend ,
206
+ cols .transaction_id ,
206
207
segment_col ,
207
208
]
208
- if get_option ( "column.unit_quantity" ) in data .columns :
209
- required_cols .append (get_option ( "column.unit_quantity" ) )
209
+ if cols . unit_qty in data .columns :
210
+ required_cols .append (cols . unit_qty )
210
211
211
212
missing_cols = set (required_cols ) - set (data .columns )
212
213
if len (missing_cols ) > 0 :
@@ -215,66 +216,103 @@ def __init__(self, data: pd.DataFrame | DuckDBPyRelation, segment_col: str = "se
215
216
216
217
self .segment_col = segment_col
217
218
218
- self .df = self ._calc_seg_stats (data , segment_col )
219
+ self .table = self ._calc_seg_stats (data , segment_col )
219
220
220
221
@staticmethod
221
- def _calc_seg_stats (data : pd .DataFrame | DuckDBPyRelation , segment_col : str ) -> pd .DataFrame :
222
+ def _get_col_order (include_quantity : bool ) -> list [str ]:
223
+ """Returns the default column order.
224
+
225
+ Columns should be supplied in the same order regardless of the function being called.
226
+
227
+ Args:
228
+ include_quantity (bool): Whether to include the columns related to quantity.
229
+
230
+ Returns:
231
+ list[str]: The default column order.
232
+ """
233
+ cols = ColumnHelper ()
234
+ col_order = [
235
+ cols .agg_unit_spend ,
236
+ cols .agg_transaction_id ,
237
+ cols .agg_customer_id ,
238
+ cols .calc_spend_per_cust ,
239
+ cols .calc_spend_per_trans ,
240
+ cols .calc_trans_per_cust ,
241
+ cols .customers_pct ,
242
+ ]
243
+ if include_quantity :
244
+ col_order .insert (3 , "units" )
245
+ col_order .insert (7 , cols .calc_units_per_trans )
246
+ col_order .insert (7 , cols .calc_price_per_unit )
247
+
248
+ return col_order
249
+
250
+ @staticmethod
251
+ def _calc_seg_stats (data : pd .DataFrame | ibis .Table , segment_col : str ) -> ibis .Table :
222
252
"""Calculates the transaction statistics by segment.
223
253
224
254
Args:
225
- data (DuckDBPyRelation ): The transaction data.
255
+ data (pd.DataFrame | ibis.Table ): The transaction data.
226
256
segment_col (str): The column to use for the segmentation.
227
257
228
258
Returns:
229
259
pd.DataFrame: The transaction statistics by segment.
230
260
231
261
"""
232
262
if isinstance (data , pd .DataFrame ):
233
- data = duckdb .from_df (data )
234
- elif not isinstance (data , DuckDBPyRelation ):
235
- raise TypeError ("data must be either a pandas DataFrame or a DuckDBPyRelation" )
236
-
237
- base_aggs = [
238
- f"SUM({ get_option ('column.unit_spend' )} ) as { get_option ('column.agg.unit_spend' )} ," ,
239
- f"COUNT(DISTINCT { get_option ('column.transaction_id' )} ) as { get_option ('column.agg.transaction_id' )} ," ,
240
- f"COUNT(DISTINCT { get_option ('column.customer_id' )} ) as { get_option ('column.agg.customer_id' )} ," ,
241
- ]
263
+ data = ibis .memtable (data )
242
264
243
- total_customers = data .aggregate ("COUNT(DISTINCT customer_id)" ).fetchone ()[0 ]
244
- return_cols = [
245
- "*," ,
246
- f"{ get_option ('column.agg.unit_spend' )} / { get_option ('column.agg.customer_id' )} " ,
247
- f"as { get_option ('column.calc.spend_per_customer' )} ," ,
248
- f"{ get_option ('column.agg.unit_spend' )} / { get_option ('column.agg.transaction_id' )} " ,
249
- f"as { get_option ('column.calc.spend_per_transaction' )} ," ,
250
- f"{ get_option ('column.agg.transaction_id' )} / { get_option ('column.agg.customer_id' )} " ,
251
- f"as { get_option ('column.calc.transactions_per_customer' )} ," ,
252
- f"{ get_option ('column.agg.customer_id' )} / { total_customers } " ,
253
- f"as customers_{ get_option ('column.suffix.percent' )} ," ,
254
- ]
265
+ elif not isinstance (data , ibis .Table ):
266
+ raise TypeError ("data must be either a pandas DataFrame or a ibis Table" )
255
267
256
- if get_option ("column.unit_quantity" ) in data .columns :
257
- base_aggs .append (
258
- f"SUM({ get_option ('column.unit_quantity' )} )::bigint as { get_option ('column.agg.unit_quantity' )} ," ,
259
- )
260
- return_cols .extend (
261
- [
262
- f"({ get_option ('column.agg.unit_spend' )} / { get_option ('column.agg.unit_quantity' )} ) " ,
263
- f"as { get_option ('column.calc.price_per_unit' )} ," ,
264
- f"({ get_option ('column.agg.unit_quantity' )} / { get_option ('column.agg.transaction_id' )} ) " ,
265
- f"as { get_option ('column.calc.units_per_transaction' )} ," ,
266
- ],
267
- )
268
+ cols = ColumnHelper ()
268
269
269
- segment_stats = data .aggregate (f"{ segment_col } as segment_name," + "" .join (base_aggs ))
270
- total_stats = data .aggregate ("'Total' as segment_name," + "" .join (base_aggs ))
271
- final_stats_df = segment_stats .union (total_stats ).select ("" .join (return_cols )).df ()
272
- final_stats_df = final_stats_df .set_index ("segment_name" ).sort_index ()
270
+ # Base aggregations for segments
271
+ aggs = {
272
+ cols .agg_unit_spend : data [cols .unit_spend ].sum (),
273
+ cols .agg_transaction_id : data [cols .transaction_id ].nunique (),
274
+ cols .agg_customer_id : data [cols .customer_id ].nunique (),
275
+ }
276
+ if cols .unit_qty in data .columns :
277
+ aggs [cols .agg_unit_qty ] = data [cols .unit_qty ].sum ()
278
+
279
+ # Calculate metrics for segments and total
280
+ segment_metrics = data .group_by (segment_col ).aggregate (** aggs )
281
+ total_metrics = data .aggregate (** aggs ).mutate (** {segment_col : ibis .literal ("Total" )})
282
+
283
+ total_customers = data [cols .customer_id ].nunique ()
284
+
285
+ # Cross join with total_customers to make it available for percentage calculation
286
+ final_metrics = ibis .union (segment_metrics , total_metrics ).mutate (
287
+ ** {
288
+ cols .calc_spend_per_cust : ibis ._ [cols .agg_unit_spend ] / ibis ._ [cols .agg_customer_id ],
289
+ cols .calc_spend_per_trans : ibis ._ [cols .agg_unit_spend ] / ibis ._ [cols .agg_transaction_id ],
290
+ cols .calc_trans_per_cust : ibis ._ [cols .agg_transaction_id ] / ibis ._ [cols .agg_customer_id ],
291
+ cols .customers_pct : ibis ._ [cols .agg_customer_id ] / total_customers ,
292
+ },
293
+ )
273
294
274
- # Make sure Total is the last row
275
- desired_index_sort = final_stats_df .index .drop ("Total" ).tolist () + ["Total" ] # noqa: RUF005
295
+ if cols .unit_qty in data .columns :
296
+ final_metrics = final_metrics .mutate (
297
+ ** {
298
+ cols .calc_price_per_unit : ibis ._ [cols .agg_unit_spend ] / ibis ._ [cols .agg_unit_qty ],
299
+ cols .calc_units_per_trans : ibis ._ [cols .agg_unit_qty ] / ibis ._ [cols .agg_transaction_id ],
300
+ },
301
+ )
302
+
303
+ return final_metrics
276
304
277
- return final_stats_df .reindex (desired_index_sort )
305
+ @property
306
+ def df (self ) -> pd .DataFrame :
307
+ """Returns the dataframe with the transaction statistics by segment."""
308
+ if self ._df is None :
309
+ cols = ColumnHelper ()
310
+ col_order = [
311
+ self .segment_col ,
312
+ * SegTransactionStats ._get_col_order (include_quantity = cols .agg_unit_qty in self .table .columns ),
313
+ ]
314
+ self ._df = self .table .execute ()[col_order ]
315
+ return self ._df
278
316
279
317
def plot (
280
318
self ,
@@ -325,9 +363,9 @@ def plot(
325
363
if orientation == "horizontal" :
326
364
kind = "barh"
327
365
328
- val_s = self .df [value_col ]
366
+ val_s = self .df . set_index ( self . segment_col ) [value_col ]
329
367
if hide_total :
330
- val_s = val_s [val_s .index != "total " ]
368
+ val_s = val_s [val_s .index != "Total " ]
331
369
332
370
if sort_order is not None :
333
371
ascending = sort_order == "ascending"
0 commit comments