Skip to content

Commit 9501eb7

Browse files
committed
Merge branch 'main' of github.com:data-simply/pyretailscience into feature/rfm-segmentation
2 parents 4f77514 + 54b8bbf commit 9501eb7

24 files changed

+137
-54
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
repos:
22
- repo: https://github.com/astral-sh/ruff-pre-commit
3-
rev: "v0.2.2"
3+
rev: "v0.11.0"
44
hooks:
55
- id: ruff
66
args: ["--fix"]
77
- id: ruff-format
88
- repo: https://github.com/pre-commit/pre-commit-hooks
9-
rev: v4.5.0
9+
rev: v5.0.0
1010
hooks:
1111
- id: trailing-whitespace
1212
- id: end-of-file-fixer

docs/examples/cross_shop.ipynb

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,9 @@
238238
"source": [
239239
"shoes_idx = df[\"category_1_name\"] == \"Shoes\"\n",
240240
"df.loc[shoes_idx, \"category_1_name\"] = np.random.RandomState(42).choice(\n",
241-
" [\"Shoes\", \"Jeans\"], size=shoes_idx.sum(), p=[0.5, 0.5],\n",
241+
" [\"Shoes\", \"Jeans\"],\n",
242+
" size=shoes_idx.sum(),\n",
243+
" p=[0.5, 0.5],\n",
242244
")"
243245
]
244246
},

docs/examples/gain_loss.ipynb

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,9 @@
254254
"# Reasign half the rows to Calvin Klein and leave the other half as Diesel\n",
255255
"p2_diesel_idx = time_period_2 & (df[\"brand_name\"] == \"Diesel\")\n",
256256
"df.loc[p2_diesel_idx, \"brand_name\"] = np.random.RandomState(42).choice(\n",
257-
" [\"Calvin Klein\", \"Diesel\"], size=p2_diesel_idx.sum(), p=[0.75, 0.25],\n",
257+
" [\"Calvin Klein\", \"Diesel\"],\n",
258+
" size=p2_diesel_idx.sum(),\n",
259+
" p=[0.75, 0.25],\n",
258260
")\n",
259261
"\n",
260262
"# Apply a 20% discount to Calvin Klein products and increase the quantity by 50%\n",

docs/examples/segmentation.ipynb

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -701,10 +701,10 @@
701701
" },\n",
702702
" color=\"black\",\n",
703703
" bbox={\n",
704-
" \"facecolor\":\"white\",\n",
705-
" \"edgecolor\":\"white\",\n",
706-
" \"boxstyle\":\"round,rounding_size=0.75\",\n",
707-
" \"pad\":0.75,\n",
704+
" \"facecolor\": \"white\",\n",
705+
" \"edgecolor\": \"white\",\n",
706+
" \"boxstyle\": \"round,rounding_size=0.75\",\n",
707+
" \"pad\": 0.75,\n",
708708
" },\n",
709709
" linespacing=1.5,\n",
710710
")\n",

pyretailscience/analysis/cross_shop.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""This module contains the CrossShop class that is used to create a cross-shop diagram."""
22

3-
43
import ibis
54
import matplotlib.pyplot as plt
65
import pandas as pd

pyretailscience/analysis/haversine.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
- **Requires Ibis-Compatible Backend**: Ensure your Ibis backend supports trigonometric functions.
2222
- **Assumes Spherical Earth**: Uses the Haversine formula, which introduces slight inaccuracies due to Earth's oblate shape.
2323
"""
24+
2425
import ibis
2526

2627

