Skip to content

Commit 2a6d7b7

Browse files
authored
BUG: groupby.agg with numba and as_index=False (#51228)
* BUG: groupby.agg with numba and as_index=False * smaller-diff implementation * Whatsnew
1 parent 48c99f2 commit 2a6d7b7

File tree

4 files changed

+12
-5
lines changed

4 files changed

+12
-5
lines changed

doc/source/whatsnew/v2.0.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -1301,6 +1301,7 @@ Groupby/resample/rolling
13011301
- Bug in :meth:`.DataFrameGroupBy.resample` raises ``KeyError`` when getting the result from a key list when resampling on time index (:issue:`50840`)
13021302
- Bug in :meth:`.DataFrameGroupBy.transform` and :meth:`.SeriesGroupBy.transform` would raise incorrectly when grouper had ``axis=1`` for ``"ngroup"`` argument (:issue:`45986`)
13031303
- Bug in :meth:`.DataFrameGroupBy.describe` produced incorrect results when data had duplicate columns (:issue:`50806`)
1304+
- Bug in :meth:`.DataFrameGroupBy.agg` with ``engine="numba"`` failing to respect ``as_index=False`` (:issue:`51228`)
13041305
-
13051306

13061307
Reshaping

pandas/core/groupby/generic.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1266,7 +1266,11 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs)
12661266
data, func, *args, engine_kwargs=engine_kwargs, **kwargs
12671267
)
12681268
index = self.grouper.result_index
1269-
return self.obj._constructor(result, index=index, columns=data.columns)
1269+
result = self.obj._constructor(result, index=index, columns=data.columns)
1270+
if not self.as_index:
1271+
result = self._insert_inaxis_grouper(result)
1272+
result.index = default_index(len(result))
1273+
return result
12701274

12711275
relabeling, func, columns, order = reconstruct_func(func, **kwargs)
12721276
func = maybe_mangle_lambdas(func)

pandas/tests/groupby/aggregate/test_numba.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ def incorrect_function(values, index):
5151
# Filter warnings when parallel=True and the function can't be parallelized by Numba
5252
@pytest.mark.parametrize("jit", [True, False])
5353
@pytest.mark.parametrize("pandas_obj", ["Series", "DataFrame"])
54-
def test_numba_vs_cython(jit, pandas_obj, nogil, parallel, nopython):
54+
@pytest.mark.parametrize("as_index", [True, False])
55+
def test_numba_vs_cython(jit, pandas_obj, nogil, parallel, nopython, as_index):
5556
def func_numba(values, index):
5657
return np.mean(values) * 2.7
5758

@@ -65,7 +66,7 @@ def func_numba(values, index):
6566
{0: ["a", "a", "b", "b", "a"], 1: [1.0, 2.0, 3.0, 4.0, 5.0]}, columns=[0, 1]
6667
)
6768
engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
68-
grouped = data.groupby(0)
69+
grouped = data.groupby(0, as_index=as_index)
6970
if pandas_obj == "Series":
7071
grouped = grouped[1]
7172

pandas/tests/groupby/transform/test_numba.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ def incorrect_function(values, index):
4848
# Filter warnings when parallel=True and the function can't be parallelized by Numba
4949
@pytest.mark.parametrize("jit", [True, False])
5050
@pytest.mark.parametrize("pandas_obj", ["Series", "DataFrame"])
51-
def test_numba_vs_cython(jit, pandas_obj, nogil, parallel, nopython):
51+
@pytest.mark.parametrize("as_index", [True, False])
52+
def test_numba_vs_cython(jit, pandas_obj, nogil, parallel, nopython, as_index):
5253
def func(values, index):
5354
return values + 1
5455

@@ -62,7 +63,7 @@ def func(values, index):
6263
{0: ["a", "a", "b", "b", "a"], 1: [1.0, 2.0, 3.0, 4.0, 5.0]}, columns=[0, 1]
6364
)
6465
engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
65-
grouped = data.groupby(0)
66+
grouped = data.groupby(0, as_index=as_index)
6667
if pandas_obj == "Series":
6768
grouped = grouped[1]
6869

0 commit comments

Comments
 (0)