Skip to content

Commit 05ae290

Browse files
committed
Make argmin/max work lazy with dask (pydata#3237).
1 parent 76d4a67 commit 05ae290

File tree

2 files changed

+6
-21
lines changed

2 files changed

+6
-21
lines changed

xarray/core/nanops.py

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -91,35 +91,18 @@ def nanargmin(a, axis=None):
9191
fill_value = dtypes.get_pos_infinity(a.dtype)
9292
if a.dtype.kind == "O":
9393
return _nan_argminmax_object("argmin", fill_value, a, axis=axis)
94-
a, mask = _replace_nan(a, fill_value)
95-
if isinstance(a, dask_array_type):
96-
res = dask_array.argmin(a, axis=axis)
97-
else:
98-
res = np.argmin(a, axis=axis)
9994

100-
if mask is not None:
101-
mask = mask.all(axis=axis)
102-
if mask.any():
103-
raise ValueError("All-NaN slice encountered")
104-
return res
95+
module = dask_array if isinstance(a, dask_array_type) else nputils
96+
return module.nanargmin(a, axis=axis)
10597

10698

10799
def nanargmax(a, axis=None):
108100
fill_value = dtypes.get_neg_infinity(a.dtype)
109101
if a.dtype.kind == "O":
110102
return _nan_argminmax_object("argmax", fill_value, a, axis=axis)
111103

112-
a, mask = _replace_nan(a, fill_value)
113-
if isinstance(a, dask_array_type):
114-
res = dask_array.argmax(a, axis=axis)
115-
else:
116-
res = np.argmax(a, axis=axis)
117-
118-
if mask is not None:
119-
mask = mask.all(axis=axis)
120-
if mask.any():
121-
raise ValueError("All-NaN slice encountered")
122-
return res
104+
module = dask_array if isinstance(a, dask_array_type) else nputils
105+
return module.nanargmax(a, axis=axis)
123106

124107

125108
def nansum(a, axis=None, dtype=None, out=None, min_count=None):

xarray/core/nputils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,3 +236,5 @@ def f(values, axis=None, **kwargs):
236236
nanprod = _create_bottleneck_method("nanprod")
237237
nancumsum = _create_bottleneck_method("nancumsum")
238238
nancumprod = _create_bottleneck_method("nancumprod")
239+
nanargmin = _create_bottleneck_method("nanargmin")
240+
nanargmax = _create_bottleneck_method("nanargmax")

0 commit comments

Comments
 (0)