Skip to content

Commit d3f3d5d

Browse files
authored
REF: avoid group_selection_context (#51096)
1 parent 3770dda commit d3f3d5d

File tree

2 files changed

+17
-19
lines changed

2 files changed

+17
-19
lines changed

pandas/core/groupby/generic.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -242,8 +242,7 @@ def apply(self, func, *args, **kwargs) -> Series:
242242
def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs):
243243

244244
if maybe_use_numba(engine):
245-
with self._group_selection_context():
246-
data = self._selected_obj
245+
data = self._obj_with_exclusions
247246
result = self._aggregate_with_numba(
248247
data.to_frame(), func, *args, engine_kwargs=engine_kwargs, **kwargs
249248
)
@@ -1234,8 +1233,7 @@ class DataFrameGroupBy(GroupBy[DataFrame]):
12341233
def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs):
12351234

12361235
if maybe_use_numba(engine):
1237-
with self._group_selection_context():
1238-
data = self._selected_obj
1236+
data = self._obj_with_exclusions
12391237
result = self._aggregate_with_numba(
12401238
data, func, *args, engine_kwargs=engine_kwargs, **kwargs
12411239
)

pandas/core/groupby/groupby.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1277,7 +1277,7 @@ def _wrap_applied_output(
12771277
# numba
12781278

12791279
@final
1280-
def _numba_prep(self, data):
1280+
def _numba_prep(self, data: DataFrame):
12811281
ids, _, ngroups = self.grouper.group_info
12821282
sorted_index = get_group_index_sorter(ids, ngroups)
12831283
sorted_ids = algorithms.take_nd(ids, sorted_index, allow_fill=False)
@@ -1337,7 +1337,9 @@ def _numba_agg_general(
13371337
return data._constructor(result, index=index, **result_kwargs)
13381338

13391339
@final
1340-
def _transform_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs):
1340+
def _transform_with_numba(
1341+
self, data: DataFrame, func, *args, engine_kwargs=None, **kwargs
1342+
):
13411343
"""
13421344
Perform groupby transform routine with the numba engine.
13431345
@@ -1363,7 +1365,9 @@ def _transform_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs)
13631365
return result.take(np.argsort(sorted_index), axis=0)
13641366

13651367
@final
1366-
def _aggregate_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs):
1368+
def _aggregate_with_numba(
1369+
self, data: DataFrame, func, *args, engine_kwargs=None, **kwargs
1370+
):
13671371
"""
13681372
Perform groupby aggregation routine with the numba engine.
13691373
@@ -1529,15 +1533,13 @@ def _agg_general(
15291533
npfunc: Callable,
15301534
):
15311535

1532-
with self._group_selection_context():
1533-
# try a cython aggregation if we can
1534-
result = self._cython_agg_general(
1535-
how=alias,
1536-
alt=npfunc,
1537-
numeric_only=numeric_only,
1538-
min_count=min_count,
1539-
)
1540-
return result.__finalize__(self.obj, method="groupby")
1536+
result = self._cython_agg_general(
1537+
how=alias,
1538+
alt=npfunc,
1539+
numeric_only=numeric_only,
1540+
min_count=min_count,
1541+
)
1542+
return result.__finalize__(self.obj, method="groupby")
15411543

15421544
def _agg_py_fallback(
15431545
self, values: ArrayLike, ndim: int, alt: Callable
@@ -1647,9 +1649,7 @@ def _cython_transform(
16471649
def _transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs):
16481650

16491651
if maybe_use_numba(engine):
1650-
# TODO: tests with self._selected_obj.ndim == 1 on DataFrameGroupBy
1651-
with self._group_selection_context():
1652-
data = self._selected_obj
1652+
data = self._obj_with_exclusions
16531653
df = data if data.ndim == 2 else data.to_frame()
16541654
result = self._transform_with_numba(
16551655
df, func, *args, engine_kwargs=engine_kwargs, **kwargs

0 commit comments

Comments
 (0)