Skip to content

Commit 2b173c9

Browse files
committed
fix: remove df_filter and refactor the code and test cases
1 parent 17f178b commit 2b173c9

File tree

2 files changed

+84
-108
lines changed

2 files changed

+84
-108
lines changed

pyretailscience/plots/index.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,10 @@
5353

5454
def plot( # noqa: C901, PLR0913 (ignore complexity and line length)
5555
df: pd.DataFrame,
56-
df_index_filter: list[bool],
5756
value_col: str,
5857
group_col: str,
58+
index_col: str,
59+
value_to_index: str,
5960
agg_func: str = "sum",
6061
series_col: str | None = None,
6162
title: str | None = None,
@@ -94,9 +95,10 @@ def plot( # noqa: C901, PLR0913 (ignore complexity and line length)
9495
9596
Args:
9697
df (pd.DataFrame): The dataframe to plot.
97-
df_index_filter (list[bool]): The filter to apply to the dataframe.
9898
value_col (str): The column to plot.
9999
group_col (str): The column to group the data by.
100+
index_col (str): The column to calculate the index on (e.g., "category").
101+
value_to_index (str): The baseline category or value to index against (e.g., "A").
100102
agg_func (str, optional): The aggregation function to apply to the value_col. Defaults to "sum".
101103
series_col (str, optional): The column to use as the series. Defaults to None.
102104
title (str, optional): The title of the plot. Defaults to None. When None the title is set to
@@ -136,15 +138,15 @@ def plot( # noqa: C901, PLR0913 (ignore complexity and line length)
136138
raise ValueError(
137139
"exclude_groups and include_only_groups cannot be used together.",
138140
)
139-
140141
index_df = get_indexes(
141142
df=df,
142-
df_index_filter=df_index_filter,
143-
index_col=group_col,
143+
index_col=index_col,
144+
value_to_index=value_to_index,
144145
index_subgroup_col=series_col,
145146
value_col=value_col,
146147
agg_func=agg_func,
147148
offset=100,
149+
group_col=group_col,
148150
)
149151

150152
if exclude_groups is not None:
@@ -241,44 +243,43 @@ def plot( # noqa: C901, PLR0913 (ignore complexity and line length)
241243

242244
def get_indexes(
243245
df: pd.DataFrame | ibis.Table,
244-
df_index_filter: list[bool],
246+
value_to_index: str,
245247
index_col: str,
246248
value_col: str,
249+
group_col: str,
247250
index_subgroup_col: str | None = None,
248251
agg_func: str = "sum",
249252
offset: int = 0,
250253
) -> pd.DataFrame:
251254
"""Calculates the index of the value_col using Ibis for efficient computation at scale.
252255
253256
Args:
254-
df (pd.DataFrame | ibis.Table): The dataframe or Ibis table to calculate the index on.
255-
df_index_filter (list[bool]): The boolean index to filter the data by.
256-
index_col (str): The column to calculate the index on.
257-
value_col (str): The column to calculate the index on.
258-
index_subgroup_col (str, optional): The column to subgroup the index by. Defaults to None.
259-
agg_func (str): The aggregation function to apply to the value_col.
260-
offset (int, optional): The offset to subtract from the index. Defaults to 0.
257+
df (pd.DataFrame | ibis.Table): The dataframe or Ibis table to calculate the index on. Can be a pandas dataframe or an Ibis table.
258+
value_to_index (str): The baseline category or value to index against (e.g., "A").
259+
index_col (str): The column to calculate the index on (e.g., "category").
260+
value_col (str): The column to calculate the index on (e.g., "sales").
261+
group_col (str): The column to group the data by (e.g., "region").
262+
index_subgroup_col (str, optional): The column to subgroup the index by (e.g., "store_type"). Defaults to None.
263+
agg_func (str, optional): The aggregation function to apply to the `value_col`. Valid options are "sum", "mean", "max", "min", or "nunique". Defaults to "sum".
264+
offset (int, optional): The offset value to subtract from the index. This allows for adjustments to the index values. Defaults to 0.
261265
262266
Returns:
263267
pd.DataFrame: The calculated index values with grouping columns.
264268
"""
265-
if all(df_index_filter) or not any(df_index_filter):
266-
raise ValueError("The df_index_filter cannot be all True or all False.")
267-
268269
if isinstance(df, pd.DataFrame):
269270
df = df.copy()
270-
df["_filter"] = df_index_filter
271+
df["_filter"] = value_to_index
271272
table = ibis.memtable(df)
272273
else:
273-
table = df.mutate(_filter=ibis.literal(df_index_filter))
274+
table = df.mutate(_filter=ibis.literal(value_to_index))
274275

275276
agg_func = agg_func.lower()
276277
if agg_func not in {"sum", "mean", "max", "min", "nunique"}:
277278
raise ValueError("Unsupported aggregation function.")
278279

279280
agg_fn = lambda x: getattr(x, agg_func)()
280281

281-
group_cols = [index_col] if index_subgroup_col is None else [index_subgroup_col, index_col]
282+
group_cols = [group_col] if index_subgroup_col is None else [index_subgroup_col, group_col]
282283

283284
overall_agg = table.group_by(group_cols).aggregate(value=agg_fn(table[value_col]))
284285

@@ -294,8 +295,8 @@ def get_indexes(
294295
)
295296

296297
overall_props = overall_props.mutate(proportion_overall=overall_props.proportion).drop("proportion")
297-
298-
subset_agg = table.filter(table._filter).group_by(group_cols).aggregate(value=agg_fn(table[value_col]))
298+
table = table.filter(table[index_col] == value_to_index)
299+
subset_agg = table.group_by(group_cols).aggregate(value=agg_fn(table[value_col]))
299300

300301
if index_subgroup_col is None:
301302
subset_total = subset_agg.value.sum().name("total")
@@ -311,4 +312,5 @@ def get_indexes(
311312
result = subset_props.join(overall_props, group_cols).mutate(
312313
index=lambda t: (t.proportion / t.proportion_overall * 100) - offset,
313314
)
315+
314316
return result.execute()

0 commit comments

Comments
 (0)