Skip to content

Support nan-ops for object-typed arrays #1883

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Feb 15, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,16 @@ Documentation

Enhancements
~~~~~~~~~~~~
- reduce methods such as :py:func:`DataArray.sum()` now accepts ``dtype``
- Reduce methods such as :py:func:`DataArray.sum()` now handles object-type array.

.. ipython:: python

da = xr.DataArray(np.array([True, False, np.nan], dtype=object), dims='x')
da.sum()

(:issue:`1866`)
By `Keisuke Fujii <https://github.com/fujiisoup>`_.
- Reduce methods such as :py:func:`DataArray.sum()` now accepts ``dtype``
arguments. (:issue:`1838`)
By `Keisuke Fujii <https://github.com/fujiisoup>`_.
- Added nodatavals attribute to DataArray when using :py:func:`~xarray.open_rasterio`. (:issue:`1736`).
Expand Down
87 changes: 87 additions & 0 deletions xarray/core/dtypes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import functools

from . import utils

Expand All @@ -7,6 +8,29 @@
NA = utils.ReprObject('<NA>')


@functools.total_ordering
class AlwaysGreaterThan(object):
def __gt__(self, other):
return True

def __eq__(self, other):
return isinstance(other, type(self))


@functools.total_ordering
class AlwaysLessThan(object):
def __lt__(self, other):
return True

def __eq__(self, other):
return isinstance(other, type(self))


# Equivalence to np.inf (-np.inf) for object-type
INF = AlwaysGreaterThan()
NINF = AlwaysLessThan()


# Pairs of types that, if both found, should be promoted to object dtype
# instead of following NumPy's own type-promotion rules. These type promotion
# rules match pandas instead. For reference, see the NumPy type hierarchy:
Expand All @@ -18,6 +42,29 @@
]


@functools.total_ordering
class AlwaysGreaterThan(object):
def __gt__(self, other):
return True

def __eq__(self, other):
return isinstance(other, type(self))


@functools.total_ordering
class AlwaysLessThan(object):
def __lt__(self, other):
return True

def __eq__(self, other):
return isinstance(other, type(self))


# Equivalence to np.inf (-np.inf) for object-type
INF = AlwaysGreaterThan()
NINF = AlwaysLessThan()


def maybe_promote(dtype):
"""Simpler equivalent of pandas.core.common._maybe_promote

Expand Down Expand Up @@ -66,6 +113,46 @@ def get_fill_value(dtype):
return fill_value


def get_pos_infinity(dtype):
"""Return an appropriate positive infinity for this dtype.

Parameters
----------
dtype : np.dtype

Returns
-------
fill_value : positive infinity value corresponding to this dtype.
"""
if issubclass(dtype.type, (np.floating, np.integer)):
return np.inf

if issubclass(dtype.type, np.complexfloating):
return np.inf + 1j * np.inf

return INF


def get_neg_infinity(dtype):
"""Return an appropriate positive infinity for this dtype.

Parameters
----------
dtype : np.dtype

