Skip to content

Commit b6a0d60

Browse files
authored
Support nan-ops for object-typed arrays (#1883)
* First support of sum, min, max for object-typed arrays * typo * flake8 * Pandas compatiblity test. Added nanmean for object-type array * Improve test * Support nanvar, nanstd * Fix bug in _create_nan_agg_method * Added nanargmin/nanargmax * Support numpy<1.13. * Update tests. * Some cleanups and whatsnew * Simplify tests. Drop support std. * flake8 * xray -> xr * string array support * Support str dtype. Refactor nanmean * added get_pos_inifinity and get_neg_inifinity * Use function for get_fill_value instead of str. Add test to make sure it raises ValueError in argmin/argmax. * Tests for dtypes.INF
1 parent 2aa5b8a commit b6a0d60

File tree

5 files changed

+384
-25
lines changed

5 files changed

+384
-25
lines changed

doc/whats-new.rst

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,16 @@ Documentation
4343

4444
Enhancements
4545
~~~~~~~~~~~~
46-
- reduce methods such as :py:func:`DataArray.sum()` now accepts ``dtype``
46+
- Reduce methods such as :py:func:`DataArray.sum()` now handles object-type array.
47+
48+
.. ipython:: python
49+
50+
da = xr.DataArray(np.array([True, False, np.nan], dtype=object), dims='x')
51+
da.sum()
52+
53+
(:issue:`1866`)
54+
By `Keisuke Fujii <https://github.com/fujiisoup>`_.
55+
- Reduce methods such as :py:func:`DataArray.sum()` now accepts ``dtype``
4756
arguments. (:issue:`1838`)
4857
By `Keisuke Fujii <https://github.com/fujiisoup>`_.
4958
- Added nodatavals attribute to DataArray when using :py:func:`~xarray.open_rasterio`. (:issue:`1736`).

xarray/core/dtypes.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import functools
23

34
from . import utils
45

@@ -7,6 +8,29 @@
78
NA = utils.ReprObject('<NA>')
89

910

11+
@functools.total_ordering
12+
class AlwaysGreaterThan(object):
13+
def __gt__(self, other):
14+
return True
15+
16+
def __eq__(self, other):
17+
return isinstance(other, type(self))
18+
19+
20+
@functools.total_ordering
21+
class AlwaysLessThan(object):
22+
def __lt__(self, other):
23+
return True
24+
25+
def __eq__(self, other):
26+
return isinstance(other, type(self))
27+
28+
29+
# Equivalence to np.inf (-np.inf) for object-type
30+
INF = AlwaysGreaterThan()
31+
NINF = AlwaysLessThan()
32+
33+
1034
# Pairs of types that, if both found, should be promoted to object dtype
1135
# instead of following NumPy's own type-promotion rules. These type promotion
1236
# rules match pandas instead. For reference, see the NumPy type hierarchy:
@@ -18,6 +42,29 @@
1842
]
1943

2044

