diff --git a/pyretailscience/plots/index.py b/pyretailscience/plots/index.py index 9ee1b9d..19b74a9 100644 --- a/pyretailscience/plots/index.py +++ b/pyretailscience/plots/index.py @@ -38,6 +38,7 @@ from typing import Literal +import ibis import numpy as np import pandas as pd from matplotlib.axes import Axes, SubplotBase @@ -52,9 +53,10 @@ def plot( # noqa: C901, PLR0913 (ignore complexity and line length) df: pd.DataFrame, - df_index_filter: list[bool], value_col: str, group_col: str, + index_col: str, + value_to_index: str, agg_func: str = "sum", series_col: str | None = None, title: str | None = None, @@ -93,9 +95,10 @@ def plot( # noqa: C901, PLR0913 (ignore complexity and line length) Args: df (pd.DataFrame): The dataframe to plot. - df_index_filter (list[bool]): The filter to apply to the dataframe. value_col (str): The column to plot. group_col (str): The column to group the data by. + index_col (str): The column to calculate the index on (e.g., "category"). + value_to_index (str): The baseline category or value to index against (e.g., "A"). agg_func (str, optional): The aggregation function to apply to the value_col. Defaults to "sum". series_col (str, optional): The column to use as the series. Defaults to None. title (str, optional): The title of the plot. Defaults to None. When None the title is set to @@ -135,15 +138,15 @@ def plot( # noqa: C901, PLR0913 (ignore complexity and line length) raise ValueError( "exclude_groups and include_only_groups cannot be used together.", ) - index_df = get_indexes( df=df, - df_index_filter=df_index_filter, - index_col=group_col, + index_col=index_col, + value_to_index=value_to_index, index_subgroup_col=series_col, value_col=value_col, agg_func=agg_func, offset=100, + group_col=group_col, ) if exclude_groups is not None: @@ -239,45 +242,78 @@ def plot( # noqa: C901, PLR0913 (ignore complexity and line length) def get_indexes( - df: pd.DataFrame, - df_index_filter: list[bool], + df: pd.DataFrame | ibis.Table, + value_to_index: str, index_col: str, value_col: str, + group_col: str, index_subgroup_col: str | None = None, agg_func: str = "sum", offset: int = 0, ) -> pd.DataFrame: - """Calculates the index of the value_col for the subset of a dataframe defined by df_index_filter. + """Calculates the index of the value_col using Ibis for efficient computation at scale. Args: - df (pd.DataFrame): The dataframe to calculate the index on. - df_index_filter (list[bool]): The boolean index to filter the data by. - index_col (str): The column to calculate the index on. - value_col (str): The column to calculate the index on. - index_subgroup_col (str, optional): The column to subgroup the index by. Defaults to None. - agg_func (str): The aggregation function to apply to the value_col. - offset (int, optional): The offset to subtract from the index. Defaults to 0. + df (pd.DataFrame | ibis.Table): The dataframe or Ibis table to calculate the index on. Can be a pandas dataframe or an Ibis table. + value_to_index (str): The baseline category or value to index against (e.g., "A"). + index_col (str): The column to calculate the index on (e.g., "category"). + value_col (str): The column to calculate the index on (e.g., "sales"). + group_col (str): The column to group the data by (e.g., "region"). + index_subgroup_col (str, optional): The column to subgroup the index by (e.g., "store_type"). Defaults to None. + agg_func (str, optional): The aggregation function to apply to the `value_col`. Valid options are "sum", "mean", "max", "min", or "nunique". Defaults to "sum". + offset (int, optional): The offset value to subtract from the index. This allows for adjustments to the index values. Defaults to 0. Returns: - pd.Series: The index of the value_col for the subset of data defined by filter_index. + pd.DataFrame: The calculated index values with grouping columns. """ - if all(df_index_filter) or not any(df_index_filter): - raise ValueError("The df_index_filter cannot be all True or all False.") + if isinstance(df, pd.DataFrame): + df = df.copy() + table = ibis.memtable(df) + else: + table = df + + agg_func = agg_func.lower() + if agg_func not in {"sum", "mean", "max", "min", "nunique"}: + raise ValueError("Unsupported aggregation function.") + + agg_fn = lambda x: getattr(x, agg_func)() - grp_cols = [index_col] if index_subgroup_col is None else [index_subgroup_col, index_col] + group_cols = [group_col] if index_subgroup_col is None else [index_subgroup_col, group_col] + + overall_agg = table.group_by(group_cols).aggregate(value=agg_fn(table[value_col])) - overall_df = df.groupby(grp_cols)[value_col].agg(agg_func).to_frame(value_col) if index_subgroup_col is None: - overall_total = overall_df[value_col].sum() + overall_total = overall_agg.value.sum().execute() + overall_props = overall_agg.mutate(proportion_overall=overall_agg.value / overall_total) else: - overall_total = overall_df.groupby(index_subgroup_col)[value_col].sum() - overall_s = overall_df[value_col] / overall_total + overall_total = overall_agg.group_by(index_subgroup_col).aggregate(total=lambda t: t.value.sum()) + overall_props = ( + overall_agg.join(overall_total, index_subgroup_col) + .mutate(proportion_overall=lambda t: t.value / t.total) + .drop("total") + ) + + table = table.filter(table[index_col] == value_to_index) + subset_agg = table.group_by(group_cols).aggregate(value=agg_fn(table[value_col])) - subset_df = df[df_index_filter].groupby(grp_cols)[value_col].agg(agg_func).to_frame(value_col) if index_subgroup_col is None: - subset_total = subset_df[value_col].sum() + subset_total = subset_agg.value.sum().name("total") + subset_props = subset_agg.mutate(proportion=subset_agg.value / subset_total) else: - subset_total = subset_df.groupby(index_subgroup_col)[value_col].sum() - subset_s = subset_df[value_col] / subset_total + subset_total = subset_agg.group_by(index_subgroup_col).aggregate(total=lambda t: t.value.sum()) + subset_props = ( + subset_agg.join(subset_total, index_subgroup_col) + .filter(lambda t: t.total != 0) + .mutate(proportion=lambda t: t.value / t.total) + .drop("total") + ) + + result = ( + subset_props.join(overall_props, group_cols) + .mutate( + index=lambda t: (t.proportion / t.proportion_overall * 100) - offset, + ) + .order_by(group_cols) + ) - return ((subset_s / overall_s * 100) - offset).to_frame("index").reset_index() + return result[[*group_cols, "index"]].execute() diff --git a/tests/plots/test_index.py b/tests/plots/test_index.py index d37dd77..ab01fbd 100644 --- a/tests/plots/test_index.py +++ b/tests/plots/test_index.py @@ -1,5 +1,6 @@ """Tests for the index plot module.""" +import ibis import matplotlib.pyplot as plt import numpy as np import pandas as pd @@ -8,56 +9,119 @@ from pyretailscience.plots.index import get_indexes, plot OFFSET_VALUE = 100 +OFFSET_THRESHOLD = 5 -def test_get_indexes_single_column(): - """Test that the function works with a single column index.""" +def test_get_indexes_basic(): + """Test get_indexes function with basic input to ensure it returns a valid DataFrame.""" df = pd.DataFrame( { - "group_col": ["A", "A", "B", "B", "C", "C"], - "filter_col": ["X", "Y", "X", "Y", "X", "Y"], - "value_col": [1, 2, 3, 4, 5, 6], + "category": ["A", "A", "B", "B", "C", "C"], + "value": [10, 20, 30, 40, 50, 60], }, ) - expected_output = pd.DataFrame({"group_col": ["A", "B", "C"], "index": [77.77777778, 100, 106.0606]}) - output = get_indexes( - df=df, - index_col="group_col", - df_index_filter=df["filter_col"] == "X", - value_col="value_col", - ) - pd.testing.assert_frame_equal(output, expected_output) + result = get_indexes(df, value_to_index="A", index_col="category", value_col="value", group_col="category") + assert isinstance(result, pd.DataFrame) + assert "category" in result.columns + assert "index" in result.columns + assert not result.empty -def test_get_indexes_two_columns(): - """Test that the function works with two columns as the index.""" + +def test_get_indexes_with_subgroup(): + """Test get_indexes function when a subgroup column is provided.""" df = pd.DataFrame( { - "group_col1": ["A", "A", "B", "B", "C", "C", "A", "A", "B", "B", "C", "C"], - "group_col2": ["D", "D", "D", "D", "D", "D", "E", "E", "E", "E", "E", "E"], - "filter_col": ["X", "Y", "X", "Y", "X", "Y", "X", "Y", "X", "Y", "X", "Y"], - "value_col": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + "subgroup": ["X", "X", "X", "Y", "Y", "Y"], + "category": ["A", "A", "B", "B", "C", "C"], + "value": [10, 20, 30, 40, 50, 60], }, ) - expected_output = pd.DataFrame( + + result = get_indexes( + df, + value_to_index="A", + index_col="category", + value_col="value", + group_col="category", + index_subgroup_col="subgroup", + ) + assert isinstance(result, pd.DataFrame) + assert "category" in result.columns + assert "index" in result.columns + assert not result.empty + + +def test_get_indexes_invalid_agg_func(): + """Test get_indexes function with an invalid aggregation function.""" + df = pd.DataFrame( { - "group_col2": ["D", "D", "D", "E", "E", "E"], - "group_col1": ["A", "B", "C", "A", "B", "C"], - "index": [77.77777778, 100, 106.0606, 98.51851852, 100, 100.9661836], + "category": ["A", "B", "C"], + "value": [10, 20, 30], }, ) - output = get_indexes( - df=df, - index_col="group_col1", - index_subgroup_col="group_col2", - df_index_filter=df["filter_col"] == "X", - value_col="value_col", + + with pytest.raises(ValueError, match="Unsupported aggregation function."): + get_indexes( + df, + value_to_index="A", + index_col="category", + value_col="value", + group_col="category", + agg_func="invalid_func", + ) + + +def test_get_indexes_with_different_aggregations(): + """Test get_indexes function with various aggregation functions.""" + df = pd.DataFrame( + { + "category": ["A", "A", "B", "B", "C", "C"], + "value": [10, 20, 30, 40, 50, 60], + }, ) - pd.testing.assert_frame_equal(output, expected_output) + + for agg in ["sum", "mean", "max", "min", "nunique"]: + result = get_indexes( + df, + value_to_index="A", + index_col="category", + value_col="value", + group_col="category", + agg_func=agg, + ) + assert isinstance(result, pd.DataFrame) + assert "category" in result.columns + assert "index" in result.columns + assert not result.empty def test_get_indexes_with_offset(): - """Test that the function works with an offset parameter.""" + """Test get_indexes function with an offset value.""" + df = pd.DataFrame( + { + "category": ["A", "B", "C"], + "value": [10, 20, 30], + }, + ) + result = get_indexes( + df, + value_to_index="A", + index_col="category", + value_col="value", + group_col="category", + offset=OFFSET_THRESHOLD, + ) + + assert isinstance(result, pd.DataFrame) + assert "category" in result.columns + assert "index" in result.columns + assert not result.empty + assert all(result["index"] >= -OFFSET_THRESHOLD) + + +def test_get_indexes_single_column(): + """Test that the function works with a single column index.""" df = pd.DataFrame( { "group_col": ["A", "A", "B", "B", "C", "C"], @@ -65,67 +129,61 @@ def test_get_indexes_with_offset(): "value_col": [1, 2, 3, 4, 5, 6], }, ) - expected_output = pd.DataFrame({"group_col": ["A", "B", "C"], "index": [-22.22222222, 0, 6.060606061]}) + expected_output = pd.DataFrame({"group_col": ["A", "B", "C"], "index": [77.77777778, 100, 106.0606]}) output = get_indexes( df=df, - index_col="group_col", - df_index_filter=df["filter_col"] == "X", + value_to_index="X", + index_col="filter_col", value_col="value_col", - offset=OFFSET_VALUE, # Replace magic number with the constant + group_col="group_col", ) pd.testing.assert_frame_equal(output, expected_output) -def test_get_indexes_with_agg_func(): - """Test that the function works with the nunique agg_func parameter.""" +def test_get_indexes_two_columns(): + """Test that the function works with two columns as the index.""" df = pd.DataFrame( { - "group_col1": ["A", "A", "A", "A", "B", "B", "B", "B", "C", "C", "C", "C"], + "group_col1": ["A", "A", "B", "B", "C", "C", "A", "A", "B", "B", "C", "C"], + "group_col2": ["D", "D", "D", "D", "D", "D", "E", "E", "E", "E", "E", "E"], "filter_col": ["X", "Y", "X", "Y", "X", "Y", "X", "Y", "X", "Y", "X", "Y"], - "value_col": [1, 1, 2, 2, 3, 3, 4, 4, 5, 6, 5, 8], + "value_col": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], }, ) expected_output = pd.DataFrame( { - "group_col1": ["A", "B", "C"], - "index": [140, 140, 46.6666667], + "group_col2": ["D", "D", "D", "E", "E", "E"], + "group_col1": ["A", "B", "C", "A", "B", "C"], + "index": [77.77777778, 100, 106.0606, 98.51851852, 100, 100.9661836], }, ) + output = get_indexes( df=df, - index_col="group_col1", - df_index_filter=df["filter_col"] == "X", + value_to_index="X", + index_col="filter_col", value_col="value_col", - agg_func="nunique", + group_col="group_col1", + index_subgroup_col="group_col2", ) pd.testing.assert_frame_equal(output, expected_output) -def test_get_indexes_index_filter_all_same(): - """Test that the function raises a ValueError when all the values in the index filter are the same.""" +def test_get_indexes_with_ibis_table_input(): + """Test that the get_indexes function works with an ibis Table.""" df = pd.DataFrame( { - "group_col": ["A", "A", "B", "B", "C", "C"], - "filter_col": ["X", "X", "X", "X", "X", "X"], - "value_col": [1, 2, 3, 4, 5, 6], + "category": ["A", "B", "C"], + "value": [10, 20, 30], }, ) - # Assert a value error will be reaised - with pytest.raises(ValueError): - get_indexes( - df=df, - df_index_filter=[True, True, True, True, True, True], - index_col="group_col", - value_col="value_col", - ) + table = ibis.memtable(df) - with pytest.raises(ValueError): - get_indexes( - df=df, - df_index_filter=[False, False, False, False, False, False], - index_col="group_col", - value_col="value_col", - ) + result = get_indexes(table, value_to_index="A", index_col="category", value_col="value", group_col="category") + assert isinstance(result, pd.DataFrame) + assert "category" in result.columns + assert "index" in result.columns + assert not result.empty class TestIndexPlot: @@ -142,115 +200,108 @@ def test_data(self): } return pd.DataFrame(data) - @pytest.fixture() - def df_index_filter(self, test_data): - """Return a boolean filter for the dataframe.""" - return [True, False, True, True, False, True, True, False, True, False] - def test_generates_index_plot_with_default_parameters( self, test_data, - df_index_filter, ): """Test that the function generates an index plot with default parameters.""" df = test_data result_ax = plot( df, - df_index_filter=df_index_filter, value_col="sales", group_col="category", + index_col="category", + value_to_index="A", ) assert isinstance(result_ax, plt.Axes) - assert len(result_ax.patches) > 0 # Ensures bars are plotted + assert len(result_ax.patches) > 0 assert result_ax.get_xlabel() == "Index" assert result_ax.get_ylabel() == "Category" - def test_generates_index_plot_with_custom_title(self, test_data, df_index_filter): + def test_generates_index_plot_with_custom_title(self, test_data): """Test that the function generates an index plot with a custom title.""" df = test_data custom_title = "Sales Performance by Category" result_ax = plot( df, - df_index_filter=df_index_filter, value_col="sales", group_col="category", + index_col="category", + value_to_index="A", title=custom_title, ) assert isinstance(result_ax, plt.Axes) assert result_ax.get_title() == custom_title - def test_generates_index_plot_with_highlight_range( - self, - test_data, - df_index_filter, - ): + def test_generates_index_plot_with_highlight_range(self, test_data): """Test that the function generates an index plot with a highlighted range.""" df = test_data result_ax = plot( df, - df_index_filter=df_index_filter, value_col="sales", group_col="category", + index_col="category", + value_to_index="A", highlight_range=(80, 120), ) assert isinstance(result_ax, plt.Axes) assert result_ax.get_xlim()[0] < OFFSET_VALUE < result_ax.get_xlim()[1] - def test_generates_index_plot_with_group_filter(self, test_data, df_index_filter): + def test_generates_index_plot_with_group_filter(self, test_data): """Test that the function generates an index plot with a group filter applied.""" df = test_data result_ax = plot( df, - df_index_filter=df_index_filter, value_col="sales", group_col="category", + index_col="category", + value_to_index="A", include_only_groups=["A", "B"], ) assert isinstance(result_ax, plt.Axes) - def test_raises_value_error_for_invalid_sort_by(self, test_data, df_index_filter): + def test_raises_value_error_for_invalid_sort_by(self, test_data): """Test that the function raises a ValueError for an invalid sort_by parameter.""" df = test_data with pytest.raises(ValueError): plot( df, - df_index_filter=df_index_filter, value_col="sales", group_col="category", + index_col="category", + value_to_index="A", sort_by="invalid", ) - def test_raises_value_error_for_invalid_sort_order( - self, - test_data, - df_index_filter, - ): + def test_raises_value_error_for_invalid_sort_order(self, test_data): """Test that the function raises a ValueError for an invalid sort_order parameter.""" df = test_data with pytest.raises(ValueError): plot( df, - df_index_filter=df_index_filter, value_col="sales", group_col="category", + index_col="category", + value_to_index="A", sort_order="invalid", ) - def test_generates_index_plot_with_source_text(self, test_data, df_index_filter): + def test_generates_index_plot_with_source_text(self, test_data): """Test that the function generates an index plot with source text.""" df = test_data source_text = "Data source: Company XYZ" result_ax = plot( df, - df_index_filter=df_index_filter, value_col="sales", group_col="category", + index_col="category", + value_to_index="A", source_text=source_text, ) @@ -258,47 +309,19 @@ def test_generates_index_plot_with_source_text(self, test_data, df_index_filter) source_texts = [text for text in result_ax.figure.texts if text.get_text() == source_text] assert len(source_texts) == 1 - def test_generates_index_plot_with_custom_labels(self, test_data, df_index_filter): + def test_generates_index_plot_with_custom_labels(self, test_data): """Test that the function generates an index plot with custom x and y labels.""" df = test_data result_ax = plot( df, - df_index_filter=df_index_filter, - value_col="sales", - group_col="category", - x_label="Index Value", - y_label="Product Category", - ) - - assert isinstance(result_ax, plt.Axes) - assert result_ax.get_xlabel() == "Index Value" - assert result_ax.get_ylabel() == "Product Category" - - def test_generates_index_plot_with_legend(self, test_data, df_index_filter): - """Test that the function generates an index plot with a legend when series_col is provided.""" - df = test_data - result_ax = plot( - df, - df_index_filter=df_index_filter, - value_col="sales", - group_col="category", - series_col="region", - ) - - assert isinstance(result_ax, plt.Axes) - assert result_ax.get_legend() is not None - assert len(result_ax.get_legend().get_texts()) > 0 - - def test_generates_index_plot_without_legend(self, test_data, df_index_filter): - """Test that the function generates an index plot without a legend when series_col is not provided.""" - df = test_data - result_ax = plot( - df, - df_index_filter=df_index_filter, value_col="sales", group_col="category", - series_col=None, + index_col="category", + value_to_index="A", + x_label="Sales Value", + y_label="Category Group", ) assert isinstance(result_ax, plt.Axes) - assert result_ax.get_legend() is None + assert result_ax.get_xlabel() == "Sales Value" + assert result_ax.get_ylabel() == "Category Group"