pyretailscience/analysis/segmentation.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ class SegTransactionStats:
193193
def __init__(
194194
self,
195195
data: pd.DataFrame | ibis.Table,
196-
segment_col: str = "segment_name",
196+
segment_col: str | list[str] = "segment_name",
197197
extra_aggs: dict[str, tuple[str, str]] | None = None,
198198
) -> None:
199199
"""Calculates transaction statistics by segment.
@@ -203,7 +203,8 @@ def __init__(
203203
customer_id, unit_spend and transaction_id. If the dataframe contains the column unit_quantity, then
204204
the columns unit_spend and unit_quantity are used to calculate the price_per_unit and
205205
units_per_transaction.
206-
segment_col (str, optional): The column to use for the segmentation. Defaults to "segment_name".
206+
segment_col (str | list[str], optional): The column or list of columns to use for the segmentation.
207+
Defaults to "segment_name".
207208
extra_aggs (dict[str, tuple[str, str]], optional): Additional aggregations to perform.
208209
The keys in the dictionary will be the column names for the aggregation results.
209210
The values are tuples with (column_name, aggregation_function), where:
@@ -212,11 +213,14 @@ def __init__(
212213
Example: {"stores": ("store_id", "nunique")} would count unique store_ids.
213214
"""
214215
cols = ColumnHelper()
216+
217+
if isinstance(segment_col, str):
218+
segment_col = [segment_col]
215219
required_cols = [
216220
cols.customer_id,
217221
cols.unit_spend,
218222
cols.transaction_id,
219-
segment_col,
223+
*segment_col,
220224
]
221225
if cols.unit_qty in data.columns:
222226
required_cols.append(cols.unit_qty)
@@ -274,14 +278,14 @@ def _get_col_order(include_quantity: bool) -> list[str]:
274278
@staticmethod
275279
def _calc_seg_stats(
276280
data: pd.DataFrame | ibis.Table,
277-
segment_col: str,
281+
segment_col: list[str],
278282
extra_aggs: dict[str, tuple[str, str]] | None = None,
279283
) -> ibis.Table:
280284
"""Calculates the transaction statistics by segment.
281285
282286
Args:
283287
data (pd.DataFrame | ibis.Table): The transaction data.
284-
segment_col (str): The column to use for the segmentation.
288+
segment_col (list[str]): The columns to use for the segmentation.
285289
extra_aggs (dict[str, tuple[str, str]], optional): Additional aggregations to perform.
286290
The keys in the dictionary will be the column names for the aggregation results.
287291
The values are tuples with (column_name, aggregation_function).
@@ -315,7 +319,7 @@ def _calc_seg_stats(
315319

316320
# Calculate metrics for segments and total
317321
segment_metrics = data.group_by(segment_col).aggregate(**aggs)
318-
total_metrics = data.aggregate(**aggs).mutate(segment_name=ibis.literal("Total"))
322+
total_metrics = data.aggregate(**aggs).mutate({col: ibis.literal("Total") for col in segment_col})
319323
total_customers = data[cols.customer_id].nunique()
320324

321325
# Cross join with total_customers to make it available for percentage calculation
@@ -344,7 +348,7 @@ def df(self) -> pd.DataFrame:
344348
if self._df is None:
345349
cols = ColumnHelper()
346350
col_order = [
347-
self.segment_col,
351+
*self.segment_col,
348352
*SegTransactionStats._get_col_order(include_quantity=cols.agg_unit_qty in self.table.columns),
349353
]
350354

@@ -393,18 +397,23 @@ def plot(
393397
Raises:
394398
ValueError: If the sort_order is not "ascending", "descending" or None.
395399
ValueError: If the orientation is not "vertical" or "horizontal".
400+
ValueError: If multiple segment columns are used, as plotting is only supported for a single segment column.
396401
"""
397402
if sort_order not in ["ascending", "descending", None]:
398403
raise ValueError("sort_order must be either 'ascending' or 'descending' or None")
399404
if orientation not in ["vertical", "horizontal"]:
400405
raise ValueError("orientation must be either 'vertical' or 'horizontal'")
406+
if len(self.segment_col) > 1:
407+
raise ValueError("Plotting is only supported for a single segment column")
401408

402409
default_title = f"{value_col.title()} by Segment"
403410
kind = "bar"
404411
if orientation == "horizontal":
405412
kind = "barh"
406413

407-
val_s = self.df.set_index(self.segment_col)[value_col]
414+
# Use the first segment column for plotting
415+
plot_segment_col = self.segment_col[0]
416+
val_s = self.df.set_index(plot_segment_col)[value_col]
408417
if hide_total:
409418
val_s = val_s[val_s.index != "Total"]
410419

pyretailscience/plots/time.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
- **Helper functions**: Utilizes utility functions from the `pyretailscience` package to handle styling, formatting, and other plot adjustments.
3434
"""
3535

36-
3736
import numpy as np
3837
import pandas as pd
3938
from matplotlib.axes import Axes, SubplotBase

pyretailscience/plots/venn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
- **Pre-Aggregated Data Required**: The module does not perform data aggregation; input data should already be structured correctly.
2323
2424
"""
25+
2526
from collections.abc import Callable
2627

2728
import pandas as pd

tests/analysis/test_cross_shop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
cols = ColumnHelper()
1010

