Skip to content

cross-shop #112

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 67 additions & 48 deletions pyretailscience/cross_shop.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""This module contains the CrossShop class that is used to create a cross-shop diagram."""


import ibis
import matplotlib.pyplot as plt
import pandas as pd
from matplotlib.axes import Axes, SubplotBase
Expand All @@ -16,24 +18,27 @@

def __init__(
self,
df: pd.DataFrame,
group_1_idx: list[bool] | pd.Series,
group_2_idx: list[bool] | pd.Series,
group_3_idx: list[bool] | pd.Series | None = None,
df: pd.DataFrame | ibis.Table,
group_1_col: str,
group_1_val: str,
group_2_col: str,
group_2_val: str,
group_3_col: str | None = None,
group_3_val: str | None = None,
labels: list[str] | None = None,
value_col: str = get_option("column.unit_spend"),
agg_func: str = "sum",
) -> None:
"""Creates a cross-shop diagram that is used to show the overlap of customers between different groups.

Args:
df (pd.DataFrame): The dataframe with transactional data.
group_1_idx (list[bool], pd.Series): A list of bool values determining whether the row is a part of the
first group.
group_2_idx (list[bool], pd.Series): A list of bool values determining whether the row is a part of the
second group.
group_3_idx (list[bool], pd.Series, optional): An optional list of bool values determining whether the
row is a part of the third group. Defaults to None. If not supplied, only two groups will be considered.
df (pd.DataFrame | ibis.Table): The input DataFrame or ibis Table containing transactional data.
group_1_col (str): The column name for the first group.
group_1_val (str): The value of the first group to match.
group_2_col (str): The column name for the second group.
group_2_val (str): The value of the second group to match.
group_3_col (str, optional): The column name for the third group. Defaults to None.
group_3_val (str, optional): The value of the third group to match. Defaults to None.
labels (list[str], optional): The labels for the groups. Defaults to None.
value_col (str, optional): The column to aggregate. Defaults to the option column.unit_spend.
agg_func (str, optional): The aggregation function. Defaults to "sum".
Expand All @@ -51,7 +56,7 @@
msg = f"The following columns are required but missing: {missing_cols}"
raise ValueError(msg)

self.group_count = 2 if group_3_idx is None else 3
self.group_count = 2 if group_3_col is None else 3

Check warning on line 59 in pyretailscience/cross_shop.py

View check run for this annotation

Codecov / codecov/patch

pyretailscience/cross_shop.py#L59

Added line #L59 was not covered by tests

if (labels is not None) and (len(labels) != self.group_count):
raise ValueError("The number of labels must be equal to the number of group indexes given")
Expand All @@ -60,9 +65,12 @@

self.cross_shop_df = self._calc_cross_shop(
df=df,
group_1_idx=group_1_idx,
group_2_idx=group_2_idx,
group_3_idx=group_3_idx,
group_1_col=group_1_col,
group_1_val=group_1_val,
group_2_col=group_2_col,
group_2_val=group_2_val,
group_3_col=group_3_col,
group_3_val=group_3_val,
value_col=value_col,
agg_func=agg_func,
)
Expand All @@ -73,60 +81,71 @@

