Skip to content

Commit fd3847f

Browse files
authored
Merge pull request #112 from Data-Simply/feature/cross-shop
cross-shop
2 parents 2ce2eec + 3b26d38 commit fd3847f

File tree

2 files changed

+130
-104
lines changed

2 files changed

+130
-104
lines changed

pyretailscience/cross_shop.py

Lines changed: 67 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""This module contains the CrossShop class that is used to create a cross-shop diagram."""
22

3+
4+
import ibis
35
import matplotlib.pyplot as plt
46
import pandas as pd
57
from matplotlib.axes import Axes, SubplotBase
@@ -16,24 +18,27 @@ class CrossShop:
1618

1719
def __init__(
1820
self,
19-
df: pd.DataFrame,
20-
group_1_idx: list[bool] | pd.Series,
21-
group_2_idx: list[bool] | pd.Series,
22-
group_3_idx: list[bool] | pd.Series | None = None,
21+
df: pd.DataFrame | ibis.Table,
22+
group_1_col: str,
23+
group_1_val: str,
24+
group_2_col: str,
25+
group_2_val: str,
26+
group_3_col: str | None = None,
27+
group_3_val: str | None = None,
2328
labels: list[str] | None = None,
2429
value_col: str = get_option("column.unit_spend"),
2530
agg_func: str = "sum",
2631
) -> None:
2732
"""Creates a cross-shop diagram that is used to show the overlap of customers between different groups.
2833
2934
Args:
30-
df (pd.DataFrame): The dataframe with transactional data.
31-
group_1_idx (list[bool], pd.Series): A list of bool values determining whether the row is a part of the
32-
first group.
33-
group_2_idx (list[bool], pd.Series): A list of bool values determining whether the row is a part of the
34-
second group.
35-
group_3_idx (list[bool], pd.Series, optional): An optional list of bool values determining whether the
36-
row is a part of the third group. Defaults to None. If not supplied, only two groups will be considered.
35+
df (pd.DataFrame | ibis.Table): The input DataFrame or ibis Table containing transactional data.
36+
group_1_col (str): The column name for the first group.
37+
group_1_val (str): The value of the first group to match.
38+
group_2_col (str): The column name for the second group.
39+
group_2_val (str): The value of the second group to match.
40+
group_3_col (str, optional): The column name for the third group. Defaults to None.
41+
group_3_val (str, optional): The value of the third group to match. Defaults to None.
3742
labels (list[str], optional): The labels for the groups. Defaults to None.
3843
value_col (str, optional): The column to aggregate. Defaults to the option column.unit_spend.
3944
agg_func (str, optional): The aggregation function. Defaults to "sum".
@@ -51,7 +56,7 @@ def __init__(
5156
msg = f"The following columns are required but missing: {missing_cols}"
5257
raise ValueError(msg)
5358

54-
self.group_count = 2 if group_3_idx is None else 3
59+
self.group_count = 2 if group_3_col is None else 3
5560

5661
if (labels is not None) and (len(labels) != self.group_count):
5762
raise ValueError("The number of labels must be equal to the number of group indexes given")
@@ -60,9 +65,12 @@ def __init__(
6065

6166
self.cross_shop_df = self._calc_cross_shop(
6267
df=df,
63-
group_1_idx=group_1_idx,
64-
group_2_idx=group_2_idx,
65-
group_3_idx=group_3_idx,
68+
group_1_col=group_1_col,
69+
group_1_val=group_1_val,
70+
group_2_col=group_2_col,
71+
group_2_val=group_2_val,
72+
group_3_col=group_3_col,
73+
group_3_val=group_3_val,
6674
value_col=value_col,
6775
agg_func=agg_func,
6876
)
@@ -73,60 +81,71 @@ def __init__(
7381

7482
@staticmethod
7583
def _calc_cross_shop(
76-
df: pd.DataFrame,
77-
group_1_idx: list[bool],
78-
group_2_idx: list[bool],
79-
group_3_idx: list[bool] | None = None,
84+
df: pd.DataFrame | ibis.Table,
85+
group_1_col: str,
86+
group_1_val: str,
87+
group_2_col: str,
88+
group_2_val: str,
89+
group_3_col: str | None = None,
90+
group_3_val: str | None = None,
8091
value_col: str = get_option("column.unit_spend"),
8192
agg_func: str = "sum",
8293
) -> pd.DataFrame:
8394
"""Calculate the cross-shop dataframe that will be used to plot the diagram.
8495
8596
Args:
86-
df (pd.DataFrame): The dataframe with transactional data.
87-
group_1_idx (list[bool]): A list of bool values determining whether the row is a part of the first group.
88-
group_2_idx (list[bool]): A list of bool values determining whether the row is a part of the second group.
89-
group_3_idx (list[bool], optional): An optional list of bool values determining whether the row is a part
90-
of the third group. Defaults to None. If not supplied, only two groups will be considered.
97+
df (pd.DataFrame | ibis.Table): The input DataFrame or ibis Table containing transactional data.
98+
group_1_col (str): Column name for the first group.
99+
group_1_val (str): Value to filter for the first group.
100+
group_2_col (str): Column name for the second group.
101+
group_2_val (str): Value to filter for the second group.
102+
group_3_col (str, optional): Column name for the third group. Defaults to None.
103+
group_3_val (str, optional): Value to filter for the third group. Defaults to None.
91104
value_col (str, optional): The column to aggregate. Defaults to option column.unit_spend.
92105
agg_func (str, optional): The aggregation function. Defaults to "sum".
93106
94107
Returns:
95108
pd.DataFrame: The cross-shop dataframe.
96109
97110
Raises:
98-
ValueError: If the groups are not mutually exclusive.
111+
ValueError: If group_3_col or group_3_val is populated, then the other must be as well.
99112
"""
100113
cols = ColumnHelper()
101-
if isinstance(group_1_idx, list):
102-
group_1_idx = pd.Series(group_1_idx)
103-
if isinstance(group_2_idx, list):
104-
group_2_idx = pd.Series(group_2_idx)
105-
if group_3_idx is not None and isinstance(group_3_idx, list):
106-
group_3_idx = pd.Series(group_3_idx)
107114

108-
cs_df = df[[cols.customer_id]].copy()
109-
110-
cs_df["group_1"] = group_1_idx.astype(int)
111-
cs_df["group_2"] = group_2_idx.astype(int)
112-
group_cols = ["group_1", "group_2"]
115+
if isinstance(df, pd.DataFrame):
116+
df: ibis.Table = ibis.memtable(df)
117+
if (group_3_col is None) != (group_3_val is None):
118+
raise ValueError("If group_3_col or group_3_val is populated, then the other must be as well")
113119

114-
if group_3_idx is not None:
115-
cs_df["group_3"] = group_3_idx.astype(int)
116-
group_cols += ["group_3"]
120+
# Using a temporary value column to avoid duplicate column errors during selection. This happens when `value_col` has the same name as `customer_id`, causing conflicts in `.select()`.
121+
temp_value_col = "temp_value_col"
122+
df = df.mutate(**{temp_value_col: df[value_col]})
117123

118-
if (cs_df[group_cols].sum(axis=1) > 1).any():
119-
raise ValueError("The groups must be mutually exclusive.")
124+
group_1 = (df[group_1_col] == group_1_val).cast("int32").name("group_1")
125+
group_2 = (df[group_2_col] == group_2_val).cast("int32").name("group_2")
126+
group_3 = (df[group_3_col] == group_3_val).cast("int32").name("group_3") if group_3_col else None
120127

121-
if not any(group_1_idx) or not any(group_2_idx) or (group_3_idx is not None and not any(group_3_idx)):
122-
raise ValueError("There must at least one row selected for group_1_idx, group_2_idx, and group_3_idx.")
128+
group_cols = ["group_1", "group_2"]
129+
select_cols = [df[cols.customer_id], group_1, group_2]
130+
if group_3 is not None:
131+
group_cols.append("group_3")
132+
select_cols.append(group_3)
133+
134+
cs_df = df.select([*select_cols, df[temp_value_col]]).order_by(cols.customer_id)
135+
cs_df = (
136+
cs_df.group_by(cols.customer_id)
137+
.aggregate(
138+
**{col: cs_df[col].max().name(col) for col in group_cols},
139+
**{temp_value_col: getattr(cs_df[temp_value_col], agg_func)().name(temp_value_col)},
140+
)
141+
.order_by(cols.customer_id)
142+
).execute()
123143

124-
cs_df = cs_df.groupby(cols.customer_id)[group_cols].max()
125144
cs_df["groups"] = cs_df[group_cols].apply(lambda x: tuple(x), axis=1)
126-
127-
kpi_df = df.groupby(cols.customer_id)[value_col].agg(agg_func)
128-
129-
return cs_df.merge(kpi_df, left_index=True, right_index=True)
145+
column_order = [cols.customer_id, *group_cols, "groups", temp_value_col]
146+
cs_df = cs_df[column_order]
147+
cs_df.set_index(cols.customer_id, inplace=True)
148+
return cs_df.rename(columns={temp_value_col: value_col})
130149

131150
@staticmethod
132151
def _calc_cross_shop_table(

tests/test_cross_shop.py

Lines changed: 63 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,20 @@ def sample_data():
1515
return pd.DataFrame(
1616
{
1717
cols.customer_id: [1, 2, 3, 4, 5, 5, 6, 7, 8, 8, 9, 10],
18-
"group_1_idx": [True, False, False, False, False, True, True, False, False, True, False, True],
19-
"group_2_idx": [False, True, False, False, True, False, False, True, False, False, True, False],
20-
"group_3_idx": [False, False, True, False, False, False, False, False, True, False, False, False],
18+
"category_1_name": [
19+
"Jeans",
20+
"Shoes",
21+
"Dresses",
22+
"Hats",
23+
"Shoes",
24+
"Jeans",
25+
"Jeans",
26+
"Shoes",
27+
"Dresses",
28+
"Jeans",
29+
"Shoes",
30+
"Jeans",
31+
],
2132
cols.unit_spend: [10, 20, 30, 40, 20, 50, 10, 20, 30, 15, 40, 50],
2233
},
2334
)
@@ -27,14 +38,16 @@ def test_calc_cross_shop_two_groups(sample_data):
2738
"""Test the _calc_cross_shop method with two groups."""
2839
cross_shop_df = CrossShop._calc_cross_shop(
2940
sample_data,
30-
group_1_idx=sample_data["group_1_idx"],
31-
group_2_idx=sample_data["group_2_idx"],
41+
group_1_col="category_1_name",
42+
group_1_val="Jeans",
43+
group_2_col="category_1_name",
44+
group_2_val="Shoes",
3245
)
3346
ret_df = pd.DataFrame(
3447
{
3548
cols.customer_id: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
36-
"group_1": [1, 0, 0, 0, 1, 1, 0, 1, 0, 1],
37-
"group_2": [0, 1, 0, 0, 1, 0, 1, 0, 1, 0],
49+
"group_1": pd.Series([1, 0, 0, 0, 1, 1, 0, 1, 0, 1], dtype="int32"),
50+
"group_2": pd.Series([0, 1, 0, 0, 1, 0, 1, 0, 1, 0], dtype="int32"),
3851
"groups": [(1, 0), (0, 1), (0, 0), (0, 0), (1, 1), (1, 0), (0, 1), (1, 0), (0, 1), (1, 0)],
3952
cols.unit_spend: [10, 20, 30, 40, 70, 10, 20, 45, 40, 50],
4053
},
@@ -47,16 +60,19 @@ def test_calc_cross_shop_three_groups(sample_data):
4760
"""Test the _calc_cross_shop method with three groups."""
4861
cross_shop_df = CrossShop._calc_cross_shop(
4962
sample_data,
50-
group_1_idx=sample_data["group_1_idx"],
51-
group_2_idx=sample_data["group_2_idx"],
52-
group_3_idx=sample_data["group_3_idx"],
63+
group_1_col="category_1_name",
64+
group_1_val="Jeans",
65+
group_2_col="category_1_name",
66+
group_2_val="Shoes",
67+
group_3_col="category_1_name",
68+
group_3_val="Dresses",
5369
)
5470
ret_df = pd.DataFrame(
5571
{
5672
cols.customer_id: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
57-
"group_1": [1, 0, 0, 0, 1, 1, 0, 1, 0, 1],
58-
"group_2": [0, 1, 0, 0, 1, 0, 1, 0, 1, 0],
59-
"group_3": [0, 0, 1, 0, 0, 0, 0, 1, 0, 0],
73+
"group_1": pd.Series([1, 0, 0, 0, 1, 1, 0, 1, 0, 1], dtype="int32"),
74+
"group_2": pd.Series([0, 1, 0, 0, 1, 0, 1, 0, 1, 0], dtype="int32"),
75+
"group_3": pd.Series([0, 0, 1, 0, 0, 0, 0, 1, 0, 0], dtype="int32"),
6076
"groups": [
6177
(1, 0, 0),
6278
(0, 1, 0),
@@ -73,39 +89,19 @@ def test_calc_cross_shop_three_groups(sample_data):
7389
},
7490
).set_index(cols.customer_id)
7591

76-
assert cross_shop_df.equals(ret_df)
77-
78-
79-
def test_calc_cross_shop_two_groups_overlap_error(sample_data):
80-
"""Test the _calc_cross_shop method with two groups and overlapping group indices."""
81-
with pytest.raises(ValueError):
82-
CrossShop._calc_cross_shop(
83-
sample_data,
84-
# Pass the same group index for both groups
85-
group_1_idx=sample_data["group_1_idx"],
86-
group_2_idx=sample_data["group_1_idx"],
87-
)
88-
89-
90-
def test_calc_cross_shop_three_groups_overlap_error(sample_data):
91-
"""Test the _calc_cross_shop method with three groups and overlapping group indices."""
92-
with pytest.raises(ValueError):
93-
CrossShop._calc_cross_shop(
94-
sample_data,
95-
# Pass the same group index for groups 1 and 3
96-
group_1_idx=sample_data["group_1_idx"],
97-
group_2_idx=sample_data["group_2_idx"],
98-
group_3_idx=sample_data["group_1_idx"],
99-
)
92+
pd.testing.assert_frame_equal(cross_shop_df, ret_df, check_dtype=False)
10093

10194

10295
def test_calc_cross_shop_three_groups_customer_id_nunique(sample_data):
10396
"""Test the _calc_cross_shop method with three groups and customer_id as the value column."""
10497
cross_shop_df = CrossShop._calc_cross_shop(
10598
sample_data,
106-
group_1_idx=sample_data["group_1_idx"],
107-
group_2_idx=sample_data["group_2_idx"],
108-
group_3_idx=sample_data["group_3_idx"],
99+
group_1_col="category_1_name",
100+
group_1_val="Jeans",
101+
group_2_col="category_1_name",
102+
group_2_val="Shoes",
103+
group_3_col="category_1_name",
104+
group_3_val="Dresses",
109105
value_col=cols.customer_id,
110106
agg_func="nunique",
111107
)
@@ -131,17 +127,20 @@ def test_calc_cross_shop_three_groups_customer_id_nunique(sample_data):
131127
index=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
132128
)
133129
ret_df.index.name = cols.customer_id
134-
130+
ret_df = ret_df.astype({"group_1": "int32", "group_2": "int32", "group_3": "int32"})
135131
assert cross_shop_df.equals(ret_df)
136132

137133

138134
def test_calc_cross_shop_table(sample_data):
139135
"""Test the _calc_cross_shop_table method."""
140136
cross_shop_df = CrossShop._calc_cross_shop(
141137
sample_data,
142-
group_1_idx=sample_data["group_1_idx"],
143-
group_2_idx=sample_data["group_2_idx"],
144-
group_3_idx=sample_data["group_3_idx"],
138+
group_1_col="category_1_name",
139+
group_1_val="Jeans",
140+
group_2_col="category_1_name",
141+
group_2_val="Shoes",
142+
group_3_col="category_1_name",
143+
group_3_val="Dresses",
145144
value_col=cols.unit_spend,
146145
)
147146
cross_shop_table = CrossShop._calc_cross_shop_table(
@@ -174,9 +173,12 @@ def test_calc_cross_shop_table_customer_id_nunique(sample_data):
174173
"""Test the _calc_cross_shop_table method with customer_id as the value column."""
175174
cross_shop_df = CrossShop._calc_cross_shop(
176175
sample_data,
177-
group_1_idx=sample_data["group_1_idx"],
178-
group_2_idx=sample_data["group_2_idx"],
179-
group_3_idx=sample_data["group_3_idx"],
176+
group_1_col="category_1_name",
177+
group_1_val="Jeans",
178+
group_2_col="category_1_name",
179+
group_2_val="Shoes",
180+
group_3_col="category_1_name",
181+
group_3_val="Dresses",
180182
value_col=cols.customer_id,
181183
agg_func="nunique",
182184
)
@@ -195,19 +197,24 @@ def test_calc_cross_shop_table_customer_id_nunique(sample_data):
195197
assert cross_shop_table.equals(ret_df)
196198

197199

198-
def test_calc_cross_shop_all_groups_false(sample_data):
199-
"""Test the _calc_cross_shop method with all group indices set to False."""
200-
with pytest.raises(ValueError):
200+
def test_calc_cross_shop_invalid_group_3(sample_data):
201+
"""Test that _calc_cross_shop raises ValueError if only one of group_3_col or group_3_val is provided."""
202+
with pytest.raises(ValueError, match="If group_3_col or group_3_val is populated, then the other must be as well"):
201203
CrossShop._calc_cross_shop(
202204
sample_data,
203-
group_1_idx=[False] * len(sample_data),
204-
group_2_idx=[False] * len(sample_data),
205+
group_1_col="category_1_name",
206+
group_1_val="Jeans",
207+
group_2_col="category_1_name",
208+
group_2_val="Shoes",
209+
group_3_col="category_1_name",
205210
)
206211

207-
with pytest.raises(ValueError):
212+
with pytest.raises(ValueError, match="If group_3_col or group_3_val is populated, then the other must be as well"):
208213
CrossShop._calc_cross_shop(
209214
sample_data,
210-
group_1_idx=[False] * len(sample_data),
211-
group_2_idx=[False] * len(sample_data),
212-
group_3_idx=[False] * len(sample_data),
215+
group_1_col="category_1_name",
216+
group_1_val="Jeans",
217+
group_2_col="category_1_name",
218+
group_2_val="Shoes",
219+
group_3_val="T-Shirts",
213220
)

0 commit comments

Comments
 (0)