Skip to content

refactor with ibis #95

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 64 additions & 28 deletions pyretailscience/plots/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@

from typing import Literal

import ibis
import numpy as np
import pandas as pd
from matplotlib.axes import Axes, SubplotBase
Expand All @@ -52,9 +53,10 @@

def plot( # noqa: C901, PLR0913 (ignore complexity and line length)
df: pd.DataFrame,
df_index_filter: list[bool],
value_col: str,
group_col: str,
index_col: str,
value_to_index: str,
agg_func: str = "sum",
series_col: str | None = None,
title: str | None = None,
Expand Down Expand Up @@ -93,9 +95,10 @@ def plot( # noqa: C901, PLR0913 (ignore complexity and line length)

Args:
df (pd.DataFrame): The dataframe to plot.
df_index_filter (list[bool]): The filter to apply to the dataframe.
value_col (str): The column to plot.
group_col (str): The column to group the data by.
index_col (str): The column to calculate the index on (e.g., "category").
value_to_index (str): The baseline category or value to index against (e.g., "A").
agg_func (str, optional): The aggregation function to apply to the value_col. Defaults to "sum".
series_col (str, optional): The column to use as the series. Defaults to None.
title (str, optional): The title of the plot. Defaults to None. When None the title is set to
Expand Down Expand Up @@ -135,15 +138,15 @@ def plot( # noqa: C901, PLR0913 (ignore complexity and line length)
raise ValueError(
"exclude_groups and include_only_groups cannot be used together.",
)

index_df = get_indexes(
df=df,
df_index_filter=df_index_filter,
index_col=group_col,
index_col=index_col,
value_to_index=value_to_index,
index_subgroup_col=series_col,
value_col=value_col,
agg_func=agg_func,
offset=100,
group_col=group_col,
)

if exclude_groups is not None:
Expand Down Expand Up @@ -239,45 +242,78 @@ def plot( # noqa: C901, PLR0913 (ignore complexity and line length)


def get_indexes(
df: pd.DataFrame,
df_index_filter: list[bool],
df: pd.DataFrame | ibis.Table,
value_to_index: str,
index_col: str,
value_col: str,
group_col: str,
index_subgroup_col: str | None = None,
agg_func: str = "sum",
offset: int = 0,
) -> pd.DataFrame:
"""Calculates the index of the value_col for the subset of a dataframe defined by df_index_filter.
"""Calculates the index of the value_col using Ibis for efficient computation at scale.

Args:
df (pd.DataFrame): The dataframe to calculate the index on.
df_index_filter (list[bool]): The boolean index to filter the data by.
index_col (str): The column to calculate the index on.
value_col (str): The column to calculate the index on.
index_subgroup_col (str, optional): The column to subgroup the index by. Defaults to None.
agg_func (str): The aggregation function to apply to the value_col.
offset (int, optional): The offset to subtract from the index. Defaults to 0.
df (pd.DataFrame | ibis.Table): The dataframe or Ibis table to calculate the index on. Can be a pandas dataframe or an Ibis table.
value_to_index (str): The baseline category or value to index against (e.g., "A").
index_col (str): The column to calculate the index on (e.g., "category").
value_col (str): The column to calculate the index on (e.g., "sales").
group_col (str): The column to group the data by (e.g., "region").
index_subgroup_col (str, optional): The column to subgroup the index by (e.g., "store_type"). Defaults to None.
agg_func (str, optional): The aggregation function to apply to the `value_col`. Valid options are "sum", "mean", "max", "min", or "nunique". Defaults to "sum".
offset (int, optional): The offset value to subtract from the index. This allows for adjustments to the index values. Defaults to 0.

Returns:
pd.Series: The index of the value_col for the subset of data defined by filter_index.
pd.DataFrame: The calculated index values with grouping columns.
"""
if all(df_index_filter) or not any(df_index_filter):
raise ValueError("The df_index_filter cannot be all True or all False.")
if isinstance(df, pd.DataFrame):
df = df.copy()
table = ibis.memtable(df)
else:
table = df

agg_func = agg_func.lower()
if agg_func not in {"sum", "mean", "max", "min", "nunique"}:
raise ValueError("Unsupported aggregation function.")

agg_fn = lambda x: getattr(x, agg_func)()

grp_cols = [index_col] if index_subgroup_col is None else [index_subgroup_col, index_col]
group_cols = [group_col] if index_subgroup_col is None else [index_subgroup_col, group_col]

overall_agg = table.group_by(group_cols).aggregate(value=agg_fn(table[value_col]))

overall_df = df.groupby(grp_cols)[value_col].agg(agg_func).to_frame(value_col)
if index_subgroup_col is None:
overall_total = overall_df[value_col].sum()
overall_total = overall_agg.value.sum().execute()
overall_props = overall_agg.mutate(proportion_overall=overall_agg.value / overall_total)
else:
overall_total = overall_df.groupby(index_subgroup_col)[value_col].sum()
overall_s = overall_df[value_col] / overall_total
overall_total = overall_agg.group_by(index_subgroup_col).aggregate(total=lambda t: t.value.sum())
overall_props = (
overall_agg.join(overall_total, index_subgroup_col)
.mutate(proportion_overall=lambda t: t.value / t.total)
.drop("total")
)

table = table.filter(table[index_col] == value_to_index)
subset_agg = table.group_by(group_cols).aggregate(value=agg_fn(table[value_col]))

subset_df = df[df_index_filter].groupby(grp_cols)[value_col].agg(agg_func).to_frame(value_col)
if index_subgroup_col is None:
subset_total = subset_df[value_col].sum()
subset_total = subset_agg.value.sum().name("total")
subset_props = subset_agg.mutate(proportion=subset_agg.value / subset_total)
else:
subset_total = subset_df.groupby(index_subgroup_col)[value_col].sum()
subset_s = subset_df[value_col] / subset_total
subset_total = subset_agg.group_by(index_subgroup_col).aggregate(total=lambda t: t.value.sum())
subset_props = (
subset_agg.join(subset_total, index_subgroup_col)
.filter(lambda t: t.total != 0)
.mutate(proportion=lambda t: t.value / t.total)
.drop("total")
)

result = (
subset_props.join(overall_props, group_cols)
.mutate(
index=lambda t: (t.proportion / t.proportion_overall * 100) - offset,
)
.order_by(group_cols)
)

return ((subset_s / overall_s * 100) - offset).to_frame("index").reset_index()
return result[[*group_cols, "index"]].execute()
Loading