Skip to content

Commit 2f63781

Browse files
committed
Merge branch 'main' of github.com:data-simply/pyretailscience into feature/rfm-segmentation
2 parents 4f77514 + 54b8bbf commit 2f63781

25 files changed

+161
-68
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
repos:
22
- repo: https://github.com/astral-sh/ruff-pre-commit
3-
rev: "v0.2.2"
3+
rev: "v0.11.0"
44
hooks:
55
- id: ruff
66
args: ["--fix"]
77
- id: ruff-format
88
- repo: https://github.com/pre-commit/pre-commit-hooks
9-
rev: v4.5.0
9+
rev: v5.0.0
1010
hooks:
1111
- id: trailing-whitespace
1212
- id: end-of-file-fixer

docs/analysis_modules.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -834,11 +834,11 @@ rfm_segmenter = RFMSegmentation(df=data, current_date=current_date)
834834
rfm_results = rfm_segmenter.df
835835
```
836836

837-
| customer_id | recency_days | frequency | monetary | r_score | f_score | m_score | rfm_segment |
838-
|-------------|--------------|-----------|----------|---------|---------|---------|-------------|
839-
| 3 | 147 | 3 | 750 | 0 | 0 | 0 | 0 |
840-
| 2 | 127 | 2 | 250 | 1 | 2 | 1 | 121 |
841-
| 1 | 113 | 2 | 125 | 2 | 1 | 2 | 212 |
837+
| customer_id | recency_days | frequency | monetary | r_score | f_score | m_score | rfm_segment | fm_segment |
838+
|-------------|--------------|-----------|----------|---------|---------|---------|-------------|------------|
839+
| 1 | 113 | 2 | 125 | 0 | 0 | 0 | 0 | 0 |
840+
| 2 | 127 | 2 | 250 | 1 | 1 | 1 | 111 | 11 |
841+
| 3 | 147 | 3 | 750 | 2 | 2 | 2 | 222 | 22 |
842842

843843
### Purchases Per Customer
844844

docs/examples/cross_shop.ipynb

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,9 @@
238238
"source": [
239239
"shoes_idx = df[\"category_1_name\"] == \"Shoes\"\n",
240240
"df.loc[shoes_idx, \"category_1_name\"] = np.random.RandomState(42).choice(\n",
241-
" [\"Shoes\", \"Jeans\"], size=shoes_idx.sum(), p=[0.5, 0.5],\n",
241+
" [\"Shoes\", \"Jeans\"],\n",
242+
" size=shoes_idx.sum(),\n",
243+
" p=[0.5, 0.5],\n",
242244
")"
243245
]
244246
},

docs/examples/gain_loss.ipynb

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,9 @@
254254
"# Reasign half the rows to Calvin Klein and leave the other half as Diesel\n",
255255
"p2_diesel_idx = time_period_2 & (df[\"brand_name\"] == \"Diesel\")\n",
256256
"df.loc[p2_diesel_idx, \"brand_name\"] = np.random.RandomState(42).choice(\n",
257-
" [\"Calvin Klein\", \"Diesel\"], size=p2_diesel_idx.sum(), p=[0.75, 0.25],\n",
257+
" [\"Calvin Klein\", \"Diesel\"],\n",
258+
" size=p2_diesel_idx.sum(),\n",
259+
" p=[0.75, 0.25],\n",
258260
")\n",
259261
"\n",
260262
"# Apply a 20% discount to Calvin Klein products and increase the quantity by 50%\n",

docs/examples/segmentation.ipynb

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -701,10 +701,10 @@
701701
" },\n",
702702
" color=\"black\",\n",
703703
" bbox={\n",
704-
" \"facecolor\":\"white\",\n",
705-
" \"edgecolor\":\"white\",\n",
706-
" \"boxstyle\":\"round,rounding_size=0.75\",\n",
707-
" \"pad\":0.75,\n",
704+
" \"facecolor\": \"white\",\n",
705+
" \"edgecolor\": \"white\",\n",
706+
" \"boxstyle\": \"round,rounding_size=0.75\",\n",
707+
" \"pad\": 0.75,\n",
708708
" },\n",
709709
" linespacing=1.5,\n",
710710
")\n",

pyretailscience/analysis/cross_shop.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""This module contains the CrossShop class that is used to create a cross-shop diagram."""
22

3-
43
import ibis
54
import matplotlib.pyplot as plt
65
import pandas as pd

pyretailscience/analysis/haversine.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
- **Requires Ibis-Compatible Backend**: Ensure your Ibis backend supports trigonometric functions.
2222
- **Assumes Spherical Earth**: Uses the Haversine formula, which introduces slight inaccuracies due to Earth's oblate shape.
2323
"""
24+
2425
import ibis
2526

2627

pyretailscience/analysis/segmentation.py

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ class SegTransactionStats:
193193
def __init__(
194194
self,
195195
data: pd.DataFrame | ibis.Table,
196-
segment_col: str = "segment_name",
196+
segment_col: str | list[str] = "segment_name",
197197
extra_aggs: dict[str, tuple[str, str]] | None = None,
198198
) -> None:
199199
"""Calculates transaction statistics by segment.
@@ -203,7 +203,8 @@ def __init__(
203203
customer_id, unit_spend and transaction_id. If the dataframe contains the column unit_quantity, then
204204
the columns unit_spend and unit_quantity are used to calculate the price_per_unit and
205205
units_per_transaction.
206-
segment_col (str, optional): The column to use for the segmentation. Defaults to "segment_name".
206+
segment_col (str | list[str], optional): The column or list of columns to use for the segmentation.
207+
Defaults to "segment_name".
207208
extra_aggs (dict[str, tuple[str, str]], optional): Additional aggregations to perform.
208209
The keys in the dictionary will be the column names for the aggregation results.
209210
The values are tuples with (column_name, aggregation_function), where:
@@ -212,11 +213,14 @@ def __init__(
212213
Example: {"stores": ("store_id", "nunique")} would count unique store_ids.
213214
"""
214215
cols = ColumnHelper()
216+
217+
if isinstance(segment_col, str):
218+
segment_col = [segment_col]
215219
required_cols = [
216220
cols.customer_id,
217221
cols.unit_spend,
218222
cols.transaction_id,
219-
segment_col,
223+
*segment_col,
220224
]
221225
if cols.unit_qty in data.columns:
222226
required_cols.append(cols.unit_qty)
@@ -274,14 +278,14 @@ def _get_col_order(include_quantity: bool) -> list[str]:
274278
@staticmethod
275279
def _calc_seg_stats(
276280
data: pd.DataFrame | ibis.Table,
277-
segment_col: str,
281+
segment_col: list[str],
278282
extra_aggs: dict[str, tuple[str, str]] | None = None,
279283
) -> ibis.Table:
280284
"""Calculates the transaction statistics by segment.
281285
282286
Args:
283287
data (pd.DataFrame | ibis.Table): The transaction data.
284-
segment_col (str): The column to use for the segmentation.
288+
segment_col (list[str]): The columns to use for the segmentation.
285289
extra_aggs (dict[str, tuple[str, str]], optional): Additional aggregations to perform.
286290
The keys in the dictionary will be the column names for the aggregation results.
287291
The values are tuples with (column_name, aggregation_function).
@@ -315,7 +319,7 @@ def _calc_seg_stats(
315319

316320
# Calculate metrics for segments and total
317321
segment_metrics = data.group_by(segment_col).aggregate(**aggs)
318-
total_metrics = data.aggregate(**aggs).mutate(segment_name=ibis.literal("Total"))
322+
total_metrics = data.aggregate(**aggs).mutate({col: ibis.literal("Total") for col in segment_col})
319323
total_customers = data[cols.customer_id].nunique()
320324

321325
# Cross join with total_customers to make it available for percentage calculation
@@ -344,7 +348,7 @@ def df(self) -> pd.DataFrame:
344348
if self._df is None:
345349
cols = ColumnHelper()
346350
col_order = [
347-
self.segment_col,
351+
*self.segment_col,
348352
*SegTransactionStats._get_col_order(include_quantity=cols.agg_unit_qty in self.table.columns),
349353
]
350354

@@ -393,18 +397,23 @@ def plot(
393397
Raises:
394398
ValueError: If the sort_order is not "ascending", "descending" or None.
395399
ValueError: If the orientation is not "vertical" or "horizontal".
400+
ValueError: If multiple segment columns are used, as plotting is only supported for a single segment column.
396401
"""
397402
if sort_order not in ["ascending", "descending", None]:
398403
raise ValueError("sort_order must be either 'ascending' or 'descending' or None")
399404
if orientation not in ["vertical", "horizontal"]:
400405
raise ValueError("orientation must be either 'vertical' or 'horizontal'")
406+
if len(self.segment_col) > 1:
407+
raise ValueError("Plotting is only supported for a single segment column")
401408

402409
default_title = f"{value_col.title()} by Segment"
403410
kind = "bar"
404411
if orientation == "horizontal":
405412
kind = "barh"
406413

407-
val_s = self.df.set_index(self.segment_col)[value_col]
414+
# Use the first segment column for plotting
415+
plot_segment_col = self.segment_col[0]
416+
val_s = self.df.set_index(plot_segment_col)[value_col]
408417
if hide_total:
409418
val_s = val_s[val_s.index != "Total"]
410419

@@ -462,7 +471,7 @@ class RFMSegmentation:
462471

463472
_df: pd.DataFrame | None = None
464473

465-
def __init__(self, df: pd.DataFrame | ibis.Table, current_date: str | None = None) -> None:
474+
def __init__(self, df: pd.DataFrame | ibis.Table, current_date: str | datetime.date | None = None) -> None:
466475
"""Initializes the RFM segmentation process.
467476
468477
Args:
@@ -472,8 +481,8 @@ def __init__(self, df: pd.DataFrame | ibis.Table, current_date: str | None = Non
472481
- transaction_date
473482
- unit_spend
474483
- transaction_id
475-
current_date (Optional[str]): The reference date for calculating recency (format: "YYYY-MM-DD").
476-
If not provided, the current system date will be used.
484+
current_date (Optional[Union[str, datetime.date]]): The reference date for calculating recency.
485+
Can be a string (format: "YYYY-MM-DD"), a date object, or None (defaults to the current system date).
477486
478487
Raises:
479488
ValueError: If the dataframe is missing required columns.
@@ -491,9 +500,13 @@ def __init__(self, df: pd.DataFrame | ibis.Table, current_date: str | None = Non
491500
if missing_cols:
492501
error_message = f"Missing required columns: {missing_cols}"
493502
raise ValueError(error_message)
494-
current_date = (
495-
datetime.date.fromisoformat(current_date) if current_date else datetime.datetime.now(datetime.UTC).date()
496-
)
503+
504+
if isinstance(current_date, str):
505+
current_date = datetime.date.fromisoformat(current_date)
506+
elif current_date is None:
507+
current_date = datetime.datetime.now(datetime.UTC).date()
508+
elif not isinstance(current_date, datetime.date):
509+
raise TypeError("current_date must be a string in 'YYYY-MM-DD' format, a datetime.date object, or None")
497510

498511
self.table = self._compute_rfm(df, current_date)
499512

@@ -537,13 +550,19 @@ def _compute_rfm(self, df: ibis.Table, current_date: datetime.date) -> ibis.Tabl
537550
m_score=(ibis.ntile(10).over(window_monetary)),
538551
)
539552

540-
rfm_segment = (rfm_scores.r_score * 100 + rfm_scores.f_score * 10 + rfm_scores.m_score).name("rfm_segment")
541-
542-
return rfm_scores.mutate(rfm_segment=rfm_segment)
553+
return rfm_scores.mutate(
554+
rfm_segment=(rfm_scores.r_score * 100 + rfm_scores.f_score * 10 + rfm_scores.m_score),
555+
fm_segment=(rfm_scores.f_score * 10 + rfm_scores.m_score),
556+
)
543557

544558
@property
545559
def df(self) -> pd.DataFrame:
546560
"""Returns the dataframe with the segment names."""
547561
if self._df is None:
548562
self._df = self.table.execute().set_index(get_option("column.customer_id"))
549563
return self._df
564+
565+
@property
566+
def ibis_table(self) -> ibis.Table:
567+
"""Returns the computed Ibis table with RFM segmentation."""
568+
return self.table

pyretailscience/plots/time.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
- **Helper functions**: Utilizes utility functions from the `pyretailscience` package to handle styling, formatting, and other plot adjustments.
3434
"""
3535

36-
3736
import numpy as np
3837
import pandas as pd
3938
from matplotlib.axes import Axes, SubplotBase

pyretailscience/plots/venn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
- **Pre-Aggregated Data Required**: The module does not perform data aggregation; input data should already be structured correctly.
2323
2424
"""
25+
2526
from collections.abc import Callable
2627

2728
import pandas as pd

tests/analysis/test_cross_shop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
cols = ColumnHelper()
1010

1111

12-
@pytest.fixture()
12+
@pytest.fixture
1313
def sample_data():
1414
"""Sample data for testing."""
1515
return pd.DataFrame(

tests/analysis/test_haversine.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
"""Tests for the haversine distance module."""
2+
23
import ibis
34
import pandas as pd
45
import pytest
56

67
from pyretailscience.analysis.haversine import haversine_distance
78

89

9-
@pytest.fixture()
10+
@pytest.fixture
1011
def sample_ibis_table():
1112
"""Fixture to provide a sample Ibis table for testing."""
1213
data = {

tests/analysis/test_product_association.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
class TestProductAssociations:
1313
"""Tests for the ProductAssociations class."""
1414

15-
@pytest.fixture()
15+
@pytest.fixture
1616
def transactions_df(self) -> pd.DataFrame:
1717
"""Return a sample DataFrame for testing."""
1818
# fmt: off
@@ -23,7 +23,7 @@ def transactions_df(self) -> pd.DataFrame:
2323
})
2424
# fmt: on
2525

26-
@pytest.fixture()
26+
@pytest.fixture
2727
def expected_results_single_items_df(self) -> pd.DataFrame:
2828
"""Return the expected results for the single items association analysis."""
2929
# fmt: off
@@ -58,7 +58,7 @@ def expected_results_single_items_df(self) -> pd.DataFrame:
5858
)
5959
# fmt: on
6060

61-
@pytest.fixture()
61+
@pytest.fixture
6262
def expected_results_pair_items_df(self) -> pd.DataFrame:
6363
"""Return the expected results for the pair items association analysis."""
6464
# fmt: off

tests/analysis/test_revenue_tree.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
class TestRevenueTree:
1313
"""Test the RevenueTree class."""
1414

15-
@pytest.fixture()
15+
@pytest.fixture
1616
def cols(self):
1717
"""Return a ColumnHelper instance."""
1818
return ColumnHelper()

0 commit comments

Comments
 (0)