1111

12-
@pytest.fixture()
12+
@pytest.fixture
1313
def sample_data():
1414
"""Sample data for testing."""
1515
return pd.DataFrame(

tests/analysis/test_haversine.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
"""Tests for the haversine distance module."""
2+
23
import ibis
34
import pandas as pd
45
import pytest
56

67
from pyretailscience.analysis.haversine import haversine_distance
78

89

9-
@pytest.fixture()
10+
@pytest.fixture
1011
def sample_ibis_table():
1112
"""Fixture to provide a sample Ibis table for testing."""
1213
data = {

tests/analysis/test_product_association.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
class TestProductAssociations:
1313
"""Tests for the ProductAssociations class."""
1414

15-
@pytest.fixture()
15+
@pytest.fixture
1616
def transactions_df(self) -> pd.DataFrame:
1717
"""Return a sample DataFrame for testing."""
1818
# fmt: off
@@ -23,7 +23,7 @@ def transactions_df(self) -> pd.DataFrame:
2323
})
2424
# fmt: on
2525

26-
@pytest.fixture()
26+
@pytest.fixture
2727
def expected_results_single_items_df(self) -> pd.DataFrame:
2828
"""Return the expected results for the single items association analysis."""
2929
# fmt: off
@@ -58,7 +58,7 @@ def expected_results_single_items_df(self) -> pd.DataFrame:
5858
)
5959
# fmt: on
6060

61-
@pytest.fixture()
61+
@pytest.fixture
6262
def expected_results_pair_items_df(self) -> pd.DataFrame:
6363
"""Return the expected results for the pair items association analysis."""
6464
# fmt: off

tests/analysis/test_revenue_tree.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
class TestRevenueTree:
1313
"""Test the RevenueTree class."""
1414

15-
@pytest.fixture()
15+
@pytest.fixture
1616
def cols(self):
1717
"""Return a ColumnHelper instance."""
1818
return ColumnHelper()

tests/analysis/test_segmentation.py

Lines changed: 72 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
class TestCalcSegStats:
1919
"""Tests for the _calc_seg_stats method."""
2020

