@@ -91,35 +91,18 @@ def nanargmin(a, axis=None):
91
91
fill_value = dtypes .get_pos_infinity (a .dtype )
92
92
if a .dtype .kind == "O" :
93
93
return _nan_argminmax_object ("argmin" , fill_value , a , axis = axis )
94
- a , mask = _replace_nan (a , fill_value )
95
- if isinstance (a , dask_array_type ):
96
- res = dask_array .argmin (a , axis = axis )
97
- else :
98
- res = np .argmin (a , axis = axis )
99
94
100
- if mask is not None :
101
- mask = mask .all (axis = axis )
102
- if mask .any ():
103
- raise ValueError ("All-NaN slice encountered" )
104
- return res
95
+ module = dask_array if isinstance (a , dask_array_type ) else nputils
96
+ return module .nanargmin (a , axis = axis )
105
97
106
98
107
99
def nanargmax (a , axis = None ):
108
100
fill_value = dtypes .get_neg_infinity (a .dtype )
109
101
if a .dtype .kind == "O" :
110
102
return _nan_argminmax_object ("argmax" , fill_value , a , axis = axis )
111
103
112
- a , mask = _replace_nan (a , fill_value )
113
- if isinstance (a , dask_array_type ):
114
- res = dask_array .argmax (a , axis = axis )
115
- else :
116
- res = np .argmax (a , axis = axis )
117
-
118
- if mask is not None :
119
- mask = mask .all (axis = axis )
120
- if mask .any ():
121
- raise ValueError ("All-NaN slice encountered" )
122
- return res
104
+ module = dask_array if isinstance (a , dask_array_type ) else nputils
105
+ return module .nanargmax (a , axis = axis )
123
106
124
107
125
108
def nansum (a , axis = None , dtype = None , out = None , min_count = None ):
0 commit comments