Skip to content

PERF: efficient argmax/argmin for SparseArray #47779

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Jul 27, 2022
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.5.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -802,6 +802,7 @@ Performance improvements
- Performance improvement in datetime arrays string formatting when one of the default strftime formats ``"%Y-%m-%d %H:%M:%S"`` or ``"%Y-%m-%d %H:%M:%S.%f"`` is used. (:issue:`44764`)
- Performance improvement in :meth:`Series.to_sql` and :meth:`DataFrame.to_sql` (:class:`SQLiteTable`) when processing time arrays. (:issue:`44764`)
- Performance improvements to :func:`read_sas` (:issue:`47403`, :issue:`47404`, :issue:`47405`)
- Performance improvement in ``argmax`` and ``argmin`` for :class:`arrays.SparseArray` (:issue:`34197`)
-

.. ---------------------------------------------------------------------------
Expand Down
44 changes: 43 additions & 1 deletion pandas/core/arrays/sparse/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@
from pandas.compat.numpy import function as nv
from pandas.errors import PerformanceWarning
from pandas.util._exceptions import find_stack_level
from pandas.util._validators import validate_insert_loc
from pandas.util._validators import (
validate_bool_kwarg,
validate_insert_loc,
)

from pandas.core.dtypes.astype import astype_nansafe
from pandas.core.dtypes.cast import (
Expand Down Expand Up @@ -1636,6 +1639,45 @@ def _min_max(self, kind: Literal["min", "max"], skipna: bool) -> Scalar:
else:
return na_value_for_dtype(self.dtype.subtype, compat=False)

def _argmin_argmax(self, kind: Literal["argmin", "argmax"]) -> int:

values = self._sparse_values
index = self._sparse_index.indices
mask = np.asarray(isna(values))
func = np.argmax if kind == "argmax" else np.argmin

idx = np.arange(values.shape[0])
non_nans = values[~mask]
non_nan_idx = idx[~mask]

_candidate = non_nan_idx[func(non_nans)]
candidate = index[_candidate]

if isna(self.fill_value):
return candidate
if kind == "argmin" and self[candidate] < self.fill_value:
return candidate
if kind == "argmax" and self[candidate] > self.fill_value:
return candidate
_loc = self._first_fill_value_loc()
if _loc == -1:
# fill_value doesn't exist
return candidate
else:
return _loc

def argmax(self, skipna: bool = True) -> int:
validate_bool_kwarg(skipna, "skipna")
if not skipna and self._hasna:
raise NotImplementedError
return self._argmin_argmax("argmax")

def argmin(self, skipna: bool = True) -> int:
validate_bool_kwarg(skipna, "skipna")
if not skipna and self._hasna:
raise NotImplementedError
return self._argmin_argmax("argmin")

# ------------------------------------------------------------------------
# Ufuncs
# ------------------------------------------------------------------------
Expand Down
38 changes: 38 additions & 0 deletions pandas/tests/arrays/sparse/test_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,3 +268,41 @@ def test_na_value_if_no_valid_values(self, func, data, dtype, expected):
assert result is NaT or np.isnat(result)
else:
assert np.isnan(result)


class TestArgmaxArgmin:
@pytest.mark.parametrize(
"arr,argmax_expected,argmin_expected",
[
(SparseArray([1, 2, 0, 1, 2]), 1, 2),
(SparseArray([-1, -2, 0, -1, -2]), 2, 1),
(SparseArray([np.nan, 1, 0, 0, np.nan, -1]), 1, 5),
(SparseArray([np.nan, 1, 0, 0, np.nan, 2]), 5, 2),
(SparseArray([np.nan, 1, 0, 0, np.nan, 2], fill_value=-1), 5, 2),
(SparseArray([np.nan, 1, 0, 0, np.nan, 2], fill_value=0), 5, 2),
(SparseArray([np.nan, 1, 0, 0, np.nan, 2], fill_value=1), 5, 2),
(SparseArray([np.nan, 1, 0, 0, np.nan, 2], fill_value=2), 5, 2),
(SparseArray([np.nan, 1, 0, 0, np.nan, 2], fill_value=3), 5, 2),
(SparseArray([0] * 10 + [-1], fill_value=0), 0, 10),
(SparseArray([0] * 10 + [-1], fill_value=-1), 0, 10),
(SparseArray([0] * 10 + [-1], fill_value=1), 0, 10),
(SparseArray([-1] + [0] * 10, fill_value=0), 1, 0),
(SparseArray([1] + [0] * 10, fill_value=0), 0, 1),
(SparseArray([-1] + [0] * 10, fill_value=-1), 1, 0),
(SparseArray([1] + [0] * 10, fill_value=1), 0, 1),
],
)
def test_argmax_argmin(self, arr, argmax_expected, argmin_expected):
argmax_result = arr.argmax()
argmin_result = arr.argmin()
assert argmax_result == argmax_expected
assert argmin_result == argmin_expected

@pytest.mark.parametrize(
"arr,method",
[(SparseArray([]), "argmax"), (SparseArray([]), "argmin")],
)
def test_empty_array(self, arr, method):
msg = f"attempt to get {method} of an empty sequence"
with pytest.raises(ValueError, match=msg):
arr.argmax() if method == "argmax" else arr.argmin()