Skip to content

Commit 783a68e

Browse files
authored
BUG: Groupby.filter with pd.NA (#51255)
1 parent cf2d8f9 commit 783a68e

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

pandas/core/groupby/generic.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,7 @@ def filter(self, func, dropna: bool = True, *args, **kwargs):
555555
# Interpret np.nan as False.
556556
def true_and_notna(x) -> bool:
557557
b = wrapper(x)
558-
return b and notna(b)
558+
return notna(b) and b
559559

560560
try:
561561
indices = [
@@ -1714,7 +1714,7 @@ def filter(self, func, dropna: bool = True, *args, **kwargs):
17141714

17151715
# interpret the result of the filter
17161716
if is_bool(res) or (is_scalar(res) and isna(res)):
1717-
if res and notna(res):
1717+
if notna(res) and res:
17181718
indices.append(self._get_index(name))
17191719
else:
17201720
# non scalars aren't allowed

pandas/tests/groupby/test_filters.py

+14
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,20 @@ def test_filter_nan_is_false():
172172
tm.assert_series_equal(g_s.filter(f), s[[]])
173173

174174

175+
def test_filter_pdna_is_false():
176+
# in particular, dont raise in filter trying to call bool(pd.NA)
177+
df = DataFrame({"A": np.arange(8), "B": list("aabbbbcc"), "C": np.arange(8)})
178+
ser = df["B"]
179+
g_df = df.groupby(df["B"])
180+
g_s = ser.groupby(ser)
181+
182+
func = lambda x: pd.NA
183+
res = g_df.filter(func)
184+
tm.assert_frame_equal(res, df.loc[[]])
185+
res = g_s.filter(func)
186+
tm.assert_series_equal(res, ser[[]])
187+
188+
175189
def test_filter_against_workaround():
176190
np.random.seed(0)
177191
# Series of ints

0 commit comments

Comments
 (0)