21-
@pytest.fixture()
21+
@pytest.fixture
2222
def base_df(self):
2323
"""Return a base DataFrame for testing."""
2424
return pd.DataFrame(
@@ -319,6 +319,70 @@ def test_handles_empty_dataframe_with_errors(self):
319319
with pytest.raises(ValueError):
320320
SegTransactionStats(df, "segment_name")
321321

322+
def test_multiple_segment_columns(self):
323+
"""Test that the class correctly handles multiple segment columns."""
324+
df = pd.DataFrame(
325+
{
326+
cols.customer_id: [1, 1, 2, 2, 3, 3],
327+
cols.unit_spend: [100.0, 150.0, 200.0, 250.0, 300.0, 350.0],
328+
cols.transaction_id: [101, 102, 103, 104, 105, 106],
329+
"segment_name": ["A", "A", "B", "B", "A", "A"],
330+
"region": ["North", "North", "South", "South", "East", "East"],
331+
},
332+
)
333+
334+
# Test with a list of segment columns
335+
seg_stats = SegTransactionStats(df, ["segment_name", "region"])
336+
337+
# Create expected DataFrame with the combinations actually produced
338+
expected_output = pd.DataFrame(
339+
{
340+
"segment_name": ["A", "A", "B", "Total"],
341+
"region": ["East", "North", "South", "Total"],
342+
cols.agg_unit_spend: [650.0, 250.0, 450.0, 1350.0],
343+
cols.agg_transaction_id: [2, 2, 2, 6],
344+
cols.agg_customer_id: [1, 1, 1, 3],
345+
cols.calc_spend_per_cust: [650.0, 250.0, 450.0, 450.0],
346+
cols.calc_spend_per_trans: [325.0, 125.0, 225.0, 225.0],
347+
cols.calc_trans_per_cust: [2.0, 2.0, 2.0, 2.0],
348+
cols.customers_pct: [1 / 3, 1 / 3, 1 / 3, 1.0],
349+
},
350+
)
351+
352+
# Sort both dataframes by the segment columns for consistent comparison
353+
result_df = seg_stats.df.sort_values(["segment_name", "region"]).reset_index(drop=True)
354+
expected_output = expected_output.sort_values(["segment_name", "region"]).reset_index(drop=True)
355+
356+
# Check that both segment columns are in the result
357+
assert "segment_name" in result_df.columns
358+
assert "region" in result_df.columns
359+
360+
# Check number of rows - the implementation only returns actual combinations that exist in data
361+
# plus the Total row, not all possible combinations
362+
assert len(result_df) == len(expected_output)
363+
364+
# Use pandas testing to compare the dataframes
365+
pd.testing.assert_frame_equal(result_df[expected_output.columns], expected_output)
366+
367+
def test_plot_with_multiple_segment_columns(self):
368+
"""Test that plotting with multiple segment columns raises a ValueError."""
369+
df = pd.DataFrame(
370+
{
371+
cols.customer_id: [1, 2, 3],
372+
cols.unit_spend: [100.0, 200.0, 300.0],
373+
cols.transaction_id: [101, 102, 103],
374+
"segment_name": ["A", "B", "A"],
375+
"region": ["North", "South", "East"],
376+
},
377+
)
378+
379+
seg_stats = SegTransactionStats(df, ["segment_name", "region"])
380+
381+
with pytest.raises(ValueError) as excinfo:
382+
seg_stats.plot("spend")
383+
384+
assert "Plotting is only supported for a single segment column" in str(excinfo.value)
385+
322386
def test_extra_aggs_functionality(self):
323387
"""Test that the extra_aggs parameter works correctly."""
324388
# Constants for expected values
@@ -370,9 +434,11 @@ def test_extra_aggs_functionality(self):
370434
# Sort by segment_name to ensure consistent order
371435
result_df_multi = seg_stats_multi.df.sort_values("segment_name").reset_index(drop=True)
372436

373-
assert result_df_multi.loc[0, "distinct_products"] == segment_a_product_count # Segment A
374-
assert result_df_multi.loc[1, "distinct_products"] == segment_b_product_count # Segment B
375-
assert result_df_multi.loc[2, "distinct_products"] == total_product_count # Total
437+
assert result_df_multi["distinct_products"].to_list() == [
438+
segment_a_product_count,
439+
segment_b_product_count,
440+
total_product_count,
441+
]
376442

377443
def test_extra_aggs_with_invalid_column(self):
378444
"""Test that an error is raised when an invalid column is specified in extra_aggs."""
@@ -410,7 +476,7 @@ def test_extra_aggs_with_invalid_function(self):
410476
class TestHMLSegmentation:
411477
"""Tests for the HMLSegmentation class."""
412478

413-
@pytest.fixture()
479+
@pytest.fixture
414480
def base_df(self):
415481
"""Return a base DataFrame for testing."""
416482
return pd.DataFrame(
@@ -489,7 +555,7 @@ def test_alternate_value_col(self, base_df):
489555
class TestRFMSegmentation:
490556
"""Tests for the RFMSegmentation class."""
491557

492-
@pytest.fixture()
558+
@pytest.fixture
493559
def base_df(self):
494560
"""Return a base DataFrame for testing."""
495561
return pd.DataFrame(

tests/plots/test_area.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Tests for the plots.area module."""
2+
23
from itertools import cycle
34

45
import numpy as np
@@ -13,7 +14,7 @@
1314
RNG = np.random.default_rng(42)
1415

1516

16-
@pytest.fixture()
17+
@pytest.fixture
1718
def sample_dataframe():
1819
"""A sample dataframe for Jeans sales data."""
1920
data = {
@@ -24,7 +25,7 @@ def sample_dataframe():
2425
return pd.DataFrame(data)
2526

2627

27-
@pytest.fixture()
28+
@pytest.fixture
2829
def _mock_color_generators(mocker):
2930
"""Mock the color generators for single and multi color maps."""
3031
single_color_gen = cycle(["#FF0000"]) # Mocked single-color generator (e.g., red)
@@ -34,7 +35,7 @@ def _mock_color_generators(mocker):
3435
mocker.patch("pyretailscience.style.tailwind.get_multi_color_cmap", return_value=multi_color_gen)
3536

3637

37-
@pytest.fixture()
38+
@pytest.fixture
3839
def _mock_gu_functions(mocker):
3940
mocker.patch("pyretailscience.style.graph_utils.standard_graph_styles", side_effect=lambda ax, **kwargs: ax)
4041
mocker.patch("pyretailscience.style.graph_utils.standard_tick_styles", side_effect=lambda ax: ax)

0 commit comments

Comments
 (0)