22
33from typing import Literal
44
5- import duckdb
65import ibis
76import pandas as pd
8- from duckdb import DuckDBPyRelation
97from matplotlib .axes import Axes , SubplotBase
108
119import pyretailscience .style .graph_utils as gu
@@ -155,7 +153,7 @@ class HMLSegmentation(ThresholdSegmentation):
155153
156154 def __init__ (
157155 self ,
158- df : pd .DataFrame ,
156+ df : pd .DataFrame | ibis . Table ,
159157 value_col : str | None = None ,
160158 agg_func : str = "sum" ,
161159 zero_value_customers : Literal ["separate_segment" , "exclude" , "include_with_light" ] = "separate_segment" ,
@@ -189,24 +187,27 @@ def __init__(
189187class SegTransactionStats :
190188 """Calculates transaction statistics by segment."""
191189
192- def __init__ (self , data : pd .DataFrame | DuckDBPyRelation , segment_col : str = "segment_name" ) -> None :
190+ _df : pd .DataFrame | None = None
191+
192+ def __init__ (self , data : pd .DataFrame | ibis .Table , segment_col : str = "segment_name" ) -> None :
193193 """Calculates transaction statistics by segment.
194194
195195 Args:
196- data (pd.DataFrame | DuckDBPyRelation ): The transaction data. The dataframe must contain the columns
196+ data (pd.DataFrame | ibis.Table ): The transaction data. The dataframe must contain the columns
197197 customer_id, unit_spend and transaction_id. If the dataframe contains the column unit_quantity, then
198198 the columns unit_spend and unit_quantity are used to calculate the price_per_unit and
199199 units_per_transaction.
200200 segment_col (str, optional): The column to use for the segmentation. Defaults to "segment_name".
201201 """
202+ cols = ColumnHelper ()
202203 required_cols = [
203- get_option ( "column .customer_id" ) ,
204- get_option ( "column .unit_spend" ) ,
205- get_option ( "column .transaction_id" ) ,
204+ cols .customer_id ,
205+ cols .unit_spend ,
206+ cols .transaction_id ,
206207 segment_col ,
207208 ]
208- if get_option ( "column.unit_quantity" ) in data .columns :
209- required_cols .append (get_option ( "column.unit_quantity" ) )
209+ if cols . unit_qty in data .columns :
210+ required_cols .append (cols . unit_qty )
210211
211212 missing_cols = set (required_cols ) - set (data .columns )
212213 if len (missing_cols ) > 0 :
@@ -215,66 +216,103 @@ def __init__(self, data: pd.DataFrame | DuckDBPyRelation, segment_col: str = "se
215216
216217 self .segment_col = segment_col
217218
218- self .df = self ._calc_seg_stats (data , segment_col )
219+ self .table = self ._calc_seg_stats (data , segment_col )
219220
220221 @staticmethod
221- def _calc_seg_stats (data : pd .DataFrame | DuckDBPyRelation , segment_col : str ) -> pd .DataFrame :
222+ def _get_col_order (include_quantity : bool ) -> list [str ]:
223+ """Returns the default column order.
224+
225+ Columns should be supplied in the same order regardless of the function being called.
226+
227+ Args:
228+ include_quantity (bool): Whether to include the columns related to quantity.
229+
230+ Returns:
231+ list[str]: The default column order.
232+ """
233+ cols = ColumnHelper ()
234+ col_order = [
235+ cols .agg_unit_spend ,
236+ cols .agg_transaction_id ,
237+ cols .agg_customer_id ,
238+ cols .calc_spend_per_cust ,
239+ cols .calc_spend_per_trans ,
240+ cols .calc_trans_per_cust ,
241+ cols .customers_pct ,
242+ ]
243+ if include_quantity :
244+ col_order .insert (3 , "units" )
245+ col_order .insert (7 , cols .calc_units_per_trans )
246+ col_order .insert (7 , cols .calc_price_per_unit )
247+
248+ return col_order
249+
250+ @staticmethod
251+ def _calc_seg_stats (data : pd .DataFrame | ibis .Table , segment_col : str ) -> ibis .Table :
222252 """Calculates the transaction statistics by segment.
223253
224254 Args:
225- data (DuckDBPyRelation ): The transaction data.
255+ data (pd.DataFrame | ibis.Table ): The transaction data.
226256 segment_col (str): The column to use for the segmentation.
227257
228258 Returns:
229259 pd.DataFrame: The transaction statistics by segment.
230260
231261 """
232262 if isinstance (data , pd .DataFrame ):
233- data = duckdb .from_df (data )
234- elif not isinstance (data , DuckDBPyRelation ):
235- raise TypeError ("data must be either a pandas DataFrame or a DuckDBPyRelation" )
236-
237- base_aggs = [
238- f"SUM({ get_option ('column.unit_spend' )} ) as { get_option ('column.agg.unit_spend' )} ," ,
239- f"COUNT(DISTINCT { get_option ('column.transaction_id' )} ) as { get_option ('column.agg.transaction_id' )} ," ,
240- f"COUNT(DISTINCT { get_option ('column.customer_id' )} ) as { get_option ('column.agg.customer_id' )} ," ,
241- ]
263+ data = ibis .memtable (data )
242264
243- total_customers = data .aggregate ("COUNT(DISTINCT customer_id)" ).fetchone ()[0 ]
244- return_cols = [
245- "*," ,
246- f"{ get_option ('column.agg.unit_spend' )} / { get_option ('column.agg.customer_id' )} " ,
247- f"as { get_option ('column.calc.spend_per_customer' )} ," ,
248- f"{ get_option ('column.agg.unit_spend' )} / { get_option ('column.agg.transaction_id' )} " ,
249- f"as { get_option ('column.calc.spend_per_transaction' )} ," ,
250- f"{ get_option ('column.agg.transaction_id' )} / { get_option ('column.agg.customer_id' )} " ,
251- f"as { get_option ('column.calc.transactions_per_customer' )} ," ,
252- f"{ get_option ('column.agg.customer_id' )} / { total_customers } " ,
253- f"as customers_{ get_option ('column.suffix.percent' )} ," ,
254- ]
265+ elif not isinstance (data , ibis .Table ):
266+ raise TypeError ("data must be either a pandas DataFrame or a ibis Table" )
255267
256- if get_option ("column.unit_quantity" ) in data .columns :
257- base_aggs .append (
258- f"SUM({ get_option ('column.unit_quantity' )} )::bigint as { get_option ('column.agg.unit_quantity' )} ," ,
259- )
260- return_cols .extend (
261- [
262- f"({ get_option ('column.agg.unit_spend' )} / { get_option ('column.agg.unit_quantity' )} ) " ,
263- f"as { get_option ('column.calc.price_per_unit' )} ," ,
264- f"({ get_option ('column.agg.unit_quantity' )} / { get_option ('column.agg.transaction_id' )} ) " ,
265- f"as { get_option ('column.calc.units_per_transaction' )} ," ,
266- ],
267- )
268+ cols = ColumnHelper ()
268269
269- segment_stats = data .aggregate (f"{ segment_col } as segment_name," + "" .join (base_aggs ))
270- total_stats = data .aggregate ("'Total' as segment_name," + "" .join (base_aggs ))
271- final_stats_df = segment_stats .union (total_stats ).select ("" .join (return_cols )).df ()
272- final_stats_df = final_stats_df .set_index ("segment_name" ).sort_index ()
270+ # Base aggregations for segments
271+ aggs = {
272+ cols .agg_unit_spend : data [cols .unit_spend ].sum (),
273+ cols .agg_transaction_id : data [cols .transaction_id ].nunique (),
274+ cols .agg_customer_id : data [cols .customer_id ].nunique (),
275+ }
276+ if cols .unit_qty in data .columns :
277+ aggs [cols .agg_unit_qty ] = data [cols .unit_qty ].sum ()
278+
279+ # Calculate metrics for segments and total
280+ segment_metrics = data .group_by (segment_col ).aggregate (** aggs )
281+ total_metrics = data .aggregate (** aggs ).mutate (** {segment_col : ibis .literal ("Total" )})
282+
283+ total_customers = data [cols .customer_id ].nunique ()
284+
285+ # Cross join with total_customers to make it available for percentage calculation
286+ final_metrics = ibis .union (segment_metrics , total_metrics ).mutate (
287+ ** {
288+ cols .calc_spend_per_cust : ibis ._ [cols .agg_unit_spend ] / ibis ._ [cols .agg_customer_id ],
289+ cols .calc_spend_per_trans : ibis ._ [cols .agg_unit_spend ] / ibis ._ [cols .agg_transaction_id ],
290+ cols .calc_trans_per_cust : ibis ._ [cols .agg_transaction_id ] / ibis ._ [cols .agg_customer_id ],
291+ cols .customers_pct : ibis ._ [cols .agg_customer_id ] / total_customers ,
292+ },
293+ )
273294
274- # Make sure Total is the last row
275- desired_index_sort = final_stats_df .index .drop ("Total" ).tolist () + ["Total" ] # noqa: RUF005
295+ if cols .unit_qty in data .columns :
296+ final_metrics = final_metrics .mutate (
297+ ** {
298+ cols .calc_price_per_unit : ibis ._ [cols .agg_unit_spend ] / ibis ._ [cols .agg_unit_qty ],
299+ cols .calc_units_per_trans : ibis ._ [cols .agg_unit_qty ] / ibis ._ [cols .agg_transaction_id ],
300+ },
301+ )
302+
303+ return final_metrics
276304
277- return final_stats_df .reindex (desired_index_sort )
305+ @property
306+ def df (self ) -> pd .DataFrame :
307+ """Returns the dataframe with the transaction statistics by segment."""
308+ if self ._df is None :
309+ cols = ColumnHelper ()
310+ col_order = [
311+ self .segment_col ,
312+ * SegTransactionStats ._get_col_order (include_quantity = cols .agg_unit_qty in self .table .columns ),
313+ ]
314+ self ._df = self .table .execute ()[col_order ]
315+ return self ._df
278316
279317 def plot (
280318 self ,
@@ -325,9 +363,9 @@ def plot(
325363 if orientation == "horizontal" :
326364 kind = "barh"
327365
328- val_s = self .df [value_col ]
366+ val_s = self .df . set_index ( self . segment_col ) [value_col ]
329367 if hide_total :
330- val_s = val_s [val_s .index != "total " ]
368+ val_s = val_s [val_s .index != "Total " ]
331369
332370 if sort_order is not None :
333371 ascending = sort_order == "ascending"
0 commit comments