From 3b26d3857319a84db20a5d68163db04da1c1d3e0 Mon Sep 17 00:00:00 2001 From: MayurK Date: Wed, 26 Feb 2025 15:45:26 +0530 Subject: [PATCH] feat: refactor cross-shop code to ibis --- pyretailscience/cross_shop.py | 115 ++++++++++++++++++-------------- tests/test_cross_shop.py | 119 ++++++++++++++++++---------------- 2 files changed, 130 insertions(+), 104 deletions(-) diff --git a/pyretailscience/cross_shop.py b/pyretailscience/cross_shop.py index bc8b310..a7c9684 100644 --- a/pyretailscience/cross_shop.py +++ b/pyretailscience/cross_shop.py @@ -1,5 +1,7 @@ """This module contains the CrossShop class that is used to create a cross-shop diagram.""" + +import ibis import matplotlib.pyplot as plt import pandas as pd from matplotlib.axes import Axes, SubplotBase @@ -16,10 +18,13 @@ class CrossShop: def __init__( self, - df: pd.DataFrame, - group_1_idx: list[bool] | pd.Series, - group_2_idx: list[bool] | pd.Series, - group_3_idx: list[bool] | pd.Series | None = None, + df: pd.DataFrame | ibis.Table, + group_1_col: str, + group_1_val: str, + group_2_col: str, + group_2_val: str, + group_3_col: str | None = None, + group_3_val: str | None = None, labels: list[str] | None = None, value_col: str = get_option("column.unit_spend"), agg_func: str = "sum", @@ -27,13 +32,13 @@ def __init__( """Creates a cross-shop diagram that is used to show the overlap of customers between different groups. Args: - df (pd.DataFrame): The dataframe with transactional data. - group_1_idx (list[bool], pd.Series): A list of bool values determining whether the row is a part of the - first group. - group_2_idx (list[bool], pd.Series): A list of bool values determining whether the row is a part of the - second group. - group_3_idx (list[bool], pd.Series, optional): An optional list of bool values determining whether the - row is a part of the third group. Defaults to None. If not supplied, only two groups will be considered. + df (pd.DataFrame | ibis.Table): The input DataFrame or ibis Table containing transactional data. + group_1_col (str): The column name for the first group. + group_1_val (str): The value of the first group to match. + group_2_col (str): The column name for the second group. + group_2_val (str): The value of the second group to match. + group_3_col (str, optional): The column name for the third group. Defaults to None. + group_3_val (str, optional): The value of the third group to match. Defaults to None. labels (list[str], optional): The labels for the groups. Defaults to None. value_col (str, optional): The column to aggregate. Defaults to the option column.unit_spend. agg_func (str, optional): The aggregation function. Defaults to "sum". @@ -51,7 +56,7 @@ def __init__( msg = f"The following columns are required but missing: {missing_cols}" raise ValueError(msg) - self.group_count = 2 if group_3_idx is None else 3 + self.group_count = 2 if group_3_col is None else 3 if (labels is not None) and (len(labels) != self.group_count): raise ValueError("The number of labels must be equal to the number of group indexes given") @@ -60,9 +65,12 @@ def __init__( self.cross_shop_df = self._calc_cross_shop( df=df, - group_1_idx=group_1_idx, - group_2_idx=group_2_idx, - group_3_idx=group_3_idx, + group_1_col=group_1_col, + group_1_val=group_1_val, + group_2_col=group_2_col, + group_2_val=group_2_val, + group_3_col=group_3_col, + group_3_val=group_3_val, value_col=value_col, agg_func=agg_func, ) @@ -73,21 +81,26 @@ def __init__( @staticmethod def _calc_cross_shop( - df: pd.DataFrame, - group_1_idx: list[bool], - group_2_idx: list[bool], - group_3_idx: list[bool] | None = None, + df: pd.DataFrame | ibis.Table, + group_1_col: str, + group_1_val: str, + group_2_col: str, + group_2_val: str, + group_3_col: str | None = None, + group_3_val: str | None = None, value_col: str = get_option("column.unit_spend"), agg_func: str = "sum", ) -> pd.DataFrame: """Calculate the cross-shop dataframe that will be used to plot the diagram. Args: - df (pd.DataFrame): The dataframe with transactional data. - group_1_idx (list[bool]): A list of bool values determining whether the row is a part of the first group. - group_2_idx (list[bool]): A list of bool values determining whether the row is a part of the second group. - group_3_idx (list[bool], optional): An optional list of bool values determining whether the row is a part - of the third group. Defaults to None. If not supplied, only two groups will be considered. + df (pd.DataFrame | ibis.Table): The input DataFrame or ibis Table containing transactional data. + group_1_col (str): Column name for the first group. + group_1_val (str): Value to filter for the first group. + group_2_col (str): Column name for the second group. + group_2_val (str): Value to filter for the second group. + group_3_col (str, optional): Column name for the third group. Defaults to None. + group_3_val (str, optional): Value to filter for the third group. Defaults to None. value_col (str, optional): The column to aggregate. Defaults to option column.unit_spend. agg_func (str, optional): The aggregation function. Defaults to "sum". @@ -95,38 +108,44 @@ def _calc_cross_shop( pd.DataFrame: The cross-shop dataframe. Raises: - ValueError: If the groups are not mutually exclusive. + ValueError: If group_3_col or group_3_val is populated, then the other must be as well. """ cols = ColumnHelper() - if isinstance(group_1_idx, list): - group_1_idx = pd.Series(group_1_idx) - if isinstance(group_2_idx, list): - group_2_idx = pd.Series(group_2_idx) - if group_3_idx is not None and isinstance(group_3_idx, list): - group_3_idx = pd.Series(group_3_idx) - cs_df = df[[cols.customer_id]].copy() - - cs_df["group_1"] = group_1_idx.astype(int) - cs_df["group_2"] = group_2_idx.astype(int) - group_cols = ["group_1", "group_2"] + if isinstance(df, pd.DataFrame): + df: ibis.Table = ibis.memtable(df) + if (group_3_col is None) != (group_3_val is None): + raise ValueError("If group_3_col or group_3_val is populated, then the other must be as well") - if group_3_idx is not None: - cs_df["group_3"] = group_3_idx.astype(int) - group_cols += ["group_3"] + # Using a temporary value column to avoid duplicate column errors during selection. This happens when `value_col` has the same name as `customer_id`, causing conflicts in `.select()`. + temp_value_col = "temp_value_col" + df = df.mutate(**{temp_value_col: df[value_col]}) - if (cs_df[group_cols].sum(axis=1) > 1).any(): - raise ValueError("The groups must be mutually exclusive.") + group_1 = (df[group_1_col] == group_1_val).cast("int32").name("group_1") + group_2 = (df[group_2_col] == group_2_val).cast("int32").name("group_2") + group_3 = (df[group_3_col] == group_3_val).cast("int32").name("group_3") if group_3_col else None - if not any(group_1_idx) or not any(group_2_idx) or (group_3_idx is not None and not any(group_3_idx)): - raise ValueError("There must at least one row selected for group_1_idx, group_2_idx, and group_3_idx.") + group_cols = ["group_1", "group_2"] + select_cols = [df[cols.customer_id], group_1, group_2] + if group_3 is not None: + group_cols.append("group_3") + select_cols.append(group_3) + + cs_df = df.select([*select_cols, df[temp_value_col]]).order_by(cols.customer_id) + cs_df = ( + cs_df.group_by(cols.customer_id) + .aggregate( + **{col: cs_df[col].max().name(col) for col in group_cols}, + **{temp_value_col: getattr(cs_df[temp_value_col], agg_func)().name(temp_value_col)}, + ) + .order_by(cols.customer_id) + ).execute() - cs_df = cs_df.groupby(cols.customer_id)[group_cols].max() cs_df["groups"] = cs_df[group_cols].apply(lambda x: tuple(x), axis=1) - - kpi_df = df.groupby(cols.customer_id)[value_col].agg(agg_func) - - return cs_df.merge(kpi_df, left_index=True, right_index=True) + column_order = [cols.customer_id, *group_cols, "groups", temp_value_col] + cs_df = cs_df[column_order] + cs_df.set_index(cols.customer_id, inplace=True) + return cs_df.rename(columns={temp_value_col: value_col}) @staticmethod def _calc_cross_shop_table( diff --git a/tests/test_cross_shop.py b/tests/test_cross_shop.py index 3401270..00cb8b0 100644 --- a/tests/test_cross_shop.py +++ b/tests/test_cross_shop.py @@ -15,9 +15,20 @@ def sample_data(): return pd.DataFrame( { cols.customer_id: [1, 2, 3, 4, 5, 5, 6, 7, 8, 8, 9, 10], - "group_1_idx": [True, False, False, False, False, True, True, False, False, True, False, True], - "group_2_idx": [False, True, False, False, True, False, False, True, False, False, True, False], - "group_3_idx": [False, False, True, False, False, False, False, False, True, False, False, False], + "category_1_name": [ + "Jeans", + "Shoes", + "Dresses", + "Hats", + "Shoes", + "Jeans", + "Jeans", + "Shoes", + "Dresses", + "Jeans", + "Shoes", + "Jeans", + ], cols.unit_spend: [10, 20, 30, 40, 20, 50, 10, 20, 30, 15, 40, 50], }, ) @@ -27,14 +38,16 @@ def test_calc_cross_shop_two_groups(sample_data): """Test the _calc_cross_shop method with two groups.""" cross_shop_df = CrossShop._calc_cross_shop( sample_data, - group_1_idx=sample_data["group_1_idx"], - group_2_idx=sample_data["group_2_idx"], + group_1_col="category_1_name", + group_1_val="Jeans", + group_2_col="category_1_name", + group_2_val="Shoes", ) ret_df = pd.DataFrame( { cols.customer_id: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], - "group_1": [1, 0, 0, 0, 1, 1, 0, 1, 0, 1], - "group_2": [0, 1, 0, 0, 1, 0, 1, 0, 1, 0], + "group_1": pd.Series([1, 0, 0, 0, 1, 1, 0, 1, 0, 1], dtype="int32"), + "group_2": pd.Series([0, 1, 0, 0, 1, 0, 1, 0, 1, 0], dtype="int32"), "groups": [(1, 0), (0, 1), (0, 0), (0, 0), (1, 1), (1, 0), (0, 1), (1, 0), (0, 1), (1, 0)], cols.unit_spend: [10, 20, 30, 40, 70, 10, 20, 45, 40, 50], }, @@ -47,16 +60,19 @@ def test_calc_cross_shop_three_groups(sample_data): """Test the _calc_cross_shop method with three groups.""" cross_shop_df = CrossShop._calc_cross_shop( sample_data, - group_1_idx=sample_data["group_1_idx"], - group_2_idx=sample_data["group_2_idx"], - group_3_idx=sample_data["group_3_idx"], + group_1_col="category_1_name", + group_1_val="Jeans", + group_2_col="category_1_name", + group_2_val="Shoes", + group_3_col="category_1_name", + group_3_val="Dresses", ) ret_df = pd.DataFrame( { cols.customer_id: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], - "group_1": [1, 0, 0, 0, 1, 1, 0, 1, 0, 1], - "group_2": [0, 1, 0, 0, 1, 0, 1, 0, 1, 0], - "group_3": [0, 0, 1, 0, 0, 0, 0, 1, 0, 0], + "group_1": pd.Series([1, 0, 0, 0, 1, 1, 0, 1, 0, 1], dtype="int32"), + "group_2": pd.Series([0, 1, 0, 0, 1, 0, 1, 0, 1, 0], dtype="int32"), + "group_3": pd.Series([0, 0, 1, 0, 0, 0, 0, 1, 0, 0], dtype="int32"), "groups": [ (1, 0, 0), (0, 1, 0), @@ -73,39 +89,19 @@ def test_calc_cross_shop_three_groups(sample_data): }, ).set_index(cols.customer_id) - assert cross_shop_df.equals(ret_df) - - -def test_calc_cross_shop_two_groups_overlap_error(sample_data): - """Test the _calc_cross_shop method with two groups and overlapping group indices.""" - with pytest.raises(ValueError): - CrossShop._calc_cross_shop( - sample_data, - # Pass the same group index for both groups - group_1_idx=sample_data["group_1_idx"], - group_2_idx=sample_data["group_1_idx"], - ) - - -def test_calc_cross_shop_three_groups_overlap_error(sample_data): - """Test the _calc_cross_shop method with three groups and overlapping group indices.""" - with pytest.raises(ValueError): - CrossShop._calc_cross_shop( - sample_data, - # Pass the same group index for groups 1 and 3 - group_1_idx=sample_data["group_1_idx"], - group_2_idx=sample_data["group_2_idx"], - group_3_idx=sample_data["group_1_idx"], - ) + pd.testing.assert_frame_equal(cross_shop_df, ret_df, check_dtype=False) def test_calc_cross_shop_three_groups_customer_id_nunique(sample_data): """Test the _calc_cross_shop method with three groups and customer_id as the value column.""" cross_shop_df = CrossShop._calc_cross_shop( sample_data, - group_1_idx=sample_data["group_1_idx"], - group_2_idx=sample_data["group_2_idx"], - group_3_idx=sample_data["group_3_idx"], + group_1_col="category_1_name", + group_1_val="Jeans", + group_2_col="category_1_name", + group_2_val="Shoes", + group_3_col="category_1_name", + group_3_val="Dresses", value_col=cols.customer_id, agg_func="nunique", ) @@ -131,7 +127,7 @@ def test_calc_cross_shop_three_groups_customer_id_nunique(sample_data): index=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], ) ret_df.index.name = cols.customer_id - + ret_df = ret_df.astype({"group_1": "int32", "group_2": "int32", "group_3": "int32"}) assert cross_shop_df.equals(ret_df) @@ -139,9 +135,12 @@ def test_calc_cross_shop_table(sample_data): """Test the _calc_cross_shop_table method.""" cross_shop_df = CrossShop._calc_cross_shop( sample_data, - group_1_idx=sample_data["group_1_idx"], - group_2_idx=sample_data["group_2_idx"], - group_3_idx=sample_data["group_3_idx"], + group_1_col="category_1_name", + group_1_val="Jeans", + group_2_col="category_1_name", + group_2_val="Shoes", + group_3_col="category_1_name", + group_3_val="Dresses", value_col=cols.unit_spend, ) cross_shop_table = CrossShop._calc_cross_shop_table( @@ -174,9 +173,12 @@ def test_calc_cross_shop_table_customer_id_nunique(sample_data): """Test the _calc_cross_shop_table method with customer_id as the value column.""" cross_shop_df = CrossShop._calc_cross_shop( sample_data, - group_1_idx=sample_data["group_1_idx"], - group_2_idx=sample_data["group_2_idx"], - group_3_idx=sample_data["group_3_idx"], + group_1_col="category_1_name", + group_1_val="Jeans", + group_2_col="category_1_name", + group_2_val="Shoes", + group_3_col="category_1_name", + group_3_val="Dresses", value_col=cols.customer_id, agg_func="nunique", ) @@ -195,19 +197,24 @@ def test_calc_cross_shop_table_customer_id_nunique(sample_data): assert cross_shop_table.equals(ret_df) -def test_calc_cross_shop_all_groups_false(sample_data): - """Test the _calc_cross_shop method with all group indices set to False.""" - with pytest.raises(ValueError): +def test_calc_cross_shop_invalid_group_3(sample_data): + """Test that _calc_cross_shop raises ValueError if only one of group_3_col or group_3_val is provided.""" + with pytest.raises(ValueError, match="If group_3_col or group_3_val is populated, then the other must be as well"): CrossShop._calc_cross_shop( sample_data, - group_1_idx=[False] * len(sample_data), - group_2_idx=[False] * len(sample_data), + group_1_col="category_1_name", + group_1_val="Jeans", + group_2_col="category_1_name", + group_2_val="Shoes", + group_3_col="category_1_name", ) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="If group_3_col or group_3_val is populated, then the other must be as well"): CrossShop._calc_cross_shop( sample_data, - group_1_idx=[False] * len(sample_data), - group_2_idx=[False] * len(sample_data), - group_3_idx=[False] * len(sample_data), + group_1_col="category_1_name", + group_1_val="Jeans", + group_2_col="category_1_name", + group_2_val="Shoes", + group_3_val="T-Shirts", )