Skip to content

Commit 8d7a379

Browse files
GYHHAHAmroeschke
andauthored
PERF: efficient argmax/argmin for SparseArray (#47779)
* Update test_reductions.py * Update v1.5.0.rst * Update array.py * Update base.py * Update sorting.py * fix format * Update array.py * Update base.py * Update sorting.py * Update base.py * Update array.py * Update test_reductions.py * fix format * fix import * Update test_reductions.py * Update array.py * move to perf * Update doc/source/whatsnew/v1.5.0.rst Co-authored-by: Matthew Roeschke <[email protected]>
1 parent e12b318 commit 8d7a379

File tree

3 files changed

+82
-1
lines changed

3 files changed

+82
-1
lines changed

doc/source/whatsnew/v1.5.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -802,6 +802,7 @@ Performance improvements
802802
- Performance improvement in datetime arrays string formatting when one of the default strftime formats ``"%Y-%m-%d %H:%M:%S"`` or ``"%Y-%m-%d %H:%M:%S.%f"`` is used. (:issue:`44764`)
803803
- Performance improvement in :meth:`Series.to_sql` and :meth:`DataFrame.to_sql` (:class:`SQLiteTable`) when processing time arrays. (:issue:`44764`)
804804
- Performance improvements to :func:`read_sas` (:issue:`47403`, :issue:`47404`, :issue:`47405`)
805+
- Performance improvement in ``argmax`` and ``argmin`` for :class:`arrays.SparseArray` (:issue:`34197`)
805806
-
806807

807808
.. ---------------------------------------------------------------------------

pandas/core/arrays/sparse/array.py

+43-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,10 @@
4242
from pandas.compat.numpy import function as nv
4343
from pandas.errors import PerformanceWarning
4444
from pandas.util._exceptions import find_stack_level
45-
from pandas.util._validators import validate_insert_loc
45+
from pandas.util._validators import (
46+
validate_bool_kwarg,
47+
validate_insert_loc,
48+
)
4649

4750
from pandas.core.dtypes.astype import astype_nansafe
4851
from pandas.core.dtypes.cast import (
@@ -1646,6 +1649,45 @@ def _min_max(self, kind: Literal["min", "max"], skipna: bool) -> Scalar:
16461649
else:
16471650
return na_value_for_dtype(self.dtype.subtype, compat=False)
16481651

1652+
def _argmin_argmax(self, kind: Literal["argmin", "argmax"]) -> int:
1653+
1654+
values = self._sparse_values
1655+
index = self._sparse_index.indices
1656+
mask = np.asarray(isna(values))
1657+
func = np.argmax if kind == "argmax" else np.argmin
1658+
1659+
idx = np.arange(values.shape[0])
1660+
non_nans = values[~mask]
1661+
non_nan_idx = idx[~mask]
1662+
1663+
_candidate = non_nan_idx[func(non_nans)]
1664+
candidate = index[_candidate]
1665+
1666+
if isna(self.fill_value):
1667+
return candidate
1668+
if kind == "argmin" and self[candidate] < self.fill_value:
1669+
return candidate
1670+
if kind == "argmax" and self[candidate] > self.fill_value:
1671+
return candidate
1672+
_loc = self._first_fill_value_loc()
1673+
if _loc == -1:
1674+
# fill_value doesn't exist
1675+
return candidate
1676+
else:
1677+
return _loc
1678+
1679+
def argmax(self, skipna: bool = True) -> int:
1680+
validate_bool_kwarg(skipna, "skipna")
1681+
if not skipna and self._hasna:
1682+
raise NotImplementedError
1683+
return self._argmin_argmax("argmax")
1684+
1685+
def argmin(self, skipna: bool = True) -> int:
1686+
validate_bool_kwarg(skipna, "skipna")
1687+
if not skipna and self._hasna:
1688+
raise NotImplementedError
1689+
return self._argmin_argmax("argmin")
1690+
16491691
# ------------------------------------------------------------------------
16501692
# Ufuncs
16511693
# ------------------------------------------------------------------------

pandas/tests/arrays/sparse/test_reductions.py

+38
Original file line numberDiff line numberDiff line change
@@ -268,3 +268,41 @@ def test_na_value_if_no_valid_values(self, func, data, dtype, expected):
268268
assert result is NaT or np.isnat(result)
269269
else:
270270
assert np.isnan(result)
271+
272+
273+
class TestArgmaxArgmin:
274+
@pytest.mark.parametrize(
275+
"arr,argmax_expected,argmin_expected",
276+
[
277+
(SparseArray([1, 2, 0, 1, 2]), 1, 2),
278+
(SparseArray([-1, -2, 0, -1, -2]), 2, 1),
279+
(SparseArray([np.nan, 1, 0, 0, np.nan, -1]), 1, 5),
280+
(SparseArray([np.nan, 1, 0, 0, np.nan, 2]), 5, 2),
281+
(SparseArray([np.nan, 1, 0, 0, np.nan, 2], fill_value=-1), 5, 2),
282+
(SparseArray([np.nan, 1, 0, 0, np.nan, 2], fill_value=0), 5, 2),
283+
(SparseArray([np.nan, 1, 0, 0, np.nan, 2], fill_value=1), 5, 2),
284+
(SparseArray([np.nan, 1, 0, 0, np.nan, 2], fill_value=2), 5, 2),
285+
(SparseArray([np.nan, 1, 0, 0, np.nan, 2], fill_value=3), 5, 2),
286+
(SparseArray([0] * 10 + [-1], fill_value=0), 0, 10),
287+
(SparseArray([0] * 10 + [-1], fill_value=-1), 0, 10),
288+
(SparseArray([0] * 10 + [-1], fill_value=1), 0, 10),
289+
(SparseArray([-1] + [0] * 10, fill_value=0), 1, 0),
290+
(SparseArray([1] + [0] * 10, fill_value=0), 0, 1),
291+
(SparseArray([-1] + [0] * 10, fill_value=-1), 1, 0),
292+
(SparseArray([1] + [0] * 10, fill_value=1), 0, 1),
293+
],
294+
)
295+
def test_argmax_argmin(self, arr, argmax_expected, argmin_expected):
296+
argmax_result = arr.argmax()
297+
argmin_result = arr.argmin()
298+
assert argmax_result == argmax_expected
299+
assert argmin_result == argmin_expected
300+
301+
@pytest.mark.parametrize(
302+
"arr,method",
303+
[(SparseArray([]), "argmax"), (SparseArray([]), "argmin")],
304+
)
305+
def test_empty_array(self, arr, method):
306+
msg = f"attempt to get {method} of an empty sequence"
307+
with pytest.raises(ValueError, match=msg):
308+
arr.argmax() if method == "argmax" else arr.argmin()

0 commit comments

Comments
 (0)