@@ -194,6 +194,7 @@ def __init__(
194194 self ,
195195 data : pd .DataFrame | ibis .Table ,
196196 segment_col : str | list [str ] = "segment_name" ,
197+ calc_total : bool = True ,
197198 extra_aggs : dict [str , tuple [str , str ]] | None = None ,
198199 ) -> None :
199200 """Calculates transaction statistics by segment.
@@ -205,6 +206,7 @@ def __init__(
205206 units_per_transaction.
206207 segment_col (str | list[str], optional): The column or list of columns to use for the segmentation.
207208 Defaults to "segment_name".
209+ calc_total (bool, optional): Whether to include the total row. Defaults to True.
208210 extra_aggs (dict[str, tuple[str, str]], optional): Additional aggregations to perform.
209211 The keys in the dictionary will be the column names for the aggregation results.
210212 The values are tuples with (column_name, aggregation_function), where:
@@ -244,7 +246,7 @@ def __init__(
244246 self .segment_col = segment_col
245247 self .extra_aggs = {} if extra_aggs is None else extra_aggs
246248
247- self .table = self ._calc_seg_stats (data , segment_col , self .extra_aggs )
249+ self .table = self ._calc_seg_stats (data , segment_col , calc_total , self .extra_aggs )
248250
249251 @staticmethod
250252 def _get_col_order (include_quantity : bool ) -> list [str ]:
@@ -279,6 +281,7 @@ def _get_col_order(include_quantity: bool) -> list[str]:
279281 def _calc_seg_stats (
280282 data : pd .DataFrame | ibis .Table ,
281283 segment_col : list [str ],
284+ calc_total : bool = True ,
282285 extra_aggs : dict [str , tuple [str , str ]] | None = None ,
283286 ) -> ibis .Table :
284287 """Calculates the transaction statistics by segment.
@@ -287,6 +290,7 @@ def _calc_seg_stats(
287290 data (pd.DataFrame | ibis.Table): The transaction data.
288291 segment_col (list[str]): The columns to use for the segmentation.
289292 extra_aggs (dict[str, tuple[str, str]], optional): Additional aggregations to perform.
293+ calc_total (bool, optional): Whether to include the total row. Defaults to True.
290294 The keys in the dictionary will be the column names for the aggregation results.
291295 The values are tuples with (column_name, aggregation_function).
292296
@@ -298,7 +302,7 @@ def _calc_seg_stats(
298302 data = ibis .memtable (data )
299303
300304 elif not isinstance (data , ibis .Table ):
301- raise TypeError ("data must be either a pandas DataFrame or a ibis Table" )
305+ raise TypeError ("data must be either a pandas DataFrame or an ibis Table" )
302306
303307 cols = ColumnHelper ()
304308
@@ -317,13 +321,18 @@ def _calc_seg_stats(
317321 col , func = col_tuple
318322 aggs [agg_name ] = getattr (data [col ], func )()
319323
320- # Calculate metrics for segments and total
324+ # Calculate metrics for segments
321325 segment_metrics = data .group_by (segment_col ).aggregate (** aggs )
322- total_metrics = data .aggregate (** aggs ).mutate ({col : ibis .literal ("Total" ) for col in segment_col })
326+ final_metrics = segment_metrics
327+
328+ if calc_total :
329+ total_metrics = data .aggregate (** aggs ).mutate ({col : ibis .literal ("Total" ) for col in segment_col })
330+ final_metrics = ibis .union (segment_metrics , total_metrics )
331+
323332 total_customers = data [cols .customer_id ].nunique ()
324333
325334 # Cross join with total_customers to make it available for percentage calculation
326- final_metrics = ibis . union ( segment_metrics , total_metrics ) .mutate (
335+ final_metrics = final_metrics .mutate (
327336 ** {
328337 cols .calc_spend_per_cust : ibis ._ [cols .agg_unit_spend ] / ibis ._ [cols .agg_customer_id ],
329338 cols .calc_spend_per_trans : ibis ._ [cols .agg_unit_spend ] / ibis ._ [cols .agg_transaction_id ],
0 commit comments