Returns
-------
fill_value : positive infinity value corresponding to this dtype.
"""
if issubclass(dtype.type, (np.floating, np.integer)):
return -np.inf

if issubclass(dtype.type, np.complexfloating):
return -np.inf - 1j * np.inf

return NINF


def is_datetime_like(dtype):
"""Check if a dtype is a subclass of the numpy datetime types
"""
Expand Down
121 changes: 101 additions & 20 deletions xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,79 @@ def _ignore_warnings_if(condition):
yield


def _nansum_object(value, axis=None, **kwargs):
""" In house nansum for object array """
value = fillna(value, 0)
return _dask_or_eager_func('sum')(value, axis=axis, **kwargs)


def _nan_minmax_object(func, get_fill_value, value, axis=None, **kwargs):
""" In house nanmin and nanmax for object array """
fill_value = get_fill_value(value.dtype)
valid_count = count(value, axis=axis)
filled_value = fillna(value, fill_value)
data = _dask_or_eager_func(func)(filled_value, axis=axis, **kwargs)
if not hasattr(data, 'dtype'): # scalar case
data = dtypes.fill_value(value.dtype) if valid_count == 0 else data
return np.array(data, dtype=value.dtype)
return where_method(data, valid_count != 0)


def _nan_argminmax_object(func, get_fill_value, value, axis=None, **kwargs):
""" In house nanargmin, nanargmax for object arrays. Always return integer
type """
fill_value = get_fill_value(value.dtype)
valid_count = count(value, axis=axis)
value = fillna(value, fill_value)
data = _dask_or_eager_func(func)(value, axis=axis, **kwargs)
# dask seems return non-integer type
if isinstance(value, dask_array_type):
data = data.astype(int)

if (valid_count == 0).any():
raise ValueError('All-NaN slice encountered')

return np.array(data, dtype=int)


def _nanmean_ddof_object(ddof, value, axis=None, **kwargs):
""" In house nanmean. ddof argument will be used in _nanvar method """
valid_count = count(value, axis=axis)
value = fillna(value, 0)
# As dtype inference is impossible for object dtype, we assume float
# https://github.com/dask/dask/issues/3162
dtype = kwargs.pop('dtype', None)
if dtype is None and value.dtype.kind == 'O':
dtype = value.dtype if value.dtype.kind in ['cf'] else float

data = _dask_or_eager_func('sum')(value, axis=axis, dtype=dtype, **kwargs)
data = data / (valid_count - ddof)
return where_method(data, valid_count != 0)


def _nanvar_object(value, axis=None, **kwargs):
ddof = kwargs.pop('ddof', 0)
kwargs_mean = kwargs.copy()
kwargs_mean.pop('keepdims', None)
value_mean = _nanmean_ddof_object(ddof=0, value=value, axis=axis,
keepdims=True, **kwargs_mean)
squared = (value.astype(value_mean.dtype) - value_mean)**2
return _nanmean_ddof_object(ddof, squared, axis=axis, **kwargs)


_nan_object_funcs = {
'sum': _nansum_object,
'min': partial(_nan_minmax_object, 'min', dtypes.get_pos_infinity),
'max': partial(_nan_minmax_object, 'max', dtypes.get_neg_infinity),
'argmin': partial(_nan_argminmax_object, 'argmin',
dtypes.get_pos_infinity),
'argmax': partial(_nan_argminmax_object, 'argmax',
dtypes.get_neg_infinity),
'mean': partial(_nanmean_ddof_object, 0),
'var': _nanvar_object,
}


def _create_nan_agg_method(name, numeric_only=False, np_compat=False,
no_bottleneck=False, coerce_strings=False,
keep_dims=False):
Expand All @@ -211,27 +284,31 @@ def f(values, axis=None, skipna=None, **kwargs):
if coerce_strings and values.dtype.kind in 'SU':
values = values.astype(object)

if skipna or (skipna is None and values.dtype.kind in 'cf'):
if skipna or (skipna is None and values.dtype.kind in 'cfO'):
if values.dtype.kind not in ['u', 'i', 'f', 'c']:
raise NotImplementedError(
'skipna=True not yet implemented for %s with dtype %s'
% (name, values.dtype))
nanname = 'nan' + name
if (isinstance(axis, tuple) or not values.dtype.isnative or
no_bottleneck or
(dtype is not None and np.dtype(dtype) != values.dtype)):
# bottleneck can't handle multiple axis arguments or non-native
# endianness
if np_compat:
eager_module = npcompat
else:
eager_module = np
func = _nan_object_funcs.get(name, None)
using_numpy_nan_func = True
if func is None or values.dtype.kind not in 'Ob':
raise NotImplementedError(
'skipna=True not yet implemented for %s with dtype %s'
% (name, values.dtype))
else:
kwargs.pop('dtype', None)
eager_module = bn
func = _dask_or_eager_func(nanname, eager_module)
using_numpy_nan_func = (eager_module is np or
eager_module is npcompat)
nanname = 'nan' + name
if (isinstance(axis, tuple) or not values.dtype.isnative or
no_bottleneck or (dtype is not None and
np.dtype(dtype) != values.dtype)):
# bottleneck can't handle multiple axis arguments or
# non-native endianness
if np_compat:
eager_module = npcompat
else:
eager_module = np
else:
kwargs.pop('dtype', None)
eager_module = bn
func = _dask_or_eager_func(nanname, eager_module)
using_numpy_nan_func = (eager_module is np or
eager_module is npcompat)
else:
func = _dask_or_eager_func(name)
using_numpy_nan_func = False
Expand All @@ -240,7 +317,11 @@ def f(values, axis=None, skipna=None, **kwargs):
return func(values, axis=axis, **kwargs)
except AttributeError:
if isinstance(values, dask_array_type):
msg = '%s is not yet implemented on dask arrays' % name
try: # dask/dask#3133 dask sometimes needs dtype argument
return func(values, axis=axis, dtype=values.dtype,
**kwargs)
except AttributeError:
msg = '%s is not yet implemented on dask arrays' % name
else:
assert using_numpy_nan_func
msg = ('%s is not available with skipna=False with the '
Expand Down
6 changes: 6 additions & 0 deletions xarray/tests/test_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,9 @@ def error():
# would get promoted to float32
actual = dtypes.result_type(array, np.array([0.5, 1.0], dtype=np.float32))
assert actual == np.float64


@pytest.mark.parametrize('obj', [1.0, np.inf, 'ab', 1.0 + 1.0j, True])
def test_inf(obj):
assert dtypes.INF > obj
assert dtypes.NINF < obj
Loading