Skip to content

Commit 4e0e497

Browse files
committed
chore: allow for upper and lower boundaries to be None for clipping
1 parent b8ad25c commit 4e0e497

File tree

2 files changed

+93
-12
lines changed

2 files changed

+93
-12
lines changed

pyretailscience/plots/histogram.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def plot(
9696
if isinstance(df, pd.Series):
9797
df = df.to_frame(name=value_col[0])
9898

99-
if range_lower is not None and range_upper is not None:
99+
if (range_lower is not None) or (range_upper is not None):
100100
df = _apply_range_clipping(
101101
df=df,
102102
value_col=value_col,
@@ -163,18 +163,18 @@ def _prepare_value_col(df: pd.DataFrame | pd.Series, value_col: str | list[str]
163163
def _apply_range_clipping(
164164
df: pd.DataFrame,
165165
value_col: list[str],
166-
range_lower: float,
167-
range_upper: float,
168-
range_method: Literal["clip", "fillna"],
166+
range_lower: float | None = None,
167+
range_upper: float | None = None,
168+
range_method: Literal["clip", "fillna"] = "fillna",
169169
) -> pd.DataFrame:
170170
"""Applies range clipping or filling based on the provided method and returns the modified dataframe.
171171
172172
Args:
173173
df (pd.DataFrame): The dataframe to apply range clipping to.
174174
value_col (list of str): The column(s) to apply clipping or filling to.
175-
range_lower (float): Lower bound for clipping or filling NA values.
176-
range_upper (float): Upper bound for clipping or filling NA values.
177-
range_method (Literal, optional): Whether to "clip" values outside the range or "fillna". Defaults to "clip".
175+
range_lower (float | None, optional): Lower bound for clipping or filling NA values.
176+
range_upper (float | None, optional): Upper bound for clipping or filling NA values.
177+
range_method (Literal, optional): Whether to "clip" values outside the range or "fillna". Defaults to "fillna".
178178
179179
Returns:
180180
pd.DataFrame: The modified dataframe with the clipping or filling applied.
@@ -184,12 +184,20 @@ def _apply_range_clipping(
184184
raise ValueError(error_msg)
185185

186186
if range_method == "clip":
187+
# Clip values based on the provided lower and upper bounds
187188
return df.assign(**{col: df[col].clip(lower=range_lower, upper=range_upper) for col in value_col})
188189

189-
# create a single boolean mask for all columns at once, which can be more efficient for large
190-
# DataFrames with multiple value columns.
191-
mask = ((range_lower is None) | (df[value_col] >= range_lower)) & ((range_upper is None) | (df[value_col] <= range_upper))
192-
return df.assign(**{col: df[col].where(mask[col], np.nan) for col in value_col})
190+
# For the "fillna" method, we will create a mask for the valid range and replace out-of-range values with NaN
191+
def apply_mask(col: str) -> pd.Series:
192+
mask = pd.Series([True] * len(df))
193+
if range_lower is not None:
194+
mask &= df[col] >= range_lower
195+
if range_upper is not None:
196+
mask &= df[col] <= range_upper
197+
return df[col].where(mask, np.nan)
198+
199+
# Apply the mask to each column
200+
return df.assign(**{col: apply_mask(col) for col in value_col})
193201

194202

195203
def _get_num_histograms(df: pd.DataFrame, value_col: list[str], group_col: str | None) -> int:

tests/plots/test_histogram.py

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def test_plot_enforces_range_clipping(sample_dataframe):
105105

106106

107107
@pytest.mark.usefixtures("_mock_color_generators", "_mock_gu_functions")
108-
def test_plot_with_range_fillna(sample_dataframe, mocker):
108+
def test_plot_with_range_fillna(sample_dataframe):
109109
"""Test the plot function with range fillna."""
110110
_, ax = plt.subplots()
111111
range_lower = 3
@@ -130,6 +130,79 @@ def test_plot_with_range_fillna(sample_dataframe, mocker):
130130
assert all(range_lower <= val + np.finfo(np.float64).eps <= range_upper for val in clipped_values)
131131

132132

133+
@pytest.mark.usefixtures("_mock_color_generators", "_mock_gu_functions")
134+
def test_plot_with_range_lower_none(sample_dataframe):
135+
"""Test the plot function with range_lower=None (no lower bound) and a specific upper bound."""
136+
_, ax = plt.subplots()
137+
range_upper = 8 # No lower bound
138+
139+
result_ax = histogram.plot(
140+
df=sample_dataframe,
141+
value_col="value_1",
142+
ax=ax,
143+
title="Test Histogram with Upper Bound Only",
144+
range_lower=None,
145+
range_upper=range_upper,
146+
range_method="clip",
147+
)
148+
149+
# Get the data limits from the resulting Axes object
150+
x_data = result_ax.patches
151+
clipped_values = [patch.get_x() for patch in x_data]
152+
153+
# Ensure that the x values (bars' positions) respect the upper bound, but no lower bound is applied
154+
assert all(val + np.finfo(np.float64).eps <= range_upper for val in clipped_values)
155+
156+
157+
@pytest.mark.usefixtures("_mock_color_generators", "_mock_gu_functions")
158+
def test_plot_with_range_upper_none(sample_dataframe):
159+
"""Test the plot function with range_upper=None (no upper bound) and a specific lower bound."""
160+
_, ax = plt.subplots()
161+
range_lower = 3 # No upper bound
162+
163+
result_ax = histogram.plot(
164+
df=sample_dataframe,
165+
value_col="value_1",
166+
ax=ax,
167+
title="Test Histogram with Lower Bound Only",
168+
range_lower=range_lower,
169+
range_upper=None,
170+
range_method="clip",
171+
)
172+
173+
# Get the data limits from the resulting Axes object
174+
x_data = result_ax.patches
175+
clipped_values = [patch.get_x() for patch in x_data]
176+
177+
# Ensure that the x values (bars' positions) respect the lower bound, but no upper bound is applied
178+
assert all(range_lower <= val + np.finfo(np.float64).eps for val in clipped_values)
179+
180+
181+
@pytest.mark.usefixtures("_mock_color_generators", "_mock_gu_functions")
182+
def test_plot_fillna_outside_range(sample_dataframe):
183+
"""Test the fillna method, ensuring values outside the range are replaced by NaN."""
184+
_, ax = plt.subplots()
185+
range_lower = 3
186+
range_upper = 8
187+
188+
result_ax = histogram.plot(
189+
df=sample_dataframe,
190+
value_col="value_1",
191+
ax=ax,
192+
title="Test Histogram with Range Fillna",
193+
range_lower=range_lower,
194+
range_upper=range_upper,
195+
range_method="fillna",
196+
)
197+
198+
# Extract data from the resulting Axes
199+
x_data = result_ax.patches
200+
clipped_values = [patch.get_x() for patch in x_data]
201+
202+
# Ensure that values outside the range are not plotted (NaN)
203+
assert all(range_lower <= val + np.finfo(np.float64).eps <= range_upper for val in clipped_values)
204+
205+
133206
@pytest.mark.usefixtures("_mock_color_generators", "_mock_gu_functions")
134207
def test_plot_single_histogram_series(sample_series):
135208
"""Test the plot function with a pandas series."""

0 commit comments

Comments
 (0)