Skip to content

Commit 6eb4810

Browse files
committed
BUG: Aggregation on arrow array return same type.
Signed-off-by: Liang Yan <[email protected]>
1 parent 1f94a1b commit 6eb4810

File tree

2 files changed

+34
-2
lines changed

2 files changed

+34
-2
lines changed

pandas/core/dtypes/cast.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,7 @@ def maybe_cast_pointwise_result(
463463
"""
464464

465465
if isinstance(dtype, ExtensionDtype):
466-
if not isinstance(dtype, (CategoricalDtype, DatetimeTZDtype)):
466+
if not isinstance(dtype, (CategoricalDtype, DatetimeTZDtype, ArrowDtype)):
467467
# TODO: avoid this special-casing
468468
# We have to special case categorical so as not to upcast
469469
# things like counts back to categorical
@@ -473,7 +473,17 @@ def maybe_cast_pointwise_result(
473473
result = _maybe_cast_to_extension_array(cls, result, dtype=dtype)
474474
else:
475475
result = _maybe_cast_to_extension_array(cls, result)
476-
476+
elif isinstance(dtype, ArrowDtype):
477+
pyarrow_type = convert_dtypes(result, dtype_backend="pyarrow")
478+
if isinstance(pyarrow_type, ExtensionDtype):
479+
cls = pyarrow_type.construct_array_type()
480+
result = _maybe_cast_to_extension_array(cls, result)
481+
else:
482+
cls = dtype.construct_array_type()
483+
if same_dtype:
484+
result = _maybe_cast_to_extension_array(cls, result, dtype=dtype)
485+
else:
486+
result = _maybe_cast_to_extension_array(cls, result)
477487
elif (numeric_only and dtype.kind in "iufcb") or not numeric_only:
478488
result = maybe_downcast_to_dtype(result, dtype)
479489

pandas/tests/groupby/aggregate/test_aggregate.py

+22
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import functools
66
from functools import partial
77
import re
8+
import typing
89

910
import numpy as np
1011
import pytest
@@ -1630,4 +1631,25 @@ def test_groupby_agg_extension_timedelta_cumsum_with_named_aggregation():
16301631
)
16311632
gb = df.groupby("grps")
16321633
result = gb.agg(td=("td", "cumsum"))
1634+
1635+
1636+
@pytest.mark.skipif(
1637+
not typing.TYPE_CHECKING, reason="let pyarrow to be imported in dtypes.py"
1638+
)
1639+
def test_agg_arrow_type():
1640+
df = DataFrame.from_dict(
1641+
{
1642+
"category": ["A"] * 10 + ["B"] * 10,
1643+
"bool_numpy": [True] * 5 + [False] * 5 + [True] * 5 + [False] * 5,
1644+
}
1645+
)
1646+
df["bool_arrow"] = df["bool_numpy"].astype("bool[pyarrow]")
1647+
result = df.groupby("category").agg(lambda x: x.sum() / x.count())
1648+
expected = DataFrame(
1649+
{
1650+
"bool_numpy": [0.5, 0.5],
1651+
"bool_arrow": Series([0.5, 0.5]).astype("double[pyarrow]").values,
1652+
},
1653+
index=Index(["A", "B"], name="category"),
1654+
)
16331655
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)