From f62f55beeeba80fc8f7119fded56a04c75e2ded0 Mon Sep 17 00:00:00 2001 From: debnathshoham Date: Mon, 27 Sep 2021 00:13:36 +0530 Subject: [PATCH 1/5] BUG: groupby mean fails for complex --- pandas/_libs/groupby.pyx | 14 ++++++++++---- pandas/core/groupby/ops.py | 2 +- pandas/tests/groupby/aggregate/test_aggregate.py | 8 ++++++++ 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/pandas/_libs/groupby.pyx b/pandas/_libs/groupby.pyx index b8700aa473d03..9682339b65371 100644 --- a/pandas/_libs/groupby.pyx +++ b/pandas/_libs/groupby.pyx @@ -480,6 +480,12 @@ ctypedef fused add_t: complex128_t object +ctypedef fused common_t: + float64_t + float32_t + complex64_t + complex128_t + @cython.wraparound(False) @cython.boundscheck(False) @@ -669,9 +675,9 @@ def group_var(floating[:, ::1] out, @cython.wraparound(False) @cython.boundscheck(False) -def group_mean(floating[:, ::1] out, +def group_mean(common_t[:, ::1] out, int64_t[::1] counts, - ndarray[floating, ndim=2] values, + ndarray[common_t, ndim=2] values, const intp_t[::1] labels, Py_ssize_t min_count=-1, bint is_datetimelike=False, @@ -711,8 +717,8 @@ def group_mean(floating[:, ::1] out, cdef: Py_ssize_t i, j, N, K, lab, ncounts = len(counts) - floating val, count, y, t, nan_val - floating[:, ::1] sumx, compensation + common_t val, count, y, t, nan_val + common_t[:, ::1] sumx, compensation int64_t[:, ::1] nobs Py_ssize_t len_values = len(values), len_labels = len(labels) diff --git a/pandas/core/groupby/ops.py b/pandas/core/groupby/ops.py index a7ac2c7a1dba6..5759f4c9c9562 100644 --- a/pandas/core/groupby/ops.py +++ b/pandas/core/groupby/ops.py @@ -307,7 +307,7 @@ def _get_result_dtype(self, dtype: DtypeObj) -> DtypeObj: elif how in ["mean", "median", "var"]: if isinstance(dtype, (BooleanDtype, _IntegerDtype)): return Float64Dtype() - elif is_float_dtype(dtype): + elif is_float_dtype(dtype) or is_complex_dtype(dtype): return dtype elif is_numeric_dtype(dtype): return np.dtype(np.float64) diff --git a/pandas/tests/groupby/aggregate/test_aggregate.py b/pandas/tests/groupby/aggregate/test_aggregate.py index 7bb850d38340f..c6e9a147f4698 100644 --- a/pandas/tests/groupby/aggregate/test_aggregate.py +++ b/pandas/tests/groupby/aggregate/test_aggregate.py @@ -1301,3 +1301,11 @@ def test_group_mean_datetime64_nat(input_data, expected_output): result = data.groupby([0, 0, 0]).mean() tm.assert_series_equal(result, expected) + + +def test_groupby_mean_complex(): + # GH#43701 + data = Series(np.arange(20).reshape(10, 2).dot([1, 2j])) + result = data.groupby(data.index % 2).mean() + expected = Series([8 + 18j, 10 + 22j]) + tm.assert_series_equal(result, expected) From 1457083d7f1fa57486ff1d58632521128164a319 Mon Sep 17 00:00:00 2001 From: debnathshoham Date: Mon, 27 Sep 2021 00:23:28 +0530 Subject: [PATCH 2/5] whatsnew --- doc/source/whatsnew/v1.4.0.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/source/whatsnew/v1.4.0.rst b/doc/source/whatsnew/v1.4.0.rst index 243bcf6900d2e..78dd84f19bc94 100644 --- a/doc/source/whatsnew/v1.4.0.rst +++ b/doc/source/whatsnew/v1.4.0.rst @@ -483,6 +483,7 @@ Groupby/resample/rolling - Bug in :meth:`DataFrame.rolling.corr` when the :class:`DataFrame` columns was a :class:`MultiIndex` (:issue:`21157`) - Bug in :meth:`DataFrame.groupby.rolling` when specifying ``on`` and calling ``__getitem__`` would subsequently return incorrect results (:issue:`43355`) - Bug in :meth:`GroupBy.apply` with time-based :class:`Grouper` objects incorrectly raising ``ValueError`` in corner cases where the grouping vector contains a ``NaT`` (:issue:`43500`, :issue:`43515`) +- Bug in :meth:`GroupBy.mean` failing with ``complex`` dtype (:issue:`43701`) Reshaping ^^^^^^^^^ From 33f499ff10f903ab584def325d1e3dbf7d0249d9 Mon Sep 17 00:00:00 2001 From: debnathshoham Date: Tue, 28 Sep 2021 00:42:44 +0530 Subject: [PATCH 3/5] added groupby.sum test for complex --- pandas/tests/groupby/aggregate/test_aggregate.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/pandas/tests/groupby/aggregate/test_aggregate.py b/pandas/tests/groupby/aggregate/test_aggregate.py index c6e9a147f4698..2babdef71daf6 100644 --- a/pandas/tests/groupby/aggregate/test_aggregate.py +++ b/pandas/tests/groupby/aggregate/test_aggregate.py @@ -1309,3 +1309,11 @@ def test_groupby_mean_complex(): result = data.groupby(data.index % 2).mean() expected = Series([8 + 18j, 10 + 22j]) tm.assert_series_equal(result, expected) + + +def test_groupby_sum_complex(): + # GH#43701 + data = Series(np.arange(20).reshape(10, 2).dot([1, 2j])) + result = data.groupby(data.index % 2).sum() + expected = Series([40 + 90j, 50 + 110j]) + tm.assert_series_equal(result, expected) From b5d9a54827b77338070b72f1459fa282cbaa487b Mon Sep 17 00:00:00 2001 From: debnathshoham Date: Wed, 29 Sep 2021 00:30:21 +0530 Subject: [PATCH 4/5] renamed new fused dtype --- pandas/_libs/groupby.pyx | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pandas/_libs/groupby.pyx b/pandas/_libs/groupby.pyx index 9682339b65371..57d2792567b51 100644 --- a/pandas/_libs/groupby.pyx +++ b/pandas/_libs/groupby.pyx @@ -480,7 +480,7 @@ ctypedef fused add_t: complex128_t object -ctypedef fused common_t: +ctypedef fused mean_t: float64_t float32_t complex64_t @@ -675,9 +675,9 @@ def group_var(floating[:, ::1] out, @cython.wraparound(False) @cython.boundscheck(False) -def group_mean(common_t[:, ::1] out, +def group_mean(mean_t[:, ::1] out, int64_t[::1] counts, - ndarray[common_t, ndim=2] values, + ndarray[mean_t, ndim=2] values, const intp_t[::1] labels, Py_ssize_t min_count=-1, bint is_datetimelike=False, @@ -717,8 +717,8 @@ def group_mean(common_t[:, ::1] out, cdef: Py_ssize_t i, j, N, K, lab, ncounts = len(counts) - common_t val, count, y, t, nan_val - common_t[:, ::1] sumx, compensation + mean_t val, count, y, t, nan_val + mean_t[:, ::1] sumx, compensation int64_t[:, ::1] nobs Py_ssize_t len_values = len(values), len_labels = len(labels) From 416dd697be40803bca0e84c794ea66e88bef9796 Mon Sep 17 00:00:00 2001 From: debnathshoham Date: Wed, 29 Sep 2021 22:50:04 +0530 Subject: [PATCH 5/5] added tests that raise for complex --- .../tests/groupby/aggregate/test_aggregate.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/pandas/tests/groupby/aggregate/test_aggregate.py b/pandas/tests/groupby/aggregate/test_aggregate.py index 2babdef71daf6..805e39671e084 100644 --- a/pandas/tests/groupby/aggregate/test_aggregate.py +++ b/pandas/tests/groupby/aggregate/test_aggregate.py @@ -1303,17 +1303,21 @@ def test_group_mean_datetime64_nat(input_data, expected_output): tm.assert_series_equal(result, expected) -def test_groupby_mean_complex(): +@pytest.mark.parametrize( + "func, output", [("mean", [8 + 18j, 10 + 22j]), ("sum", [40 + 90j, 50 + 110j])] +) +def test_groupby_complex(func, output): # GH#43701 data = Series(np.arange(20).reshape(10, 2).dot([1, 2j])) - result = data.groupby(data.index % 2).mean() - expected = Series([8 + 18j, 10 + 22j]) + result = data.groupby(data.index % 2).agg(func) + expected = Series(output) tm.assert_series_equal(result, expected) -def test_groupby_sum_complex(): +@pytest.mark.parametrize("func", ["min", "max", "var"]) +def test_groupby_complex_raises(func): # GH#43701 data = Series(np.arange(20).reshape(10, 2).dot([1, 2j])) - result = data.groupby(data.index % 2).sum() - expected = Series([40 + 90j, 50 + 110j]) - tm.assert_series_equal(result, expected) + msg = "No matching signature found" + with pytest.raises(TypeError, match=msg): + data.groupby(data.index % 2).agg(func)