@staticmethod
def _calc_cross_shop(
df: pd.DataFrame,
group_1_idx: list[bool],
group_2_idx: list[bool],
group_3_idx: list[bool] | None = None,
df: pd.DataFrame | ibis.Table,
group_1_col: str,
group_1_val: str,
group_2_col: str,
group_2_val: str,
group_3_col: str | None = None,
group_3_val: str | None = None,
value_col: str = get_option("column.unit_spend"),
agg_func: str = "sum",
) -> pd.DataFrame:
"""Calculate the cross-shop dataframe that will be used to plot the diagram.

Args:
df (pd.DataFrame): The dataframe with transactional data.
group_1_idx (list[bool]): A list of bool values determining whether the row is a part of the first group.
group_2_idx (list[bool]): A list of bool values determining whether the row is a part of the second group.
group_3_idx (list[bool], optional): An optional list of bool values determining whether the row is a part
of the third group. Defaults to None. If not supplied, only two groups will be considered.
df (pd.DataFrame | ibis.Table): The input DataFrame or ibis Table containing transactional data.
group_1_col (str): Column name for the first group.
group_1_val (str): Value to filter for the first group.
group_2_col (str): Column name for the second group.
group_2_val (str): Value to filter for the second group.
group_3_col (str, optional): Column name for the third group. Defaults to None.
group_3_val (str, optional): Value to filter for the third group. Defaults to None.
value_col (str, optional): The column to aggregate. Defaults to option column.unit_spend.
agg_func (str, optional): The aggregation function. Defaults to "sum".

Returns:
pd.DataFrame: The cross-shop dataframe.

Raises:
ValueError: If the groups are not mutually exclusive.
ValueError: If group_3_col or group_3_val is populated, then the other must be as well.
"""
cols = ColumnHelper()
if isinstance(group_1_idx, list):
group_1_idx = pd.Series(group_1_idx)
if isinstance(group_2_idx, list):
group_2_idx = pd.Series(group_2_idx)
if group_3_idx is not None and isinstance(group_3_idx, list):
group_3_idx = pd.Series(group_3_idx)

cs_df = df[[cols.customer_id]].copy()

cs_df["group_1"] = group_1_idx.astype(int)
cs_df["group_2"] = group_2_idx.astype(int)
group_cols = ["group_1", "group_2"]
if isinstance(df, pd.DataFrame):
df: ibis.Table = ibis.memtable(df)
if (group_3_col is None) != (group_3_val is None):
raise ValueError("If group_3_col or group_3_val is populated, then the other must be as well")

if group_3_idx is not None:
cs_df["group_3"] = group_3_idx.astype(int)
group_cols += ["group_3"]
# Using a temporary value column to avoid duplicate column errors during selection. This happens when `value_col` has the same name as `customer_id`, causing conflicts in `.select()`.
temp_value_col = "temp_value_col"
df = df.mutate(**{temp_value_col: df[value_col]})

if (cs_df[group_cols].sum(axis=1) > 1).any():
raise ValueError("The groups must be mutually exclusive.")
group_1 = (df[group_1_col] == group_1_val).cast("int32").name("group_1")
group_2 = (df[group_2_col] == group_2_val).cast("int32").name("group_2")
group_3 = (df[group_3_col] == group_3_val).cast("int32").name("group_3") if group_3_col else None

if not any(group_1_idx) or not any(group_2_idx) or (group_3_idx is not None and not any(group_3_idx)):
raise ValueError("There must at least one row selected for group_1_idx, group_2_idx, and group_3_idx.")
group_cols = ["group_1", "group_2"]
select_cols = [df[cols.customer_id], group_1, group_2]
if group_3 is not None:
group_cols.append("group_3")
select_cols.append(group_3)

cs_df = df.select([*select_cols, df[temp_value_col]]).order_by(cols.customer_id)
cs_df = (
cs_df.group_by(cols.customer_id)
.aggregate(
**{col: cs_df[col].max().name(col) for col in group_cols},
**{temp_value_col: getattr(cs_df[temp_value_col], agg_func)().name(temp_value_col)},
)
.order_by(cols.customer_id)
).execute()

cs_df = cs_df.groupby(cols.customer_id)[group_cols].max()
cs_df["groups"] = cs_df[group_cols].apply(lambda x: tuple(x), axis=1)

kpi_df = df.groupby(cols.customer_id)[value_col].agg(agg_func)

return cs_df.merge(kpi_df, left_index=True, right_index=True)
column_order = [cols.customer_id, *group_cols, "groups", temp_value_col]
cs_df = cs_df[column_order]
cs_df.set_index(cols.customer_id, inplace=True)
return cs_df.rename(columns={temp_value_col: value_col})

