diff --git a/doc/source/whatsnew/v1.4.3.rst b/doc/source/whatsnew/v1.4.3.rst index 23c8ad63bf7bb..7c09eec212d69 100644 --- a/doc/source/whatsnew/v1.4.3.rst +++ b/doc/source/whatsnew/v1.4.3.rst @@ -16,6 +16,7 @@ Fixed regressions ~~~~~~~~~~~~~~~~~ - Fixed regression in :meth:`DataFrame.nsmallest` led to wrong results when ``np.nan`` in the sorting column (:issue:`46589`) - Fixed regression in :func:`read_fwf` raising ``ValueError`` when ``widths`` was specified with ``usecols`` (:issue:`46580`) +- Fixed regression in :meth:`.Groupby.transform` and :meth:`.Groupby.agg` failing with ``engine="numba"`` when the index was a :class:`MultiIndex` (:issue:`46867`) - Fixed regression is :meth:`.Styler.to_latex` and :meth:`.Styler.to_html` where ``buf`` failed in combination with ``encoding`` (:issue:`47053`) .. --------------------------------------------------------------------------- diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 0203d54e0de86..f7c89b6e7dc49 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -1310,7 +1310,16 @@ def _numba_prep(self, data): sorted_ids = algorithms.take_nd(ids, sorted_index, allow_fill=False) sorted_data = data.take(sorted_index, axis=self.axis).to_numpy() - sorted_index_data = data.index.take(sorted_index).to_numpy() + if len(self.grouper.groupings) > 1: + raise NotImplementedError( + "More than 1 grouping labels are not supported with engine='numba'" + ) + # GH 46867 + index_data = data.index + if isinstance(index_data, MultiIndex): + group_key = self.grouper.groupings[0].name + index_data = index_data.get_level_values(group_key) + sorted_index_data = index_data.take(sorted_index).to_numpy() starts, ends = lib.generate_slices(sorted_ids, ngroups) return ( diff --git a/pandas/tests/groupby/aggregate/test_numba.py b/pandas/tests/groupby/aggregate/test_numba.py index ba58ac27284b8..9f71c2c2fa0b6 100644 --- a/pandas/tests/groupby/aggregate/test_numba.py +++ b/pandas/tests/groupby/aggregate/test_numba.py @@ -211,3 +211,30 @@ def func_kwargs(values, index): ) expected = DataFrame({"value": [1.0, 1.0, 1.0]}) tm.assert_frame_equal(result, expected) + + +@td.skip_if_no("numba") +def test_multiindex_one_key(nogil, parallel, nopython): + def numba_func(values, index): + return 1 + + df = DataFrame([{"A": 1, "B": 2, "C": 3}]).set_index(["A", "B"]) + engine_kwargs = {"nopython": nopython, "nogil": nogil, "parallel": parallel} + result = df.groupby("A").agg( + numba_func, engine="numba", engine_kwargs=engine_kwargs + ) + expected = DataFrame([1.0], index=Index([1], name="A"), columns=["C"]) + tm.assert_frame_equal(result, expected) + + +@td.skip_if_no("numba") +def test_multiindex_multi_key_not_supported(nogil, parallel, nopython): + def numba_func(values, index): + return 1 + + df = DataFrame([{"A": 1, "B": 2, "C": 3}]).set_index(["A", "B"]) + engine_kwargs = {"nopython": nopython, "nogil": nogil, "parallel": parallel} + with pytest.raises(NotImplementedError, match="More than 1 grouping labels"): + df.groupby(["A", "B"]).agg( + numba_func, engine="numba", engine_kwargs=engine_kwargs + ) diff --git a/pandas/tests/groupby/transform/test_numba.py b/pandas/tests/groupby/transform/test_numba.py index a404e0b9304cc..1b8570dbdc21d 100644 --- a/pandas/tests/groupby/transform/test_numba.py +++ b/pandas/tests/groupby/transform/test_numba.py @@ -199,3 +199,30 @@ def func_kwargs(values, index): ) expected = DataFrame({"value": [1.0, 1.0, 1.0]}) tm.assert_frame_equal(result, expected) + + +@td.skip_if_no("numba") +def test_multiindex_one_key(nogil, parallel, nopython): + def numba_func(values, index): + return 1 + + df = DataFrame([{"A": 1, "B": 2, "C": 3}]).set_index(["A", "B"]) + engine_kwargs = {"nopython": nopython, "nogil": nogil, "parallel": parallel} + result = df.groupby("A").transform( + numba_func, engine="numba", engine_kwargs=engine_kwargs + ) + expected = DataFrame([{"A": 1, "B": 2, "C": 1.0}]).set_index(["A", "B"]) + tm.assert_frame_equal(result, expected) + + +@td.skip_if_no("numba") +def test_multiindex_multi_key_not_supported(nogil, parallel, nopython): + def numba_func(values, index): + return 1 + + df = DataFrame([{"A": 1, "B": 2, "C": 3}]).set_index(["A", "B"]) + engine_kwargs = {"nopython": nopython, "nogil": nogil, "parallel": parallel} + with pytest.raises(NotImplementedError, match="More than 1 grouping labels"): + df.groupby(["A", "B"]).transform( + numba_func, engine="numba", engine_kwargs=engine_kwargs + )