Skip to content

Commit 4505b7a

Browse files
committed
fix: refactor the code
1 parent 26459e9 commit 4505b7a

File tree

2 files changed

+23
-6
lines changed

2 files changed

+23
-6
lines changed

pyretailscience/plots/index.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -284,16 +284,15 @@ def get_indexes(
284284

285285
if index_subgroup_col is None:
286286
overall_total = overall_agg.value.sum().execute()
287-
overall_props = overall_agg.mutate(proportion=overall_agg.value / overall_total)
287+
overall_props = overall_agg.mutate(proportion_overall=overall_agg.value / overall_total)
288288
else:
289289
overall_total = overall_agg.group_by(index_subgroup_col).aggregate(total=lambda t: t.value.sum())
290290
overall_props = (
291291
overall_agg.join(overall_total, index_subgroup_col)
292-
.mutate(proportion=lambda t: t.value / t.total)
292+
.mutate(proportion_overall=lambda t: t.value / t.total)
293293
.drop("total")
294294
)
295295

296-
overall_props = overall_props.mutate(proportion_overall=overall_props.proportion).drop("proportion")
297296
table = table.filter(table[index_col] == value_to_index)
298297
subset_agg = table.group_by(group_cols).aggregate(value=agg_fn(table[value_col]))
299298

tests/plots/test_index.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Tests for the index plot module."""
22

3+
import ibis
34
import matplotlib.pyplot as plt
45
import numpy as np
56
import pandas as pd
@@ -8,7 +9,7 @@
89
from pyretailscience.plots.index import get_indexes, plot
910

1011
OFFSET_VALUE = 100
11-
OFFSET_THRESHOLD = -5
12+
OFFSET_THRESHOLD = 5
1213

1314

1415
def test_get_indexes_basic():
@@ -109,14 +110,14 @@ def test_get_indexes_with_offset():
109110
index_col="category",
110111
value_col="value",
111112
group_col="category",
112-
offset=5,
113+
offset=OFFSET_THRESHOLD,
113114
)
114115

115116
assert isinstance(result, pd.DataFrame)
116117
assert "category" in result.columns
117118
assert "index" in result.columns
118119
assert not result.empty
119-
assert all(result["index"] >= OFFSET_THRESHOLD)
120+
assert all(result["index"] >= -OFFSET_THRESHOLD)
120121

121122

122123
def test_get_indexes_single_column():
@@ -168,6 +169,23 @@ def test_get_indexes_two_columns():
168169
pd.testing.assert_frame_equal(output, expected_output)
169170

170171

172+
def test_get_indexes_with_ibis_table_input():
173+
"""Test that the get_indexes function works with an ibis Table."""
174+
df = pd.DataFrame(
175+
{
176+
"category": ["A", "B", "C"],
177+
"value": [10, 20, 30],
178+
},
179+
)
180+
table = ibis.memtable(df)
181+
182+
result = get_indexes(table, value_to_index="A", index_col="category", value_col="value", group_col="category")
183+
assert isinstance(result, pd.DataFrame)
184+
assert "category" in result.columns
185+
assert "index" in result.columns
186+
assert not result.empty
187+
188+
171189
class TestIndexPlot:
172190
"""Tests for the index_plot function."""
173191

0 commit comments

Comments
 (0)