@staticmethod
def _calc_cross_shop_table(
Expand Down
119 changes: 63 additions & 56 deletions tests/test_cross_shop.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,20 @@ def sample_data():
return pd.DataFrame(
{
cols.customer_id: [1, 2, 3, 4, 5, 5, 6, 7, 8, 8, 9, 10],
"group_1_idx": [True, False, False, False, False, True, True, False, False, True, False, True],
"group_2_idx": [False, True, False, False, True, False, False, True, False, False, True, False],
"group_3_idx": [False, False, True, False, False, False, False, False, True, False, False, False],
"category_1_name": [
"Jeans",
"Shoes",
"Dresses",
"Hats",
"Shoes",
"Jeans",
"Jeans",
"Shoes",
"Dresses",
"Jeans",
"Shoes",
"Jeans",
],
cols.unit_spend: [10, 20, 30, 40, 20, 50, 10, 20, 30, 15, 40, 50],
},
)
Expand All @@ -27,14 +38,16 @@ def test_calc_cross_shop_two_groups(sample_data):
"""Test the _calc_cross_shop method with two groups."""
cross_shop_df = CrossShop._calc_cross_shop(
sample_data,
group_1_idx=sample_data["group_1_idx"],
group_2_idx=sample_data["group_2_idx"],
group_1_col="category_1_name",
group_1_val="Jeans",
group_2_col="category_1_name",
group_2_val="Shoes",
)
ret_df = pd.DataFrame(
{
cols.customer_id: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
"group_1": [1, 0, 0, 0, 1, 1, 0, 1, 0, 1],
"group_2": [0, 1, 0, 0, 1, 0, 1, 0, 1, 0],
"group_1": pd.Series([1, 0, 0, 0, 1, 1, 0, 1, 0, 1], dtype="int32"),
"group_2": pd.Series([0, 1, 0, 0, 1, 0, 1, 0, 1, 0], dtype="int32"),
"groups": [(1, 0), (0, 1), (0, 0), (0, 0), (1, 1), (1, 0), (0, 1), (1, 0), (0, 1), (1, 0)],
cols.unit_spend: [10, 20, 30, 40, 70, 10, 20, 45, 40, 50],
},
Expand All @@ -47,16 +60,19 @@ def test_calc_cross_shop_three_groups(sample_data):
"""Test the _calc_cross_shop method with three groups."""
cross_shop_df = CrossShop._calc_cross_shop(
sample_data,
group_1_idx=sample_data["group_1_idx"],
group_2_idx=sample_data["group_2_idx"],
group_3_idx=sample_data["group_3_idx"],
group_1_col="category_1_name",
group_1_val="Jeans",
group_2_col="category_1_name",
group_2_val="Shoes",
group_3_col="category_1_name",
group_3_val="Dresses",
)
ret_df = pd.DataFrame(
{
cols.customer_id: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
"group_1": [1, 0, 0, 0, 1, 1, 0, 1, 0, 1],
"group_2": [0, 1, 0, 0, 1, 0, 1, 0, 1, 0],
"group_3": [0, 0, 1, 0, 0, 0, 0, 1, 0, 0],
"group_1": pd.Series([1, 0, 0, 0, 1, 1, 0, 1, 0, 1], dtype="int32"),
"group_2": pd.Series([0, 1, 0, 0, 1, 0, 1, 0, 1, 0], dtype="int32"),
"group_3": pd.Series([0, 0, 1, 0, 0, 0, 0, 1, 0, 0], dtype="int32"),
"groups": [
(1, 0, 0),
(0, 1, 0),
Expand All @@ -73,39 +89,19 @@ def test_calc_cross_shop_three_groups(sample_data):
},
).set_index(cols.customer_id)

assert cross_shop_df.equals(ret_df)


def test_calc_cross_shop_two_groups_overlap_error(sample_data):
"""Test the _calc_cross_shop method with two groups and overlapping group indices."""
with pytest.raises(ValueError):
CrossShop._calc_cross_shop(
sample_data,
# Pass the same group index for both groups
group_1_idx=sample_data["group_1_idx"],
group_2_idx=sample_data["group_1_idx"],
)


def test_calc_cross_shop_three_groups_overlap_error(sample_data):
"""Test the _calc_cross_shop method with three groups and overlapping group indices."""
with pytest.raises(ValueError):
CrossShop._calc_cross_shop(
sample_data,
# Pass the same group index for groups 1 and 3
group_1_idx=sample_data["group_1_idx"],
group_2_idx=sample_data["group_2_idx"],
group_3_idx=sample_data["group_1_idx"],
)
pd.testing.assert_frame_equal(cross_shop_df, ret_df, check_dtype=False)


def test_calc_cross_shop_three_groups_customer_id_nunique(sample_data):
"""Test the _calc_cross_shop method with three groups and customer_id as the value column."""
cross_shop_df = CrossShop._calc_cross_shop(
sample_data,
group_1_idx=sample_data["group_1_idx"],
group_2_idx=sample_data["group_2_idx"],
group_3_idx=sample_data["group_3_idx"],
group_1_col="category_1_name",
group_1_val="Jeans",
group_2_col="category_1_name",
group_2_val="Shoes",
group_3_col="category_1_name",
group_3_val="Dresses",
value_col=cols.customer_id,
agg_func="nunique",
)
Expand All @@ -131,17 +127,20 @@ def test_calc_cross_shop_three_groups_customer_id_nunique(sample_data):
index=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
)
ret_df.index.name = cols.customer_id

ret_df = ret_df.astype({"group_1": "int32", "group_2": "int32", "group_3": "int32"})
assert cross_shop_df.equals(ret_df)


def test_calc_cross_shop_table(sample_data):
"""Test the _calc_cross_shop_table method."""
cross_shop_df = CrossShop._calc_cross_shop(
sample_data,
group_1_idx=sample_data["group_1_idx"],
group_2_idx=sample_data["group_2_idx"],
group_3_idx=sample_data["group_3_idx"],
group_1_col="category_1_name",
group_1_val="Jeans",
group_2_col="category_1_name",
group_2_val="Shoes",
group_3_col="category_1_name",
group_3_val="Dresses",
value_col=cols.unit_spend,
)
cross_shop_table = CrossShop._calc_cross_shop_table(
Expand Down Expand Up @@ -174,9 +173,12 @@ def test_calc_cross_shop_table_customer_id_nunique(sample_data):
"""Test the _calc_cross_shop_table method with customer_id as the value column."""
cross_shop_df = CrossShop._calc_cross_shop(
sample_data,
group_1_idx=sample_data["group_1_idx"],
group_2_idx=sample_data["group_2_idx"],
group_3_idx=sample_data["group_3_idx"],
group_1_col="category_1_name",
group_1_val="Jeans",
group_2_col="category_1_name",
group_2_val="Shoes",
group_3_col="category_1_name",
group_3_val="Dresses",
value_col=cols.customer_id,
agg_func="nunique",
)
Expand All @@ -195,19 +197,24 @@ def test_calc_cross_shop_table_customer_id_nunique(sample_data):
assert cross_shop_table.equals(ret_df)


def test_calc_cross_shop_all_groups_false(sample_data):
"""Test the _calc_cross_shop method with all group indices set to False."""
with pytest.raises(ValueError):
def test_calc_cross_shop_invalid_group_3(sample_data):
"""Test that _calc_cross_shop raises ValueError if only one of group_3_col or group_3_val is provided."""
with pytest.raises(ValueError, match="If group_3_col or group_3_val is populated, then the other must be as well"):
CrossShop._calc_cross_shop(
sample_data,
group_1_idx=[False] * len(sample_data),
group_2_idx=[False] * len(sample_data),
group_1_col="category_1_name",
group_1_val="Jeans",
group_2_col="category_1_name",
group_2_val="Shoes",
group_3_col="category_1_name",
)

with pytest.raises(ValueError):
with pytest.raises(ValueError, match="If group_3_col or group_3_val is populated, then the other must be as well"):
CrossShop._calc_cross_shop(
sample_data,
group_1_idx=[False] * len(sample_data),
group_2_idx=[False] * len(sample_data),
group_3_idx=[False] * len(sample_data),
group_1_col="category_1_name",
group_1_val="Jeans",
group_2_col="category_1_name",
group_2_val="Shoes",
group_3_val="T-Shirts",
)