@@ -194,6 +194,7 @@ def __init__(
194
194
self ,
195
195
data : pd .DataFrame | ibis .Table ,
196
196
segment_col : str | list [str ] = "segment_name" ,
197
+ calc_total : bool = True ,
197
198
extra_aggs : dict [str , tuple [str , str ]] | None = None ,
198
199
) -> None :
199
200
"""Calculates transaction statistics by segment.
@@ -205,6 +206,7 @@ def __init__(
205
206
units_per_transaction.
206
207
segment_col (str | list[str], optional): The column or list of columns to use for the segmentation.
207
208
Defaults to "segment_name".
209
+ calc_total (bool, optional): Whether to include the total row. Defaults to True.
208
210
extra_aggs (dict[str, tuple[str, str]], optional): Additional aggregations to perform.
209
211
The keys in the dictionary will be the column names for the aggregation results.
210
212
The values are tuples with (column_name, aggregation_function), where:
@@ -244,7 +246,7 @@ def __init__(
244
246
self .segment_col = segment_col
245
247
self .extra_aggs = {} if extra_aggs is None else extra_aggs
246
248
247
- self .table = self ._calc_seg_stats (data , segment_col , self .extra_aggs )
249
+ self .table = self ._calc_seg_stats (data , segment_col , calc_total , self .extra_aggs )
248
250
249
251
@staticmethod
250
252
def _get_col_order (include_quantity : bool ) -> list [str ]:
@@ -279,6 +281,7 @@ def _get_col_order(include_quantity: bool) -> list[str]:
279
281
def _calc_seg_stats (
280
282
data : pd .DataFrame | ibis .Table ,
281
283
segment_col : list [str ],
284
+ calc_total : bool = True ,
282
285
extra_aggs : dict [str , tuple [str , str ]] | None = None ,
283
286
) -> ibis .Table :
284
287
"""Calculates the transaction statistics by segment.
@@ -287,6 +290,7 @@ def _calc_seg_stats(
287
290
data (pd.DataFrame | ibis.Table): The transaction data.
288
291
segment_col (list[str]): The columns to use for the segmentation.
289
292
extra_aggs (dict[str, tuple[str, str]], optional): Additional aggregations to perform.
293
+ calc_total (bool, optional): Whether to include the total row. Defaults to True.
290
294
The keys in the dictionary will be the column names for the aggregation results.
291
295
The values are tuples with (column_name, aggregation_function).
292
296
@@ -298,7 +302,7 @@ def _calc_seg_stats(
298
302
data = ibis .memtable (data )
299
303
300
304
elif not isinstance (data , ibis .Table ):
301
- raise TypeError ("data must be either a pandas DataFrame or a ibis Table" )
305
+ raise TypeError ("data must be either a pandas DataFrame or an ibis Table" )
302
306
303
307
cols = ColumnHelper ()
304
308
@@ -317,13 +321,18 @@ def _calc_seg_stats(
317
321
col , func = col_tuple
318
322
aggs [agg_name ] = getattr (data [col ], func )()
319
323
320
- # Calculate metrics for segments and total
324
+ # Calculate metrics for segments
321
325
segment_metrics = data .group_by (segment_col ).aggregate (** aggs )
322
- total_metrics = data .aggregate (** aggs ).mutate ({col : ibis .literal ("Total" ) for col in segment_col })
326
+ final_metrics = segment_metrics
327
+
328
+ if calc_total :
329
+ total_metrics = data .aggregate (** aggs ).mutate ({col : ibis .literal ("Total" ) for col in segment_col })
330
+ final_metrics = ibis .union (segment_metrics , total_metrics )
331
+
323
332
total_customers = data [cols .customer_id ].nunique ()
324
333
325
334
# Cross join with total_customers to make it available for percentage calculation
326
- final_metrics = ibis . union ( segment_metrics , total_metrics ) .mutate (
335
+ final_metrics = final_metrics .mutate (
327
336
** {
328
337
cols .calc_spend_per_cust : ibis ._ [cols .agg_unit_spend ] / ibis ._ [cols .agg_customer_id ],
329
338
cols .calc_spend_per_trans : ibis ._ [cols .agg_unit_spend ] / ibis ._ [cols .agg_transaction_id ],
0 commit comments