diff --git a/doc/source/whatsnew/v3.0.0.rst b/doc/source/whatsnew/v3.0.0.rst index 8d3ac0e396430..38a829200e311 100644 --- a/doc/source/whatsnew/v3.0.0.rst +++ b/doc/source/whatsnew/v3.0.0.rst @@ -91,6 +91,7 @@ Other enhancements - Support passing a :class:`Iterable[Hashable]` input to :meth:`DataFrame.drop_duplicates` (:issue:`59237`) - Support reading Stata 102-format (Stata 1) dta files (:issue:`58978`) - Support reading Stata 110-format (Stata 7) dta files (:issue:`47176`) +- :class:`ExtensionArray` has gained a default, non-performant ``round`` method (:issue:`49387`) .. --------------------------------------------------------------------------- .. _whatsnew_300.notable_bug_fixes: diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 0b90bcea35100..8e24a566a9cc7 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -1259,7 +1259,10 @@ def round(self, decimals: int = 0, *args, **kwargs) -> Self: DataFrame.round : Round values of a DataFrame. Series.round : Round values of a Series. """ - return type(self)(pc.round(self._pa_array, ndigits=decimals)) + if not self.dtype._is_numeric or self.dtype._is_boolean: + raise TypeError("Cannot round non-numeric type.") + result = pc.round(self._pa_array, ndigits=decimals) + return type(self)(result.cast(self._pa_array.type)) @doc(ExtensionArray.searchsorted) def searchsorted( diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index d0048e122051a..2d65db2819b80 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -2515,6 +2515,22 @@ def _mode(self, dropna: bool = True) -> Self: result, _ = mode(self, dropna=dropna) return result # type: ignore[return-value] + def round(self, decimals: int = 0, *args, **kwargs) -> Self: + # Implementer note: This is a non-optimized default implementation. + # Implementers are encouraged to override this method to avoid + # elementwise rounding. + if self.dtype._is_boolean: + return self.copy() + if not self.dtype._is_numeric: + raise TypeError(f"Cannot round {type(self)} dtype as it is non-numeric") + return self._from_sequence( + [ + round(element) if not element_isna else element + for (element, element_isna) in zip(self, self.isna()) + ], + dtype=self.dtype, + ) + def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs): if any( isinstance(other, (ABCSeries, ABCIndex, ABCDataFrame)) for other in inputs diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index 994d7b1d0081c..5c93cc2125454 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -2209,7 +2209,7 @@ def _round(self, freq, mode, ambiguous, nonexistent): return self._simple_new(result, dtype=self.dtype) @Appender((_round_doc + _round_example).format(op="round")) - def round( + def round( # type: ignore[override] self, freq, ambiguous: TimeAmbiguous = "raise", diff --git a/pandas/core/arrays/sparse/array.py b/pandas/core/arrays/sparse/array.py index 137dbb6e4d139..d567403429e11 100644 --- a/pandas/core/arrays/sparse/array.py +++ b/pandas/core/arrays/sparse/array.py @@ -1679,6 +1679,15 @@ def argmin(self, skipna: bool = True) -> int: raise ValueError("Encountered an NA value with skipna=False") return self._argmin_argmax("argmin") + def round(self, decimals: int = 0, *args, **kwargs) -> Self: + new_values = np.array( + [ + round(element) if not isna(element) else element + for element in self.sp_values + ] + ) + return self._simple_new(new_values, self._sparse_index, self.dtype) + # ------------------------------------------------------------------------ # Ufuncs # ------------------------------------------------------------------------ diff --git a/pandas/core/internals/blocks.py b/pandas/core/internals/blocks.py index 6aa5062b8ed86..d25eb3e01a377 100644 --- a/pandas/core/internals/blocks.py +++ b/pandas/core/internals/blocks.py @@ -1499,11 +1499,7 @@ def round(self, decimals: int) -> Self: """ if not self.is_numeric or self.is_bool: return self.copy(deep=False) - # TODO: round only defined on BaseMaskedArray - # Series also does this, so would need to fix both places - # error: Item "ExtensionArray" of "Union[ndarray[Any, Any], ExtensionArray]" - # has no attribute "round" - values = self.values.round(decimals) # type: ignore[union-attr] + values = self.values.round(decimals) refs = None if values is self.values: diff --git a/pandas/tests/extension/base/methods.py b/pandas/tests/extension/base/methods.py index fd9fec0cb490c..edd2ed77b85c8 100644 --- a/pandas/tests/extension/base/methods.py +++ b/pandas/tests/extension/base/methods.py @@ -726,3 +726,20 @@ def test_equals(self, data, na_value, as_series, box): def test_equals_same_data_different_object(self, data): # https://github.com/pandas-dev/pandas/issues/34660 assert pd.Series(data).equals(pd.Series(data)) + + def test_round(self, data): + if not data.dtype._is_numeric: + msg = "non-numeric" + with pytest.raises(TypeError, match=msg): + data.round() + elif data.dtype._is_boolean: + result = pd.Series(data).round() + expected = pd.Series(data) + tm.assert_series_equal(result, expected) + else: + result = pd.Series(data).round() + expected = pd.Series( + [round(element) if pd.notna(element) else element for element in data], + dtype=data.dtype, + ) + tm.assert_series_equal(result, expected) diff --git a/pandas/tests/extension/test_datetime.py b/pandas/tests/extension/test_datetime.py index 356d5352f41f4..35ec8730a3a64 100644 --- a/pandas/tests/extension/test_datetime.py +++ b/pandas/tests/extension/test_datetime.py @@ -143,6 +143,10 @@ def check_reduce(self, ser: pd.Series, op_name: str, skipna: bool): else: return super().check_reduce(ser, op_name, skipna) + @pytest.mark.skip("DatetimeArray uses a different function signature for round") + def test_round(self): + pass + class Test2DCompat(base.NDArrayBacked2DTests): pass