Skip to content

Commit 1a57596

Browse files
authored
[BUG] Sum of grouped bool has inconsistent dtype (#32894)
1 parent 883379c commit 1a57596

File tree

9 files changed

+152
-80
lines changed

9 files changed

+152
-80
lines changed

doc/source/whatsnew/v1.1.0.rst

+2
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,8 @@ Groupby/resample/rolling
400400

401401
- Bug in :meth:`GroupBy.apply` raises ``ValueError`` when the ``by`` axis is not sorted and has duplicates and the applied ``func`` does not mutate passed in objects (:issue:`30667`)
402402
- Bug in :meth:`DataFrameGroupby.transform` produces incorrect result with transformation functions (:issue:`30918`)
403+
- Bug in :meth:`DataFrame.groupby` and :meth:`Series.groupby` produces inconsistent type when aggregating Boolean series (:issue:`32894`)
404+
403405

404406
Reshaping
405407
^^^^^^^^^

pandas/core/arrays/__init__.py

-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
ExtensionArray,
33
ExtensionOpsMixin,
44
ExtensionScalarOpsMixin,
5-
try_cast_to_ea,
65
)
76
from pandas.core.arrays.boolean import BooleanArray
87
from pandas.core.arrays.categorical import Categorical
@@ -19,7 +18,6 @@
1918
"ExtensionArray",
2019
"ExtensionOpsMixin",
2120
"ExtensionScalarOpsMixin",
22-
"try_cast_to_ea",
2321
"BooleanArray",
2422
"Categorical",
2523
"DatetimeArray",

pandas/core/arrays/base.py

+2-24
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from pandas.util._decorators import Appender, Substitution
2020
from pandas.util._validators import validate_fillna_kwargs
2121

22+
from pandas.core.dtypes.cast import maybe_cast_to_extension_array
2223
from pandas.core.dtypes.common import is_array_like, is_list_like
2324
from pandas.core.dtypes.dtypes import ExtensionDtype
2425
from pandas.core.dtypes.generic import ABCIndexClass, ABCSeries
@@ -32,29 +33,6 @@
3233
_extension_array_shared_docs: Dict[str, str] = dict()
3334

3435

