Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 45 additions & 14 deletions pyretailscience/plots/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@

from typing import Literal

import ibis
import numpy as np
import pandas as pd
from matplotlib.axes import Axes, SubplotBase
Expand Down Expand Up @@ -239,18 +240,18 @@


def get_indexes(
df: pd.DataFrame,
df: pd.DataFrame | ibis.Table,
df_index_filter: list[bool],
index_col: str,
value_col: str,
index_subgroup_col: str | None = None,
agg_func: str = "sum",
offset: int = 0,
) -> pd.DataFrame:
"""Calculates the index of the value_col for the subset of a dataframe defined by df_index_filter.
"""Calculates the index of the value_col using Ibis for efficient computation at scale.
Args:
df (pd.DataFrame): The dataframe to calculate the index on.
df (pd.DataFrame | ibis.Table): The dataframe or Ibis table to calculate the index on.
df_index_filter (list[bool]): The boolean index to filter the data by.
index_col (str): The column to calculate the index on.
value_col (str): The column to calculate the index on.
Expand All @@ -259,25 +260,55 @@
offset (int, optional): The offset to subtract from the index. Defaults to 0.
Returns:
pd.Series: The index of the value_col for the subset of data defined by filter_index.
pd.DataFrame: The calculated index values with grouping columns.
"""
if all(df_index_filter) or not any(df_index_filter):
raise ValueError("The df_index_filter cannot be all True or all False.")

grp_cols = [index_col] if index_subgroup_col is None else [index_subgroup_col, index_col]
if isinstance(df, pd.DataFrame):
df = df.copy()
df["_filter"] = df_index_filter
table = ibis.memtable(df)
else:
table = df.mutate(_filter=ibis.literal(df_index_filter))

Check warning on line 273 in pyretailscience/plots/index.py

View check run for this annotation

Codecov / codecov/patch

pyretailscience/plots/index.py#L273

Added line #L273 was not covered by tests

agg_func = agg_func.lower()
if agg_func not in {"sum", "mean", "max", "min", "nunique"}:
raise ValueError("Unsupported aggregation function.")

agg_fn = lambda x: getattr(x, agg_func)()

group_cols = [index_col] if index_subgroup_col is None else [index_subgroup_col, index_col]

overall_agg = table.group_by(group_cols).aggregate(value=agg_fn(table[value_col]))

overall_df = df.groupby(grp_cols)[value_col].agg(agg_func).to_frame(value_col)
if index_subgroup_col is None:
overall_total = overall_df[value_col].sum()
overall_total = overall_agg.value.sum().execute()
overall_props = overall_agg.mutate(proportion=overall_agg.value / overall_total)
else:
overall_total = overall_df.groupby(index_subgroup_col)[value_col].sum()
overall_s = overall_df[value_col] / overall_total
overall_total = overall_agg.group_by(index_subgroup_col).aggregate(total=lambda t: t.value.sum())
overall_props = (
overall_agg.join(overall_total, index_subgroup_col)
.mutate(proportion=lambda t: t.value / t.total)
.drop("total")
)

overall_props = overall_props.mutate(proportion_overall=overall_props.proportion).drop("proportion")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this just renaming proportion to proportion_overall?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it is

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just name it proportion_overall in the first place then and remove this line.


subset_agg = table.filter(table._filter).group_by(group_cols).aggregate(value=agg_fn(table[value_col]))

subset_df = df[df_index_filter].groupby(grp_cols)[value_col].agg(agg_func).to_frame(value_col)
if index_subgroup_col is None:
subset_total = subset_df[value_col].sum()
subset_total = subset_agg.value.sum().name("total")
subset_props = subset_agg.mutate(proportion=subset_agg.value / subset_total)
else:
subset_total = subset_df.groupby(index_subgroup_col)[value_col].sum()
subset_s = subset_df[value_col] / subset_total
subset_total = subset_agg.group_by(index_subgroup_col).aggregate(total=lambda t: t.value.sum())
subset_props = (
subset_agg.join(subset_total, index_subgroup_col)
.mutate(proportion=lambda t: t.value / t.total)
.drop("total")
)

return ((subset_s / overall_s * 100) - offset).to_frame("index").reset_index()
result = subset_props.join(overall_props, group_cols).mutate(
index=lambda t: (t.proportion / t.proportion_overall * 100) - offset,
)
return result.execute()
152 changes: 67 additions & 85 deletions tests/plots/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,124 +8,106 @@
from pyretailscience.plots.index import get_indexes, plot

OFFSET_VALUE = 100
OFFSET_THRESHOLD = -5


def test_get_indexes_single_column():
"""Test that the function works with a single column index."""
def test_get_indexes_basic():
"""Test get_indexes function with basic input to ensure it returns a valid DataFrame."""
df = pd.DataFrame(
{
"group_col": ["A", "A", "B", "B", "C", "C"],
"filter_col": ["X", "Y", "X", "Y", "X", "Y"],
"value_col": [1, 2, 3, 4, 5, 6],
"category": ["A", "A", "B", "B", "C", "C"],
"value": [10, 20, 30, 40, 50, 60],
},
)
expected_output = pd.DataFrame({"group_col": ["A", "B", "C"], "index": [77.77777778, 100, 106.0606]})
output = get_indexes(
df=df,
index_col="group_col",
df_index_filter=df["filter_col"] == "X",
value_col="value_col",
)
pd.testing.assert_frame_equal(output, expected_output)
df_index_filter = [True, False, True, False, True, False]

result = get_indexes(df, df_index_filter, "category", "value")
assert isinstance(result, pd.DataFrame)
assert "category" in result.columns
assert "index" in result.columns
assert not result.empty


def test_get_indexes_two_columns():
"""Test that the function works with two columns as the index."""
def test_get_indexes_with_subgroup():
"""Test get_indexes function when a subgroup column is provided."""
df = pd.DataFrame(
{
"group_col1": ["A", "A", "B", "B", "C", "C", "A", "A", "B", "B", "C", "C"],
"group_col2": ["D", "D", "D", "D", "D", "D", "E", "E", "E", "E", "E", "E"],
"filter_col": ["X", "Y", "X", "Y", "X", "Y", "X", "Y", "X", "Y", "X", "Y"],
"value_col": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
},
)
expected_output = pd.DataFrame(
{
"group_col2": ["D", "D", "D", "E", "E", "E"],
"group_col1": ["A", "B", "C", "A", "B", "C"],
"index": [77.77777778, 100, 106.0606, 98.51851852, 100, 100.9661836],
"subgroup": ["X", "X", "X", "Y", "Y", "Y"],
"category": ["A", "A", "B", "B", "C", "C"],
"value": [10, 20, 30, 40, 50, 60],
},
)
output = get_indexes(
df=df,
index_col="group_col1",
index_subgroup_col="group_col2",
df_index_filter=df["filter_col"] == "X",
value_col="value_col",
)
pd.testing.assert_frame_equal(output, expected_output)
df_index_filter = [True, False, True, False, True, False]

result = get_indexes(df, df_index_filter, "category", "value", index_subgroup_col="subgroup")
assert isinstance(result, pd.DataFrame)
assert "category" in result.columns
assert "index" in result.columns
assert not result.empty

def test_get_indexes_with_offset():
"""Test that the function works with an offset parameter."""

def test_get_indexes_invalid_filter():
"""Test get_indexes function with an invalid filter where all values are True."""
df = pd.DataFrame(
{
"group_col": ["A", "A", "B", "B", "C", "C"],
"filter_col": ["X", "Y", "X", "Y", "X", "Y"],
"value_col": [1, 2, 3, 4, 5, 6],
"category": ["A", "B", "C"],
"value": [10, 20, 30],
},
)
expected_output = pd.DataFrame({"group_col": ["A", "B", "C"], "index": [-22.22222222, 0, 6.060606061]})
output = get_indexes(
df=df,
index_col="group_col",
df_index_filter=df["filter_col"] == "X",
value_col="value_col",
offset=OFFSET_VALUE, # Replace magic number with the constant
)
pd.testing.assert_frame_equal(output, expected_output)
df_index_filter = [True, True, True] # Invalid case

with pytest.raises(ValueError, match="The df_index_filter cannot be all True or all False."):
get_indexes(df, df_index_filter, "category", "value")


def test_get_indexes_with_agg_func():
"""Test that the function works with the nunique agg_func parameter."""
def test_get_indexes_invalid_agg_func():
"""Test get_indexes function with an invalid aggregation function."""
df = pd.DataFrame(
{
"group_col1": ["A", "A", "A", "A", "B", "B", "B", "B", "C", "C", "C", "C"],
"filter_col": ["X", "Y", "X", "Y", "X", "Y", "X", "Y", "X", "Y", "X", "Y"],
"value_col": [1, 1, 2, 2, 3, 3, 4, 4, 5, 6, 5, 8],
"category": ["A", "B", "C"],
"value": [10, 20, 30],
},
)
expected_output = pd.DataFrame(
df_index_filter = [True, False, True]

with pytest.raises(ValueError, match="Unsupported aggregation function."):
get_indexes(df, df_index_filter, "category", "value", agg_func="invalid_func")


def test_get_indexes_with_different_aggregations():
"""Test get_indexes function with various aggregation functions."""
df = pd.DataFrame(
{
"group_col1": ["A", "B", "C"],
"index": [140, 140, 46.6666667],
"category": ["A", "A", "B", "B", "C", "C"],
"value": [10, 20, 30, 40, 50, 60],
},
)
output = get_indexes(
df=df,
index_col="group_col1",
df_index_filter=df["filter_col"] == "X",
value_col="value_col",
agg_func="nunique",
)
pd.testing.assert_frame_equal(output, expected_output)
df_index_filter = [True, False, True, False, True, False]

for agg in ["sum", "mean", "max", "min", "nunique"]:
result = get_indexes(df, df_index_filter, "category", "value", agg_func=agg)
assert isinstance(result, pd.DataFrame)
assert "category" in result.columns
assert "index" in result.columns
assert not result.empty

def test_get_indexes_index_filter_all_same():
"""Test that the function raises a ValueError when all the values in the index filter are the same."""

def test_get_indexes_with_offset():
"""Test get_indexes function with an offset value."""
df = pd.DataFrame(
{
"group_col": ["A", "A", "B", "B", "C", "C"],
"filter_col": ["X", "X", "X", "X", "X", "X"],
"value_col": [1, 2, 3, 4, 5, 6],
"category": ["A", "B", "C"],
"value": [10, 20, 30],
},
)
# Assert a value error will be reaised
with pytest.raises(ValueError):
get_indexes(
df=df,
df_index_filter=[True, True, True, True, True, True],
index_col="group_col",
value_col="value_col",
)

with pytest.raises(ValueError):
get_indexes(
df=df,
df_index_filter=[False, False, False, False, False, False],
index_col="group_col",
value_col="value_col",
)
df_index_filter = [True, False, True]
result = get_indexes(df, df_index_filter, "category", "value", offset=5)

assert isinstance(result, pd.DataFrame)
assert "category" in result.columns
assert "index" in result.columns
assert not result.empty
assert all(result["index"] >= OFFSET_THRESHOLD)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since OFFSET_THRESHOLD is only used in this function, can you move its instantiation here please.



class TestIndexPlot:
Expand Down