diff --git a/doc/source/whatsnew/v1.1.0.rst b/doc/source/whatsnew/v1.1.0.rst index b05911a4d1eb0..eb18e1e07547b 100644 --- a/doc/source/whatsnew/v1.1.0.rst +++ b/doc/source/whatsnew/v1.1.0.rst @@ -444,6 +444,8 @@ Groupby/resample/rolling - Bug in :meth:`DataFrameGroupby.transform` produces incorrect result with transformation functions (:issue:`30918`) - Bug in :meth:`GroupBy.count` causes segmentation fault when grouped-by column contains NaNs (:issue:`32841`) - Bug in :meth:`DataFrame.groupby` and :meth:`Series.groupby` produces inconsistent type when aggregating Boolean series (:issue:`32894`) +- Bug in :meth:`SeriesGroupBy.first`, :meth:`SeriesGroupBy.last`, :meth:`SeriesGroupBy.min`, and :meth:`SeriesGroupBy.max` returning floats when applied to nullable Booleans (:issue:`33071`) +- Bug in :meth:`DataFrameGroupBy.agg` with dictionary input losing ``ExtensionArray`` dtypes (:issue:`32194`) - Bug in :meth:`DataFrame.resample` where an ``AmbiguousTimeError`` would be raised when the resulting timezone aware :class:`DatetimeIndex` had a DST transition at midnight (:issue:`25758`) Reshaping diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index 223cc43d158e6..7dda6850ba4f7 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -33,6 +33,7 @@ ensure_str, is_bool, is_bool_dtype, + is_categorical_dtype, is_complex, is_complex_dtype, is_datetime64_dtype, @@ -279,12 +280,15 @@ def maybe_cast_result(result, obj: "Series", numeric_only: bool = False, how: st dtype = maybe_cast_result_dtype(dtype, how) if not is_scalar(result): - if is_extension_array_dtype(dtype) and dtype.kind != "M": - # The result may be of any type, cast back to original - # type if it's compatible. - if len(result) and isinstance(result[0], dtype.type): - cls = dtype.construct_array_type() - result = maybe_cast_to_extension_array(cls, result, dtype=dtype) + if ( + is_extension_array_dtype(dtype) + and not is_categorical_dtype(dtype) + and dtype.kind != "M" + ): + # We have to special case categorical so as not to upcast + # things like counts back to categorical + cls = dtype.construct_array_type() + result = maybe_cast_to_extension_array(cls, result, dtype=dtype) elif numeric_only and is_numeric_dtype(dtype) or not numeric_only: result = maybe_downcast_to_dtype(result, dtype) diff --git a/pandas/tests/groupby/test_nth.py b/pandas/tests/groupby/test_nth.py index b1476f1059d84..e1bc058508bee 100644 --- a/pandas/tests/groupby/test_nth.py +++ b/pandas/tests/groupby/test_nth.py @@ -384,6 +384,32 @@ def test_first_last_tz_multi_column(method, ts, alpha): tm.assert_frame_equal(result, expected) +@pytest.mark.parametrize( + "values", + [ + pd.array([True, False], dtype="boolean"), + pd.array([1, 2], dtype="Int64"), + pd.to_datetime(["2020-01-01", "2020-02-01"]), + pd.to_timedelta([1, 2], unit="D"), + ], +) +@pytest.mark.parametrize("function", ["first", "last", "min", "max"]) +def test_first_last_extension_array_keeps_dtype(values, function): + # https://github.com/pandas-dev/pandas/issues/33071 + # https://github.com/pandas-dev/pandas/issues/32194 + df = DataFrame({"a": [1, 2], "b": values}) + grouped = df.groupby("a") + idx = Index([1, 2], name="a") + expected_series = Series(values, name="b", index=idx) + expected_frame = DataFrame({"b": values}, index=idx) + + result_series = getattr(grouped["b"], function)() + tm.assert_series_equal(result_series, expected_series) + + result_frame = grouped.agg({"b": function}) + tm.assert_frame_equal(result_frame, expected_frame) + + def test_nth_multi_index_as_expected(): # PR 9090, related to issue 8979 # test nth on MultiIndex diff --git a/pandas/tests/resample/test_datetime_index.py b/pandas/tests/resample/test_datetime_index.py index bbb294bc109c1..f15d39e9e6456 100644 --- a/pandas/tests/resample/test_datetime_index.py +++ b/pandas/tests/resample/test_datetime_index.py @@ -122,9 +122,7 @@ def test_resample_integerarray(): result = ts.resample("3T").mean() expected = Series( - [1, 4, 7], - index=pd.date_range("1/1/2000", periods=3, freq="3T"), - dtype="float64", + [1, 4, 7], index=pd.date_range("1/1/2000", periods=3, freq="3T"), dtype="Int64", ) tm.assert_series_equal(result, expected)