35-
def try_cast_to_ea(cls_or_instance, obj, dtype=None):
36-
"""
37-
Call to `_from_sequence` that returns the object unchanged on Exception.
38-
39-
Parameters
40-
----------
41-
cls_or_instance : ExtensionArray subclass or instance
42-
obj : arraylike
43-
Values to pass to cls._from_sequence
44-
dtype : ExtensionDtype, optional
45-
46-
Returns
47-
-------
48-
ExtensionArray or obj
49-
"""
50-
try:
51-
result = cls_or_instance._from_sequence(obj, dtype=dtype)
52-
except Exception:
53-
# We can't predict what downstream EA constructors may raise
54-
result = obj
55-
return result
56-
57-
5836
class ExtensionArray:
5937
"""
6038
Abstract base class for custom 1-D array types.
@@ -1214,7 +1192,7 @@ def _maybe_convert(arr):
12141192
# https://github.com/pandas-dev/pandas/issues/22850
12151193
# We catch all regular exceptions here, and fall back
12161194
# to an ndarray.
1217-
res = try_cast_to_ea(self, arr)
1195+
res = maybe_cast_to_extension_array(type(self), arr)
12181196
if not isinstance(res, type(self)):
12191197
# exception raised in _from_sequence; ensure we have ndarray
12201198
res = np.asarray(arr)

pandas/core/arrays/categorical.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@
1919
)
2020
from pandas.util._validators import validate_bool_kwarg, validate_fillna_kwargs
2121

22-
from pandas.core.dtypes.cast import coerce_indexer_dtype, maybe_infer_to_datetimelike
22+
from pandas.core.dtypes.cast import (
23+
coerce_indexer_dtype,
24+
maybe_cast_to_extension_array,
25+
maybe_infer_to_datetimelike,
26+
)
2327
from pandas.core.dtypes.common import (
2428
ensure_int64,
2529
ensure_object,
@@ -47,11 +51,7 @@
4751
from pandas.core.accessor import PandasDelegate, delegate_names
4852
import pandas.core.algorithms as algorithms
4953
from pandas.core.algorithms import _get_data_algo, factorize, take, take_1d, unique1d
50-
from pandas.core.arrays.base import (
51-
ExtensionArray,
52-
_extension_array_shared_docs,
53-
try_cast_to_ea,
54-
)
54+
from pandas.core.arrays.base import ExtensionArray, _extension_array_shared_docs
5555
from pandas.core.base import NoNewAttributesMixin, PandasObject, _shared_docs
5656
import pandas.core.common as com
5757
from pandas.core.construction import array, extract_array, sanitize_array
@@ -2568,7 +2568,7 @@ def _get_codes_for_values(values, categories):
25682568
# scalar objects. e.g.
25692569
# Categorical(array[Period, Period], categories=PeriodIndex(...))
25702570
cls = categories.dtype.construct_array_type()
2571-
values = try_cast_to_ea(cls, values)
2571+
values = maybe_cast_to_extension_array(cls, values)
25722572
if not isinstance(values, cls):
25732573
# exception raised in _from_sequence
25742574
values = ensure_object(values)

pandas/core/dtypes/cast.py

+92-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
iNaT,
1717
)
1818
from pandas._libs.tslibs.timezones import tz_compare
19-
from pandas._typing import Dtype
19+
from pandas._typing import Dtype, DtypeObj
2020
from pandas.util._validators import validate_bool_kwarg
2121

2222
from pandas.core.dtypes.common import (
@@ -246,6 +246,97 @@ def trans(x):
246246
return result
247247

248248

249+
def maybe_cast_result(
250+
result, obj: ABCSeries, numeric_only: bool = False, how: str = ""
251+
):
252+
"""
253+
Try casting result to a different type if appropriate
254+
255+
Parameters
256+
----------
257+
result : array-like
258+
Result to cast.
259+
obj : ABCSeries
260+
Input series from which result was calculated.
261+
numeric_only : bool, default False
262+
Whether to cast only numerics or datetimes as well.
263+
how : str, default ""
264+
How the result was computed.
265+
266+
Returns
267+
-------
268+
result : array-like
269+
result maybe casted to the dtype.
270+
"""
271+
if obj.ndim > 1:
272+
dtype = obj._values.dtype
273+
else:
274+
dtype = obj.dtype
275+
dtype = maybe_cast_result_dtype(dtype, how)
276+
277+
if not is_scalar(result):
278+
if is_extension_array_dtype(dtype) and dtype.kind != "M":
279+
# The result may be of any type, cast back to original
280+
# type if it's compatible.
281+
if len(result) and isinstance(result[0], dtype.type):
282+
cls = dtype.construct_array_type()
283+
result = maybe_cast_to_extension_array(cls, result, dtype=dtype)
284+
285+
elif numeric_only and is_numeric_dtype(dtype) or not numeric_only:
286+
result = maybe_downcast_to_dtype(result, dtype)
287+
288+
return result
289+
290+
291+
def maybe_cast_result_dtype(dtype: DtypeObj, how: str) -> DtypeObj:
292+
"""
293+
Get the desired dtype of a result based on the
294+
input dtype and how it was computed.
295+
296+
Parameters
297+
----------
298+
dtype : DtypeObj
299+
Input dtype.
300+
how : str
301+
How the result was computed.
302+
303+
Returns
304+
-------
305+
DtypeObj
306+
The desired dtype of the result.
307+
"""
308+
d = {
309+
(np.dtype(np.bool), "add"): np.dtype(np.int64),
310+
(np.dtype(np.bool), "cumsum"): np.dtype(np.int64),
311+
(np.dtype(np.bool), "sum"): np.dtype(np.int64),
312+
}
313+
return d.get((dtype, how), dtype)
314+
315+
316+
def maybe_cast_to_extension_array(cls, obj, dtype=None):
317+
"""
318+
Call to `_from_sequence` that returns the object unchanged on Exception.
319+
320+
Parameters
321+
----------
322+
cls : ExtensionArray subclass
323+
obj : arraylike
324+
Values to pass to cls._from_sequence
325+
dtype : ExtensionDtype, optional
326+
327+
Returns
328+
-------
329+
ExtensionArray or obj
330+
"""
331+
assert isinstance(cls, type), f"must pass a type: {cls}"
332+
try:
333+
result = cls._from_sequence(obj, dtype=dtype)
334+
except Exception:
335+
# We can't predict what downstream EA constructors may raise
336+
result = obj
337+
return result
338+
339+
249340
def maybe_upcast_putmask(result: np.ndarray, mask: np.ndarray, other):
250341
"""
251342
A safe version of putmask that potentially upcasts the result.