45+
@functools.total_ordering
46+
class AlwaysGreaterThan(object):
47+
def __gt__(self, other):
48+
return True
49+
50+
def __eq__(self, other):
51+
return isinstance(other, type(self))
52+
53+
54+
@functools.total_ordering
55+
class AlwaysLessThan(object):
56+
def __lt__(self, other):
57+
return True
58+
59+
def __eq__(self, other):
60+
return isinstance(other, type(self))
61+
62+
63+
# Equivalence to np.inf (-np.inf) for object-type
64+
INF = AlwaysGreaterThan()
65+
NINF = AlwaysLessThan()
66+
67+
2168
def maybe_promote(dtype):
2269
"""Simpler equivalent of pandas.core.common._maybe_promote
2370
@@ -66,6 +113,46 @@ def get_fill_value(dtype):
66113
return fill_value
67114

68115

116+
def get_pos_infinity(dtype):
117+
"""Return an appropriate positive infinity for this dtype.
118+
119+
Parameters
120+
----------
121+
dtype : np.dtype
122+
123+
Returns
124+
-------
125+
fill_value : positive infinity value corresponding to this dtype.
126+
"""
127+
if issubclass(dtype.type, (np.floating, np.integer)):
128+
return np.inf
129+
130+
if issubclass(dtype.type, np.complexfloating):
131+
return np.inf + 1j * np.inf
132+
133+
return INF
134+
135+
136+
def get_neg_infinity(dtype):
137+
"""Return an appropriate positive infinity for this dtype.
138+
139+
Parameters
140+
----------
141+
dtype : np.dtype
142+
143+
Returns
144+
-------
145+
fill_value : positive infinity value corresponding to this dtype.
146+
"""
147+
if issubclass(dtype.type, (np.floating, np.integer)):
148+
return -np.inf
149+
150+
if issubclass(dtype.type, np.complexfloating):
151+
return -np.inf - 1j * np.inf
152+
153+
return NINF
154+
155+
69156
def is_datetime_like(dtype):
70157
"""Check if a dtype is a subclass of the numpy datetime types
71158
"""

xarray/core/duck_array_ops.py

Lines changed: 101 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,79 @@ def _ignore_warnings_if(condition):
197197
yield
198198

199199

200+
def _nansum_object(value, axis=None, **kwargs):
201+
""" In house nansum for object array """
202+
value = fillna(value, 0)
203+
return _dask_or_eager_func('sum')(value, axis=axis, **kwargs)
204+
205+
206+
def _nan_minmax_object(func, get_fill_value, value, axis=None, **kwargs):
207+
""" In house nanmin and nanmax for object array """
208+
fill_value = get_fill_value(value.dtype)
209+
valid_count = count(value, axis=axis)
210+
filled_value = fillna(value, fill_value)
211+
data = _dask_or_eager_func(func)(filled_value, axis=axis, **kwargs)
212+
if not hasattr(data, 'dtype'): # scalar case
213+
data = dtypes.fill_value(value.dtype) if valid_count == 0 else data
214+
return np.array(data, dtype=value.dtype)
215+
return where_method(data, valid_count != 0)
216+
217+
218+
def _nan_argminmax_object(func, get_fill_value, value, axis=None, **kwargs):
219+
""" In house nanargmin, nanargmax for object arrays. Always return integer
220+
type """
221+
fill_value = get_fill_value(value.dtype)
222+
valid_count = count(value, axis=axis)
223+
value = fillna(value, fill_value)
224+
data = _dask_or_eager_func(func)(value, axis=axis, **kwargs)
225+
# dask seems return non-integer type
226+
if isinstance(value, dask_array_type):
227+
data = data.astype(int)
228+
229+
if (valid_count == 0).any():
230+
raise ValueError('All-NaN slice encountered')
231+
232+
return np.array(data, dtype=int)
233+
234+
235+
def _nanmean_ddof_object(ddof, value, axis=None, **kwargs):
236+
""" In house nanmean. ddof argument will be used in _nanvar method """
237+
valid_count = count(value, axis=axis)
238+
value = fillna(value, 0)
239+
# As dtype inference is impossible for object dtype, we assume float
240+
# https://github.com/dask/dask/issues/3162
241+
dtype = kwargs.pop('dtype', None)
242+
if dtype is None and value.dtype.kind == 'O':
243+
dtype = value.dtype if value.dtype.kind in ['cf'] else float
244+
245+
data = _dask_or_eager_func('sum')(value, axis=axis, dtype=dtype, **kwargs)
246+
data = data / (valid_count - ddof)
247+
return where_method(data, valid_count != 0)
248+
249+
250+
def _nanvar_object(value, axis=None, **kwargs):
251+
ddof = kwargs.pop('ddof', 0)
252+
kwargs_mean = kwargs.copy()
253+
kwargs_mean.pop('keepdims', None)
254+
value_mean = _nanmean_ddof_object(ddof=0, value=value, axis=axis,
255+
keepdims=True, **kwargs_mean)
256+
squared = (value.astype(value_mean.dtype) - value_mean)**2
257+
return _nanmean_ddof_object(ddof, squared, axis=axis, **kwargs)
258+
259+
260+
_nan_object_funcs = {
261+
'sum': _nansum_object,
262+
'min': partial(_nan_minmax_object, 'min', dtypes.get_pos_infinity),
263+
'max': partial(_nan_minmax_object, 'max', dtypes.get_neg_infinity),
264+
'argmin': partial(_nan_argminmax_object, 'argmin',
265+
dtypes.get_pos_infinity),
266+
'argmax': partial(_nan_argminmax_object, 'argmax',
267+
dtypes.get_neg_infinity),
268+
'mean': partial(_nanmean_ddof_object, 0),
269+
'var': _nanvar_object,
270+
}
271+
272+
200273
def _create_nan_agg_method(name, numeric_only=False, np_compat=False,
201274
no_bottleneck=False, coerce_strings=False,
202275
keep_dims=False):
@@ -211,27 +284,31 @@ def f(values, axis=None, skipna=None, **kwargs):
211284
if coerce_strings and values.dtype.kind in 'SU':
212285
values = values.astype(object)
213286

214-
if skipna or (skipna is None and values.dtype.kind in 'cf'):
287+
if skipna or (skipna is None and values.dtype.kind in 'cfO'):
215288
if values.dtype.kind not in ['u', 'i', 'f', 'c']:
216-
raise NotImplementedError(
217-
'skipna=True not yet implemented for %s with dtype %s'
218-
% (name, values.dtype))
219-
nanname = 'nan' + name
220-
if (isinstance(axis, tuple) or not values.dtype.isnative or
221-
no_bottleneck or
222-
(dtype is not None and np.dtype(dtype) != values.dtype)):
223-
# bottleneck can't handle multiple axis arguments or non-native
224-
# endianness
225-
if np_compat:
226-
eager_module = npcompat
227-
else:
228-
eager_module = np
289+
func = _nan_object_funcs.get(name, None)
290+
using_numpy_nan_func = True
291+
if func is None or values.dtype.kind not in 'Ob':
292+
raise NotImplementedError(
293+
'skipna=True not yet implemented for %s with dtype %s'
294+
% (name, values.dtype))
229295
else:
230-
kwargs.pop('dtype', None)
231-
eager_module = bn
232-
func = _dask_or_eager_func(nanname, eager_module)
233-
using_numpy_nan_func = (eager_module is np or
234-
eager_module is npcompat)
296+
nanname = 'nan' + name
297+
if (isinstance(axis, tuple) or not values.dtype.isnative or
298+
no_bottleneck or (dtype is not None and
299+
np.dtype(dtype) != values.dtype)):
300+
# bottleneck can't handle multiple axis arguments or
301+
# non-native endianness
302+
if np_compat:
303+
eager_module = npcompat
304+
else:
305+
eager_module = np
306+
else:
307+
kwargs.pop('dtype', None)
308+
eager_module = bn
309+
func = _dask_or_eager_func(nanname, eager_module)
310+
using_numpy_nan_func = (eager_module is np or
311+
eager_module is npcompat)
235312
else:
236313
func = _dask_or_eager_func(name)
237314
using_numpy_nan_func = False
@@ -240,7 +317,11 @@ def f(values, axis=None, skipna=None, **kwargs):
240317
return func(values, axis=axis, **kwargs)
241318
except AttributeError:
242319
if isinstance(values, dask_array_type):
243-
msg = '%s is not yet implemented on dask arrays' % name
320+
try: # dask/dask#3133 dask sometimes needs dtype argument
321+
return func(values, axis=axis, dtype=values.dtype,
322+
**kwargs)
323+
except AttributeError:
324+
msg = '%s is not yet implemented on dask arrays' % name
244325
else:
245326
assert using_numpy_nan_func
246327
msg = ('%s is not available with skipna=False with the '

xarray/tests/test_dtypes.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,9 @@ def error():
4646
# would get promoted to float32
4747
actual = dtypes.result_type(array, np.array([0.5, 1.0], dtype=np.float32))
4848
assert actual == np.float64
49+
50+
51+
@pytest.mark.parametrize('obj', [1.0, np.inf, 'ab', 1.0 + 1.0j, True])
52+
def test_inf(obj):
53+
assert dtypes.INF > obj
54+
assert dtypes.NINF < obj

0 commit comments

Comments
 (0)