-
Notifications
You must be signed in to change notification settings - Fork 1
refactor with ibis #95
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
17f178b
2b173c9
8d6e353
26459e9
4505b7a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -38,6 +38,7 @@ | |
|
|
||
| from typing import Literal | ||
|
|
||
| import ibis | ||
| import numpy as np | ||
| import pandas as pd | ||
| from matplotlib.axes import Axes, SubplotBase | ||
|
|
@@ -239,18 +240,18 @@ | |
|
|
||
|
|
||
| def get_indexes( | ||
| df: pd.DataFrame, | ||
| df: pd.DataFrame | ibis.Table, | ||
| df_index_filter: list[bool], | ||
| index_col: str, | ||
| value_col: str, | ||
| index_subgroup_col: str | None = None, | ||
| agg_func: str = "sum", | ||
| offset: int = 0, | ||
| ) -> pd.DataFrame: | ||
| """Calculates the index of the value_col for the subset of a dataframe defined by df_index_filter. | ||
| """Calculates the index of the value_col using Ibis for efficient computation at scale. | ||
| Args: | ||
| df (pd.DataFrame): The dataframe to calculate the index on. | ||
| df (pd.DataFrame | ibis.Table): The dataframe or Ibis table to calculate the index on. | ||
| df_index_filter (list[bool]): The boolean index to filter the data by. | ||
| index_col (str): The column to calculate the index on. | ||
| value_col (str): The column to calculate the index on. | ||
|
|
@@ -259,25 +260,55 @@ | |
| offset (int, optional): The offset to subtract from the index. Defaults to 0. | ||
| Returns: | ||
| pd.Series: The index of the value_col for the subset of data defined by filter_index. | ||
| pd.DataFrame: The calculated index values with grouping columns. | ||
| """ | ||
| if all(df_index_filter) or not any(df_index_filter): | ||
| raise ValueError("The df_index_filter cannot be all True or all False.") | ||
|
|
||
| grp_cols = [index_col] if index_subgroup_col is None else [index_subgroup_col, index_col] | ||
| if isinstance(df, pd.DataFrame): | ||
| df = df.copy() | ||
| df["_filter"] = df_index_filter | ||
| table = ibis.memtable(df) | ||
| else: | ||
| table = df.mutate(_filter=ibis.literal(df_index_filter)) | ||
|
|
||
| agg_func = agg_func.lower() | ||
| if agg_func not in {"sum", "mean", "max", "min", "nunique"}: | ||
| raise ValueError("Unsupported aggregation function.") | ||
|
|
||
| agg_fn = lambda x: getattr(x, agg_func)() | ||
|
|
||
| group_cols = [index_col] if index_subgroup_col is None else [index_subgroup_col, index_col] | ||
|
|
||
| overall_agg = table.group_by(group_cols).aggregate(value=agg_fn(table[value_col])) | ||
|
|
||
| overall_df = df.groupby(grp_cols)[value_col].agg(agg_func).to_frame(value_col) | ||
| if index_subgroup_col is None: | ||
| overall_total = overall_df[value_col].sum() | ||
| overall_total = overall_agg.value.sum().execute() | ||
| overall_props = overall_agg.mutate(proportion=overall_agg.value / overall_total) | ||
| else: | ||
| overall_total = overall_df.groupby(index_subgroup_col)[value_col].sum() | ||
| overall_s = overall_df[value_col] / overall_total | ||
| overall_total = overall_agg.group_by(index_subgroup_col).aggregate(total=lambda t: t.value.sum()) | ||
| overall_props = ( | ||
| overall_agg.join(overall_total, index_subgroup_col) | ||
| .mutate(proportion=lambda t: t.value / t.total) | ||
| .drop("total") | ||
| ) | ||
|
|
||
| overall_props = overall_props.mutate(proportion_overall=overall_props.proportion).drop("proportion") | ||
|
||
|
|
||
| subset_agg = table.filter(table._filter).group_by(group_cols).aggregate(value=agg_fn(table[value_col])) | ||
|
|
||
| subset_df = df[df_index_filter].groupby(grp_cols)[value_col].agg(agg_func).to_frame(value_col) | ||
| if index_subgroup_col is None: | ||
| subset_total = subset_df[value_col].sum() | ||
| subset_total = subset_agg.value.sum().name("total") | ||
| subset_props = subset_agg.mutate(proportion=subset_agg.value / subset_total) | ||
| else: | ||
| subset_total = subset_df.groupby(index_subgroup_col)[value_col].sum() | ||
| subset_s = subset_df[value_col] / subset_total | ||
| subset_total = subset_agg.group_by(index_subgroup_col).aggregate(total=lambda t: t.value.sum()) | ||
| subset_props = ( | ||
| subset_agg.join(subset_total, index_subgroup_col) | ||
| .mutate(proportion=lambda t: t.value / t.total) | ||
mvanwyk marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| .drop("total") | ||
| ) | ||
|
|
||
| return ((subset_s / overall_s * 100) - offset).to_frame("index").reset_index() | ||
| result = subset_props.join(overall_props, group_cols).mutate( | ||
| index=lambda t: (t.proportion / t.proportion_overall * 100) - offset, | ||
| ) | ||
| return result.execute() | ||
mvanwyk marked this conversation as resolved.
Show resolved
Hide resolved
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,124 +8,106 @@ | |
| from pyretailscience.plots.index import get_indexes, plot | ||
|
|
||
| OFFSET_VALUE = 100 | ||
| OFFSET_THRESHOLD = -5 | ||
|
|
||
|
|
||
| def test_get_indexes_single_column(): | ||
| """Test that the function works with a single column index.""" | ||
| def test_get_indexes_basic(): | ||
| """Test get_indexes function with basic input to ensure it returns a valid DataFrame.""" | ||
| df = pd.DataFrame( | ||
| { | ||
| "group_col": ["A", "A", "B", "B", "C", "C"], | ||
| "filter_col": ["X", "Y", "X", "Y", "X", "Y"], | ||
| "value_col": [1, 2, 3, 4, 5, 6], | ||
| "category": ["A", "A", "B", "B", "C", "C"], | ||
| "value": [10, 20, 30, 40, 50, 60], | ||
| }, | ||
| ) | ||
| expected_output = pd.DataFrame({"group_col": ["A", "B", "C"], "index": [77.77777778, 100, 106.0606]}) | ||
| output = get_indexes( | ||
| df=df, | ||
| index_col="group_col", | ||
| df_index_filter=df["filter_col"] == "X", | ||
| value_col="value_col", | ||
| ) | ||
| pd.testing.assert_frame_equal(output, expected_output) | ||
| df_index_filter = [True, False, True, False, True, False] | ||
|
|
||
| result = get_indexes(df, df_index_filter, "category", "value") | ||
| assert isinstance(result, pd.DataFrame) | ||
| assert "category" in result.columns | ||
| assert "index" in result.columns | ||
| assert not result.empty | ||
|
|
||
|
|
||
| def test_get_indexes_two_columns(): | ||
| """Test that the function works with two columns as the index.""" | ||
| def test_get_indexes_with_subgroup(): | ||
| """Test get_indexes function when a subgroup column is provided.""" | ||
| df = pd.DataFrame( | ||
| { | ||
| "group_col1": ["A", "A", "B", "B", "C", "C", "A", "A", "B", "B", "C", "C"], | ||
| "group_col2": ["D", "D", "D", "D", "D", "D", "E", "E", "E", "E", "E", "E"], | ||
| "filter_col": ["X", "Y", "X", "Y", "X", "Y", "X", "Y", "X", "Y", "X", "Y"], | ||
| "value_col": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], | ||
| }, | ||
| ) | ||
| expected_output = pd.DataFrame( | ||
| { | ||
| "group_col2": ["D", "D", "D", "E", "E", "E"], | ||
| "group_col1": ["A", "B", "C", "A", "B", "C"], | ||
| "index": [77.77777778, 100, 106.0606, 98.51851852, 100, 100.9661836], | ||
| "subgroup": ["X", "X", "X", "Y", "Y", "Y"], | ||
| "category": ["A", "A", "B", "B", "C", "C"], | ||
| "value": [10, 20, 30, 40, 50, 60], | ||
| }, | ||
| ) | ||
| output = get_indexes( | ||
| df=df, | ||
| index_col="group_col1", | ||
| index_subgroup_col="group_col2", | ||
| df_index_filter=df["filter_col"] == "X", | ||
| value_col="value_col", | ||
| ) | ||
| pd.testing.assert_frame_equal(output, expected_output) | ||
| df_index_filter = [True, False, True, False, True, False] | ||
|
|
||
| result = get_indexes(df, df_index_filter, "category", "value", index_subgroup_col="subgroup") | ||
| assert isinstance(result, pd.DataFrame) | ||
| assert "category" in result.columns | ||
| assert "index" in result.columns | ||
| assert not result.empty | ||
|
|
||
| def test_get_indexes_with_offset(): | ||
| """Test that the function works with an offset parameter.""" | ||
|
|
||
| def test_get_indexes_invalid_filter(): | ||
| """Test get_indexes function with an invalid filter where all values are True.""" | ||
| df = pd.DataFrame( | ||
| { | ||
| "group_col": ["A", "A", "B", "B", "C", "C"], | ||
| "filter_col": ["X", "Y", "X", "Y", "X", "Y"], | ||
| "value_col": [1, 2, 3, 4, 5, 6], | ||
| "category": ["A", "B", "C"], | ||
| "value": [10, 20, 30], | ||
| }, | ||
| ) | ||
| expected_output = pd.DataFrame({"group_col": ["A", "B", "C"], "index": [-22.22222222, 0, 6.060606061]}) | ||
| output = get_indexes( | ||
| df=df, | ||
| index_col="group_col", | ||
| df_index_filter=df["filter_col"] == "X", | ||
| value_col="value_col", | ||
| offset=OFFSET_VALUE, # Replace magic number with the constant | ||
| ) | ||
| pd.testing.assert_frame_equal(output, expected_output) | ||
| df_index_filter = [True, True, True] # Invalid case | ||
|
|
||
| with pytest.raises(ValueError, match="The df_index_filter cannot be all True or all False."): | ||
| get_indexes(df, df_index_filter, "category", "value") | ||
|
|
||
|
|
||
| def test_get_indexes_with_agg_func(): | ||
| """Test that the function works with the nunique agg_func parameter.""" | ||
| def test_get_indexes_invalid_agg_func(): | ||
| """Test get_indexes function with an invalid aggregation function.""" | ||
| df = pd.DataFrame( | ||
| { | ||
| "group_col1": ["A", "A", "A", "A", "B", "B", "B", "B", "C", "C", "C", "C"], | ||
| "filter_col": ["X", "Y", "X", "Y", "X", "Y", "X", "Y", "X", "Y", "X", "Y"], | ||
| "value_col": [1, 1, 2, 2, 3, 3, 4, 4, 5, 6, 5, 8], | ||
| "category": ["A", "B", "C"], | ||
| "value": [10, 20, 30], | ||
| }, | ||
| ) | ||
| expected_output = pd.DataFrame( | ||
| df_index_filter = [True, False, True] | ||
|
|
||
| with pytest.raises(ValueError, match="Unsupported aggregation function."): | ||
| get_indexes(df, df_index_filter, "category", "value", agg_func="invalid_func") | ||
|
|
||
|
|
||
| def test_get_indexes_with_different_aggregations(): | ||
| """Test get_indexes function with various aggregation functions.""" | ||
| df = pd.DataFrame( | ||
| { | ||
| "group_col1": ["A", "B", "C"], | ||
| "index": [140, 140, 46.6666667], | ||
| "category": ["A", "A", "B", "B", "C", "C"], | ||
| "value": [10, 20, 30, 40, 50, 60], | ||
| }, | ||
| ) | ||
| output = get_indexes( | ||
| df=df, | ||
| index_col="group_col1", | ||
| df_index_filter=df["filter_col"] == "X", | ||
| value_col="value_col", | ||
| agg_func="nunique", | ||
| ) | ||
| pd.testing.assert_frame_equal(output, expected_output) | ||
| df_index_filter = [True, False, True, False, True, False] | ||
|
|
||
| for agg in ["sum", "mean", "max", "min", "nunique"]: | ||
| result = get_indexes(df, df_index_filter, "category", "value", agg_func=agg) | ||
| assert isinstance(result, pd.DataFrame) | ||
| assert "category" in result.columns | ||
| assert "index" in result.columns | ||
| assert not result.empty | ||
|
|
||
| def test_get_indexes_index_filter_all_same(): | ||
| """Test that the function raises a ValueError when all the values in the index filter are the same.""" | ||
|
|
||
| def test_get_indexes_with_offset(): | ||
| """Test get_indexes function with an offset value.""" | ||
| df = pd.DataFrame( | ||
| { | ||
| "group_col": ["A", "A", "B", "B", "C", "C"], | ||
| "filter_col": ["X", "X", "X", "X", "X", "X"], | ||
| "value_col": [1, 2, 3, 4, 5, 6], | ||
| "category": ["A", "B", "C"], | ||
| "value": [10, 20, 30], | ||
| }, | ||
| ) | ||
| # Assert a value error will be reaised | ||
| with pytest.raises(ValueError): | ||
| get_indexes( | ||
| df=df, | ||
| df_index_filter=[True, True, True, True, True, True], | ||
| index_col="group_col", | ||
| value_col="value_col", | ||
| ) | ||
|
|
||
| with pytest.raises(ValueError): | ||
| get_indexes( | ||
| df=df, | ||
| df_index_filter=[False, False, False, False, False, False], | ||
| index_col="group_col", | ||
| value_col="value_col", | ||
| ) | ||
| df_index_filter = [True, False, True] | ||
| result = get_indexes(df, df_index_filter, "category", "value", offset=5) | ||
|
|
||
| assert isinstance(result, pd.DataFrame) | ||
| assert "category" in result.columns | ||
| assert "index" in result.columns | ||
| assert not result.empty | ||
| assert all(result["index"] >= OFFSET_THRESHOLD) | ||
|
||
|
|
||
|
|
||
| class TestIndexPlot: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.