pandas/core/groupby/generic.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
from pandas.util._decorators import Appender, Substitution
3535

3636
from pandas.core.dtypes.cast import (
37+
maybe_cast_result,
38+
maybe_cast_result_dtype,
3739
maybe_convert_objects,
3840
maybe_downcast_numeric,
3941
maybe_downcast_to_dtype,
@@ -526,7 +528,7 @@ def _transform_fast(self, result, func_nm: str) -> Series:
526528
cast = self._transform_should_cast(func_nm)
527529
out = algorithms.take_1d(result._values, ids)
528530
if cast:
529-
out = self._try_cast(out, self.obj)
531+
out = maybe_cast_result(out, self.obj, how=func_nm)
530532
return Series(out, index=self.obj.index, name=self.obj.name)
531533

532534
def filter(self, func, dropna=True, *args, **kwargs):
@@ -1072,8 +1074,10 @@ def _cython_agg_blocks(
10721074
assert not isinstance(result, DataFrame)
10731075

10741076
if result is not no_result:
1075-
# see if we can cast the block back to the original dtype
1076-
result = maybe_downcast_numeric(result, block.dtype)
1077+
# see if we can cast the block to the desired dtype
1078+
# this may not be the original dtype
1079+
dtype = maybe_cast_result_dtype(block.dtype, how)
1080+
result = maybe_downcast_numeric(result, dtype)
10771081

10781082
if block.is_extension and isinstance(result, np.ndarray):
10791083
# e.g. block.values was an IntegerArray
@@ -1175,7 +1179,7 @@ def _aggregate_item_by_item(self, func, *args, **kwargs) -> DataFrame:
11751179

11761180
else:
11771181
if cast:
1178-
result[item] = self._try_cast(result[item], data)
1182+
result[item] = maybe_cast_result(result[item], data)
11791183

11801184
result_columns = obj.columns
11811185
if cannot_agg:
@@ -1460,7 +1464,7 @@ def _transform_fast(self, result: DataFrame, func_nm: str) -> DataFrame:
14601464
# TODO: we have no test cases that get here with EA dtypes;
14611465
# try_cast may not be needed if EAs never get here
14621466
if cast:
1463-
res = self._try_cast(res, obj.iloc[:, i])
1467+
res = maybe_cast_result(res, obj.iloc[:, i], how=func_nm)
14641468
output.append(res)
14651469

14661470
return DataFrame._from_arrays(output, columns=result.columns, index=obj.index)

pandas/core/groupby/groupby.py

+7-38
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,10 @@ class providing the base-class of operations.
3939
from pandas.errors import AbstractMethodError
4040
from pandas.util._decorators import Appender, Substitution, cache_readonly
4141

42-
from pandas.core.dtypes.cast import maybe_downcast_to_dtype
42+
from pandas.core.dtypes.cast import maybe_cast_result
4343
from pandas.core.dtypes.common import (
4444
ensure_float,
4545
is_datetime64_dtype,
46-
is_extension_array_dtype,
4746
is_integer_dtype,
4847
is_numeric_dtype,
4948
is_object_dtype,
@@ -53,7 +52,7 @@ class providing the base-class of operations.
5352

5453
from pandas.core import nanops
5554
import pandas.core.algorithms as algorithms
56-
from pandas.core.arrays import Categorical, DatetimeArray, try_cast_to_ea
55+
from pandas.core.arrays import Categorical, DatetimeArray
5756
from pandas.core.base import DataError, PandasObject, SelectionMixin
5857
import pandas.core.common as com
5958
from pandas.core.frame import DataFrame
@@ -792,36 +791,6 @@ def _cumcount_array(self, ascending: bool = True):
792791
rev[sorter] = np.arange(count, dtype=np.intp)
793792
return out[rev].astype(np.int64, copy=False)
794793

795-
def _try_cast(self, result, obj, numeric_only: bool = False):
796-
"""
797-
Try to cast the result to our obj original type,
798-
we may have roundtripped through object in the mean-time.
799-
800-
If numeric_only is True, then only try to cast numerics
801-
and not datetimelikes.
802-
803-
"""
804-
if obj.ndim > 1:
805-
dtype = obj._values.dtype
806-
else:
807-
dtype = obj.dtype
808-
809-
if not is_scalar(result):
810-
if is_extension_array_dtype(dtype) and dtype.kind != "M":
811-
# The function can return something of any type, so check
812-
# if the type is compatible with the calling EA.
813-
# datetime64tz is handled correctly in agg_series,
814-
# so is excluded here.
815-
816-
if len(result) and isinstance(result[0], dtype.type):
817-
cls = dtype.construct_array_type()
818-
result = try_cast_to_ea(cls, result, dtype=dtype)
819-
820-
elif numeric_only and is_numeric_dtype(dtype) or not numeric_only:
821-
result = maybe_downcast_to_dtype(result, dtype)
822-
823-
return result
824-
825794
def _transform_should_cast(self, func_nm: str) -> bool:
826795
"""
827796
Parameters
@@ -852,7 +821,7 @@ def _cython_transform(self, how: str, numeric_only: bool = True, **kwargs):
852821
continue
853822

854823
if self._transform_should_cast(how):
855-
result = self._try_cast(result, obj)
824+
result = maybe_cast_result(result, obj, how=how)
856825

857826
key = base.OutputKey(label=name, position=idx)
858827
output[key] = result
@@ -895,12 +864,12 @@ def _cython_agg_general(
895864
assert len(agg_names) == result.shape[1]
896865
for result_column, result_name in zip(result.T, agg_names):
897866
key = base.OutputKey(label=result_name, position=idx)
898-
output[key] = self._try_cast(result_column, obj)
867+
output[key] = maybe_cast_result(result_column, obj, how=how)
899868
idx += 1
900869
else:
901870
assert result.ndim == 1
902871
key = base.OutputKey(label=name, position=idx)
903-
output[key] = self._try_cast(result, obj)
872+
output[key] = maybe_cast_result(result, obj, how=how)
904873
idx += 1
905874

906875
if len(output) == 0:
@@ -929,7 +898,7 @@ def _python_agg_general(self, func, *args, **kwargs):
929898

930899
assert result is not None
931900
key = base.OutputKey(label=name, position=idx)
932-
output[key] = self._try_cast(result, obj, numeric_only=True)
901+
output[key] = maybe_cast_result(result, obj, numeric_only=True)
933902

934903
if len(output) == 0:
935904
return self._python_apply_general(f)
@@ -944,7 +913,7 @@ def _python_agg_general(self, func, *args, **kwargs):
944913
if is_numeric_dtype(values.dtype):
945914
values = ensure_float(values)
946915

947-
output[key] = self._try_cast(values[mask], result)
916+
output[key] = maybe_cast_result(values[mask], result)
948917

949918
return self._wrap_aggregated_output(output)
950919

pandas/core/series.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,11 @@
2727
from pandas.util._decorators import Appender, Substitution, doc
2828
from pandas.util._validators import validate_bool_kwarg, validate_percentile
2929

30-
from pandas.core.dtypes.cast import convert_dtypes, validate_numeric_casting
30+
from pandas.core.dtypes.cast import (
31+
convert_dtypes,
32+
maybe_cast_to_extension_array,
33+
validate_numeric_casting,
34+
)
3135
from pandas.core.dtypes.common import (
3236
_is_unorderable_exception,
3337
ensure_platform_int,
@@ -59,7 +63,7 @@
5963
import pandas as pd
6064
from pandas.core import algorithms, base, generic, nanops, ops
6165
from pandas.core.accessor import CachedAccessor
62-
from pandas.core.arrays import ExtensionArray, try_cast_to_ea
66+
from pandas.core.arrays import ExtensionArray
6367
from pandas.core.arrays.categorical import CategoricalAccessor
6468
from pandas.core.arrays.sparse import SparseAccessor
6569
import pandas.core.common as com
@@ -2721,7 +2725,7 @@ def combine(self, other, func, fill_value=None) -> "Series":
27212725
# TODO: can we do this for only SparseDtype?
27222726
# The function can return something of any type, so check
27232727
# if the type is compatible with the calling EA.
2724-
new_values = try_cast_to_ea(self._values, new_values)
2728+
new_values = maybe_cast_to_extension_array(type(self._values), new_values)
27252729
return self._constructor(new_values, index=new_index, name=new_name)
27262730

27272731
def combine_first(self, other) -> "Series":

0 commit comments

Comments
 (0)