@@ -193,7 +193,7 @@ class SegTransactionStats:
193
193
def __init__ (
194
194
self ,
195
195
data : pd .DataFrame | ibis .Table ,
196
- segment_col : str = "segment_name" ,
196
+ segment_col : str | list [ str ] = "segment_name" ,
197
197
extra_aggs : dict [str , tuple [str , str ]] | None = None ,
198
198
) -> None :
199
199
"""Calculates transaction statistics by segment.
@@ -203,7 +203,8 @@ def __init__(
203
203
customer_id, unit_spend and transaction_id. If the dataframe contains the column unit_quantity, then
204
204
the columns unit_spend and unit_quantity are used to calculate the price_per_unit and
205
205
units_per_transaction.
206
- segment_col (str, optional): The column to use for the segmentation. Defaults to "segment_name".
206
+ segment_col (str | list[str], optional): The column or list of columns to use for the segmentation.
207
+ Defaults to "segment_name".
207
208
extra_aggs (dict[str, tuple[str, str]], optional): Additional aggregations to perform.
208
209
The keys in the dictionary will be the column names for the aggregation results.
209
210
The values are tuples with (column_name, aggregation_function), where:
@@ -212,11 +213,14 @@ def __init__(
212
213
Example: {"stores": ("store_id", "nunique")} would count unique store_ids.
213
214
"""
214
215
cols = ColumnHelper ()
216
+
217
+ if isinstance (segment_col , str ):
218
+ segment_col = [segment_col ]
215
219
required_cols = [
216
220
cols .customer_id ,
217
221
cols .unit_spend ,
218
222
cols .transaction_id ,
219
- segment_col ,
223
+ * segment_col ,
220
224
]
221
225
if cols .unit_qty in data .columns :
222
226
required_cols .append (cols .unit_qty )
@@ -274,14 +278,14 @@ def _get_col_order(include_quantity: bool) -> list[str]:
274
278
@staticmethod
275
279
def _calc_seg_stats (
276
280
data : pd .DataFrame | ibis .Table ,
277
- segment_col : str ,
281
+ segment_col : list [ str ] ,
278
282
extra_aggs : dict [str , tuple [str , str ]] | None = None ,
279
283
) -> ibis .Table :
280
284
"""Calculates the transaction statistics by segment.
281
285
282
286
Args:
283
287
data (pd.DataFrame | ibis.Table): The transaction data.
284
- segment_col (str): The column to use for the segmentation.
288
+ segment_col (list[ str] ): The columns to use for the segmentation.
285
289
extra_aggs (dict[str, tuple[str, str]], optional): Additional aggregations to perform.
286
290
The keys in the dictionary will be the column names for the aggregation results.
287
291
The values are tuples with (column_name, aggregation_function).
@@ -315,7 +319,7 @@ def _calc_seg_stats(
315
319
316
320
# Calculate metrics for segments and total
317
321
segment_metrics = data .group_by (segment_col ).aggregate (** aggs )
318
- total_metrics = data .aggregate (** aggs ).mutate (segment_name = ibis .literal ("Total" ))
322
+ total_metrics = data .aggregate (** aggs ).mutate ({ col : ibis .literal ("Total" ) for col in segment_col } )
319
323
total_customers = data [cols .customer_id ].nunique ()
320
324
321
325
# Cross join with total_customers to make it available for percentage calculation
@@ -344,7 +348,7 @@ def df(self) -> pd.DataFrame:
344
348
if self ._df is None :
345
349
cols = ColumnHelper ()
346
350
col_order = [
347
- self .segment_col ,
351
+ * self .segment_col ,
348
352
* SegTransactionStats ._get_col_order (include_quantity = cols .agg_unit_qty in self .table .columns ),
349
353
]
350
354
@@ -393,18 +397,23 @@ def plot(
393
397
Raises:
394
398
ValueError: If the sort_order is not "ascending", "descending" or None.
395
399
ValueError: If the orientation is not "vertical" or "horizontal".
400
+ ValueError: If multiple segment columns are used, as plotting is only supported for a single segment column.
396
401
"""
397
402
if sort_order not in ["ascending" , "descending" , None ]:
398
403
raise ValueError ("sort_order must be either 'ascending' or 'descending' or None" )
399
404
if orientation not in ["vertical" , "horizontal" ]:
400
405
raise ValueError ("orientation must be either 'vertical' or 'horizontal'" )
406
+ if len (self .segment_col ) > 1 :
407
+ raise ValueError ("Plotting is only supported for a single segment column" )
401
408
402
409
default_title = f"{ value_col .title ()} by Segment"
403
410
kind = "bar"
404
411
if orientation == "horizontal" :
405
412
kind = "barh"
406
413
407
- val_s = self .df .set_index (self .segment_col )[value_col ]
414
+ # Use the first segment column for plotting
415
+ plot_segment_col = self .segment_col [0 ]
416
+ val_s = self .df .set_index (plot_segment_col )[value_col ]
408
417
if hide_total :
409
418
val_s = val_s [val_s .index != "Total" ]
410
419
@@ -462,7 +471,7 @@ class RFMSegmentation:
462
471
463
472
_df : pd .DataFrame | None = None
464
473
465
- def __init__ (self , df : pd .DataFrame | ibis .Table , current_date : str | None = None ) -> None :
474
+ def __init__ (self , df : pd .DataFrame | ibis .Table , current_date : str | datetime . date | None = None ) -> None :
466
475
"""Initializes the RFM segmentation process.
467
476
468
477
Args:
@@ -472,8 +481,8 @@ def __init__(self, df: pd.DataFrame | ibis.Table, current_date: str | None = Non
472
481
- transaction_date
473
482
- unit_spend
474
483
- transaction_id
475
- current_date (Optional[str] ): The reference date for calculating recency (format: "YYYY-MM-DD") .
476
- If not provided, the current system date will be used .
484
+ current_date (Optional[Union[ str, datetime.date]] ): The reference date for calculating recency.
485
+ Can be a string (format: "YYYY-MM-DD"), a date object, or None (defaults to the current system date) .
477
486
478
487
Raises:
479
488
ValueError: If the dataframe is missing required columns.
@@ -491,9 +500,13 @@ def __init__(self, df: pd.DataFrame | ibis.Table, current_date: str | None = Non
491
500
if missing_cols :
492
501
error_message = f"Missing required columns: { missing_cols } "
493
502
raise ValueError (error_message )
494
- current_date = (
495
- datetime .date .fromisoformat (current_date ) if current_date else datetime .datetime .now (datetime .UTC ).date ()
496
- )
503
+
504
+ if isinstance (current_date , str ):
505
+ current_date = datetime .date .fromisoformat (current_date )
506
+ elif current_date is None :
507
+ current_date = datetime .datetime .now (datetime .UTC ).date ()
508
+ elif not isinstance (current_date , datetime .date ):
509
+ raise TypeError ("current_date must be a string in 'YYYY-MM-DD' format, a datetime.date object, or None" )
497
510
498
511
self .table = self ._compute_rfm (df , current_date )
499
512
@@ -537,13 +550,19 @@ def _compute_rfm(self, df: ibis.Table, current_date: datetime.date) -> ibis.Tabl
537
550
m_score = (ibis .ntile (10 ).over (window_monetary )),
538
551
)
539
552
540
- rfm_segment = (rfm_scores .r_score * 100 + rfm_scores .f_score * 10 + rfm_scores .m_score ).name ("rfm_segment" )
541
-
542
- return rfm_scores .mutate (rfm_segment = rfm_segment )
553
+ return rfm_scores .mutate (
554
+ rfm_segment = (rfm_scores .r_score * 100 + rfm_scores .f_score * 10 + rfm_scores .m_score ),
555
+ fm_segment = (rfm_scores .f_score * 10 + rfm_scores .m_score ),
556
+ )
543
557
544
558
@property
545
559
def df (self ) -> pd .DataFrame :
546
560
"""Returns the dataframe with the segment names."""
547
561
if self ._df is None :
548
562
self ._df = self .table .execute ().set_index (get_option ("column.customer_id" ))
549
563
return self ._df
564
+
565
+ @property
566
+ def ibis_table (self ) -> ibis .Table :
567
+ """Returns the computed Ibis table with RFM segmentation."""
568
+ return self .table
0 commit comments