diff --git a/pandas/core/dtypes/missing.py b/pandas/core/dtypes/missing.py index d9fbc3ed122fa..9a72dee8d87ca 100644 --- a/pandas/core/dtypes/missing.py +++ b/pandas/core/dtypes/missing.py @@ -556,12 +556,12 @@ def infer_fill_value(val): return np.nan -def maybe_fill(arr, fill_value=np.nan): +def maybe_fill(arr: np.ndarray) -> np.ndarray: """ - if we have a compatible fill_value and arr dtype, then fill + Fill numpy.ndarray with NaN, unless we have a integer or boolean dtype. """ - if isna_compat(arr, fill_value): - arr.fill(fill_value) + if arr.dtype.kind not in ("u", "i", "b"): + arr.fill(np.nan) return arr diff --git a/pandas/core/groupby/ops.py b/pandas/core/groupby/ops.py index e505359987eb3..74e96015b4544 100644 --- a/pandas/core/groupby/ops.py +++ b/pandas/core/groupby/ops.py @@ -31,7 +31,7 @@ import pandas._libs.groupby as libgroupby import pandas._libs.reduction as libreduction from pandas._typing import ( - ArrayLike, + DtypeObj, F, FrameOrSeries, Shape, @@ -46,7 +46,6 @@ maybe_downcast_to_dtype, ) from pandas.core.dtypes.common import ( - ensure_float, ensure_float64, ensure_int64, ensure_int_or_float, @@ -491,7 +490,9 @@ def _get_cython_func_and_vals( return func, values @final - def _disallow_invalid_ops(self, values: ArrayLike, how: str): + def _disallow_invalid_ops( + self, dtype: DtypeObj, how: str, is_numeric: bool = False + ): """ Check if we can do this operation with our cython functions. @@ -501,7 +502,9 @@ def _disallow_invalid_ops(self, values: ArrayLike, how: str): This is either not a valid function for this dtype, or valid but not implemented in cython. """ - dtype = values.dtype + if is_numeric: + # never an invalid op for those dtypes, so return early as fastpath + return if is_categorical_dtype(dtype) or is_sparse(dtype): # categoricals are only 1d, so we @@ -589,32 +592,34 @@ def _cython_operation( # as we can have 1D ExtensionArrays that we need to treat as 2D assert axis == 1, axis + dtype = values.dtype + is_numeric = is_numeric_dtype(dtype) + # can we do this operation with our cython functions # if not raise NotImplementedError - self._disallow_invalid_ops(values, how) + self._disallow_invalid_ops(dtype, how, is_numeric) - if is_extension_array_dtype(values.dtype): + if is_extension_array_dtype(dtype): return self._ea_wrap_cython_operation( kind, values, how, axis, min_count, **kwargs ) - is_datetimelike = needs_i8_conversion(values.dtype) - is_numeric = is_numeric_dtype(values.dtype) + is_datetimelike = needs_i8_conversion(dtype) if is_datetimelike: values = values.view("int64") is_numeric = True - elif is_bool_dtype(values.dtype): + elif is_bool_dtype(dtype): values = ensure_int_or_float(values) - elif is_integer_dtype(values): + elif is_integer_dtype(dtype): # we use iNaT for the missing value on ints # so pre-convert to guard this condition if (values == iNaT).any(): values = ensure_float64(values) else: values = ensure_int_or_float(values) - elif is_numeric and not is_complex_dtype(values): - values = ensure_float64(ensure_float(values)) + elif is_numeric and not is_complex_dtype(dtype): + values = ensure_float64(values) else: values = values.astype(object) @@ -649,20 +654,18 @@ def _cython_operation( codes, _, _ = self.group_info if kind == "aggregate": - result = maybe_fill(np.empty(out_shape, dtype=out_dtype), fill_value=np.nan) + result = maybe_fill(np.empty(out_shape, dtype=out_dtype)) counts = np.zeros(self.ngroups, dtype=np.int64) result = self._aggregate(result, counts, values, codes, func, min_count) elif kind == "transform": - result = maybe_fill( - np.empty(values.shape, dtype=out_dtype), fill_value=np.nan - ) + result = maybe_fill(np.empty(values.shape, dtype=out_dtype)) # TODO: min_count result = self._transform( result, values, codes, func, is_datetimelike, **kwargs ) - if is_integer_dtype(result) and not is_datetimelike: + if is_integer_dtype(result.dtype) and not is_datetimelike: mask = result == iNaT if mask.any(): result = result.astype("float64") @@ -682,9 +685,9 @@ def _cython_operation( # e.g. if we are int64 and need to restore to datetime64/timedelta64 # "rank" is the only member of cython_cast_blocklist we get here dtype = maybe_cast_result_dtype(orig_values.dtype, how) - # error: Argument 2 to "maybe_downcast_to_dtype" has incompatible type - # "Union[dtype[Any], ExtensionDtype]"; expected "Union[str, dtype[Any]]" - result = maybe_downcast_to_dtype(result, dtype) # type: ignore[arg-type] + # error: Incompatible types in assignment (expression has type + # "Union[ExtensionArray, ndarray]", variable has type "ndarray") + result = maybe_downcast_to_dtype(result, dtype) # type: ignore[assignment] return result