diff --git a/doc/source/whatsnew/v1.5.0.rst b/doc/source/whatsnew/v1.5.0.rst index f313b49cd198d..7a90d96926475 100644 --- a/doc/source/whatsnew/v1.5.0.rst +++ b/doc/source/whatsnew/v1.5.0.rst @@ -802,6 +802,7 @@ Performance improvements - Performance improvement in datetime arrays string formatting when one of the default strftime formats ``"%Y-%m-%d %H:%M:%S"`` or ``"%Y-%m-%d %H:%M:%S.%f"`` is used. (:issue:`44764`) - Performance improvement in :meth:`Series.to_sql` and :meth:`DataFrame.to_sql` (:class:`SQLiteTable`) when processing time arrays. (:issue:`44764`) - Performance improvements to :func:`read_sas` (:issue:`47403`, :issue:`47404`, :issue:`47405`) +- Performance improvement in ``argmax`` and ``argmin`` for :class:`arrays.SparseArray` (:issue:`34197`) - .. --------------------------------------------------------------------------- diff --git a/pandas/core/arrays/sparse/array.py b/pandas/core/arrays/sparse/array.py index 5653d87a4570b..bf65da0412642 100644 --- a/pandas/core/arrays/sparse/array.py +++ b/pandas/core/arrays/sparse/array.py @@ -42,7 +42,10 @@ from pandas.compat.numpy import function as nv from pandas.errors import PerformanceWarning from pandas.util._exceptions import find_stack_level -from pandas.util._validators import validate_insert_loc +from pandas.util._validators import ( + validate_bool_kwarg, + validate_insert_loc, +) from pandas.core.dtypes.astype import astype_nansafe from pandas.core.dtypes.cast import ( @@ -1636,6 +1639,45 @@ def _min_max(self, kind: Literal["min", "max"], skipna: bool) -> Scalar: else: return na_value_for_dtype(self.dtype.subtype, compat=False) + def _argmin_argmax(self, kind: Literal["argmin", "argmax"]) -> int: + + values = self._sparse_values + index = self._sparse_index.indices + mask = np.asarray(isna(values)) + func = np.argmax if kind == "argmax" else np.argmin + + idx = np.arange(values.shape[0]) + non_nans = values[~mask] + non_nan_idx = idx[~mask] + + _candidate = non_nan_idx[func(non_nans)] + candidate = index[_candidate] + + if isna(self.fill_value): + return candidate + if kind == "argmin" and self[candidate] < self.fill_value: + return candidate + if kind == "argmax" and self[candidate] > self.fill_value: + return candidate + _loc = self._first_fill_value_loc() + if _loc == -1: + # fill_value doesn't exist + return candidate + else: + return _loc + + def argmax(self, skipna: bool = True) -> int: + validate_bool_kwarg(skipna, "skipna") + if not skipna and self._hasna: + raise NotImplementedError + return self._argmin_argmax("argmax") + + def argmin(self, skipna: bool = True) -> int: + validate_bool_kwarg(skipna, "skipna") + if not skipna and self._hasna: + raise NotImplementedError + return self._argmin_argmax("argmin") + # ------------------------------------------------------------------------ # Ufuncs # ------------------------------------------------------------------------ diff --git a/pandas/tests/arrays/sparse/test_reductions.py b/pandas/tests/arrays/sparse/test_reductions.py index a33a282bb4869..2dd80c52f1419 100644 --- a/pandas/tests/arrays/sparse/test_reductions.py +++ b/pandas/tests/arrays/sparse/test_reductions.py @@ -268,3 +268,41 @@ def test_na_value_if_no_valid_values(self, func, data, dtype, expected): assert result is NaT or np.isnat(result) else: assert np.isnan(result) + + +class TestArgmaxArgmin: + @pytest.mark.parametrize( + "arr,argmax_expected,argmin_expected", + [ + (SparseArray([1, 2, 0, 1, 2]), 1, 2), + (SparseArray([-1, -2, 0, -1, -2]), 2, 1), + (SparseArray([np.nan, 1, 0, 0, np.nan, -1]), 1, 5), + (SparseArray([np.nan, 1, 0, 0, np.nan, 2]), 5, 2), + (SparseArray([np.nan, 1, 0, 0, np.nan, 2], fill_value=-1), 5, 2), + (SparseArray([np.nan, 1, 0, 0, np.nan, 2], fill_value=0), 5, 2), + (SparseArray([np.nan, 1, 0, 0, np.nan, 2], fill_value=1), 5, 2), + (SparseArray([np.nan, 1, 0, 0, np.nan, 2], fill_value=2), 5, 2), + (SparseArray([np.nan, 1, 0, 0, np.nan, 2], fill_value=3), 5, 2), + (SparseArray([0] * 10 + [-1], fill_value=0), 0, 10), + (SparseArray([0] * 10 + [-1], fill_value=-1), 0, 10), + (SparseArray([0] * 10 + [-1], fill_value=1), 0, 10), + (SparseArray([-1] + [0] * 10, fill_value=0), 1, 0), + (SparseArray([1] + [0] * 10, fill_value=0), 0, 1), + (SparseArray([-1] + [0] * 10, fill_value=-1), 1, 0), + (SparseArray([1] + [0] * 10, fill_value=1), 0, 1), + ], + ) + def test_argmax_argmin(self, arr, argmax_expected, argmin_expected): + argmax_result = arr.argmax() + argmin_result = arr.argmin() + assert argmax_result == argmax_expected + assert argmin_result == argmin_expected + + @pytest.mark.parametrize( + "arr,method", + [(SparseArray([]), "argmax"), (SparseArray([]), "argmin")], + ) + def test_empty_array(self, arr, method): + msg = f"attempt to get {method} of an empty sequence" + with pytest.raises(ValueError, match=msg): + arr.argmax() if method == "argmax" else arr.argmin()