@@ -197,6 +197,79 @@ def _ignore_warnings_if(condition):
197
197
yield
198
198
199
199
200
+ def _nansum_object (value , axis = None , ** kwargs ):
201
+ """ In house nansum for object array """
202
+ value = fillna (value , 0 )
203
+ return _dask_or_eager_func ('sum' )(value , axis = axis , ** kwargs )
204
+
205
+
206
+ def _nan_minmax_object (func , get_fill_value , value , axis = None , ** kwargs ):
207
+ """ In house nanmin and nanmax for object array """
208
+ fill_value = get_fill_value (value .dtype )
209
+ valid_count = count (value , axis = axis )
210
+ filled_value = fillna (value , fill_value )
211
+ data = _dask_or_eager_func (func )(filled_value , axis = axis , ** kwargs )
212
+ if not hasattr (data , 'dtype' ): # scalar case
213
+ data = dtypes .fill_value (value .dtype ) if valid_count == 0 else data
214
+ return np .array (data , dtype = value .dtype )
215
+ return where_method (data , valid_count != 0 )
216
+
217
+
218
+ def _nan_argminmax_object (func , get_fill_value , value , axis = None , ** kwargs ):
219
+ """ In house nanargmin, nanargmax for object arrays. Always return integer
220
+ type """
221
+ fill_value = get_fill_value (value .dtype )
222
+ valid_count = count (value , axis = axis )
223
+ value = fillna (value , fill_value )
224
+ data = _dask_or_eager_func (func )(value , axis = axis , ** kwargs )
225
+ # dask seems return non-integer type
226
+ if isinstance (value , dask_array_type ):
227
+ data = data .astype (int )
228
+
229
+ if (valid_count == 0 ).any ():
230
+ raise ValueError ('All-NaN slice encountered' )
231
+
232
+ return np .array (data , dtype = int )
233
+
234
+
235
+ def _nanmean_ddof_object (ddof , value , axis = None , ** kwargs ):
236
+ """ In house nanmean. ddof argument will be used in _nanvar method """
237
+ valid_count = count (value , axis = axis )
238
+ value = fillna (value , 0 )
239
+ # As dtype inference is impossible for object dtype, we assume float
240
+ # https://github.com/dask/dask/issues/3162
241
+ dtype = kwargs .pop ('dtype' , None )
242
+ if dtype is None and value .dtype .kind == 'O' :
243
+ dtype = value .dtype if value .dtype .kind in ['cf' ] else float
244
+
245
+ data = _dask_or_eager_func ('sum' )(value , axis = axis , dtype = dtype , ** kwargs )
246
+ data = data / (valid_count - ddof )
247
+ return where_method (data , valid_count != 0 )
248
+
249
+
250
+ def _nanvar_object (value , axis = None , ** kwargs ):
251
+ ddof = kwargs .pop ('ddof' , 0 )
252
+ kwargs_mean = kwargs .copy ()
253
+ kwargs_mean .pop ('keepdims' , None )
254
+ value_mean = _nanmean_ddof_object (ddof = 0 , value = value , axis = axis ,
255
+ keepdims = True , ** kwargs_mean )
256
+ squared = (value .astype (value_mean .dtype ) - value_mean )** 2
257
+ return _nanmean_ddof_object (ddof , squared , axis = axis , ** kwargs )
258
+
259
+
260
+ _nan_object_funcs = {
261
+ 'sum' : _nansum_object ,
262
+ 'min' : partial (_nan_minmax_object , 'min' , dtypes .get_pos_infinity ),
263
+ 'max' : partial (_nan_minmax_object , 'max' , dtypes .get_neg_infinity ),
264
+ 'argmin' : partial (_nan_argminmax_object , 'argmin' ,
265
+ dtypes .get_pos_infinity ),
266
+ 'argmax' : partial (_nan_argminmax_object , 'argmax' ,
267
+ dtypes .get_neg_infinity ),
268
+ 'mean' : partial (_nanmean_ddof_object , 0 ),
269
+ 'var' : _nanvar_object ,
270
+ }
271
+
272
+
200
273
def _create_nan_agg_method (name , numeric_only = False , np_compat = False ,
201
274
no_bottleneck = False , coerce_strings = False ,
202
275
keep_dims = False ):
@@ -211,27 +284,31 @@ def f(values, axis=None, skipna=None, **kwargs):
211
284
if coerce_strings and values .dtype .kind in 'SU' :
212
285
values = values .astype (object )
213
286
214
- if skipna or (skipna is None and values .dtype .kind in 'cf ' ):
287
+ if skipna or (skipna is None and values .dtype .kind in 'cfO ' ):
215
288
if values .dtype .kind not in ['u' , 'i' , 'f' , 'c' ]:
216
- raise NotImplementedError (
217
- 'skipna=True not yet implemented for %s with dtype %s'
218
- % (name , values .dtype ))
219
- nanname = 'nan' + name
220
- if (isinstance (axis , tuple ) or not values .dtype .isnative or
221
- no_bottleneck or
222
- (dtype is not None and np .dtype (dtype ) != values .dtype )):
223
- # bottleneck can't handle multiple axis arguments or non-native
224
- # endianness
225
- if np_compat :
226
- eager_module = npcompat
227
- else :
228
- eager_module = np
289
+ func = _nan_object_funcs .get (name , None )
290
+ using_numpy_nan_func = True
291
+ if func is None or values .dtype .kind not in 'Ob' :
292
+ raise NotImplementedError (
293
+ 'skipna=True not yet implemented for %s with dtype %s'
294
+ % (name , values .dtype ))
229
295
else :
230
- kwargs .pop ('dtype' , None )
231
- eager_module = bn
232
- func = _dask_or_eager_func (nanname , eager_module )
233
- using_numpy_nan_func = (eager_module is np or
234
- eager_module is npcompat )
296
+ nanname = 'nan' + name
297
+ if (isinstance (axis , tuple ) or not values .dtype .isnative or
298
+ no_bottleneck or (dtype is not None and
299
+ np .dtype (dtype ) != values .dtype )):
300
+ # bottleneck can't handle multiple axis arguments or
301
+ # non-native endianness
302
+ if np_compat :
303
+ eager_module = npcompat
304
+ else :
305
+ eager_module = np
306
+ else :
307
+ kwargs .pop ('dtype' , None )
308
+ eager_module = bn
309
+ func = _dask_or_eager_func (nanname , eager_module )
310
+ using_numpy_nan_func = (eager_module is np or
311
+ eager_module is npcompat )
235
312
else :
236
313
func = _dask_or_eager_func (name )
237
314
using_numpy_nan_func = False
@@ -240,7 +317,11 @@ def f(values, axis=None, skipna=None, **kwargs):
240
317
return func (values , axis = axis , ** kwargs )
241
318
except AttributeError :
242
319
if isinstance (values , dask_array_type ):
243
- msg = '%s is not yet implemented on dask arrays' % name
320
+ try : # dask/dask#3133 dask sometimes needs dtype argument
321
+ return func (values , axis = axis , dtype = values .dtype ,
322
+ ** kwargs )
323
+ except AttributeError :
324
+ msg = '%s is not yet implemented on dask arrays' % name
244
325
else :
245
326
assert using_numpy_nan_func
246
327
msg = ('%s is not available with skipna=False with the '
0 commit comments