Skip to content

Commit 0e21fdf

Browse files
authored
Validate output dimension sizes with apply_ufunc (#2155)
* Validate output dimension sizes with apply_ufunc Fixes GH1931 Uses of apply_ufunc that change dimension size now raise an explicit error, e.g., >>> xr.apply_ufunc(lambda x: x[:5], xr.Variable('x', np.arange(10))) ValueError: size of dimension 'x' on inputs was unexpectedly changed by applied function from 10 to 5. Only dimensions specified in ``exclude_dims`` with xarray.apply_ufunc are allowed to change size. * lint * More output validation for apply_ufunc
1 parent 7036eb5 commit 0e21fdf

File tree

3 files changed

+160
-22
lines changed

3 files changed

+160
-22
lines changed

doc/whats-new.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ Enhancements
6767
Bug fixes
6868
~~~~~~~~~
6969

70+
- :py:func:`apply_ufunc` now directly validates output variables
71+
(:issue:`1931`).
72+
By `Stephan Hoyer <https://github.com/shoyer>`_.
73+
7074
- Fixed a bug where ``to_netcdf(..., unlimited_dims='bar')`` yielded NetCDF
7175
files with spurious 0-length dimensions (i.e. ``b``, ``a``, and ``r``)
7276
(:issue:`2134`).

xarray/core/computation.py

Lines changed: 68 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -513,7 +513,7 @@ def broadcast_compat_data(variable, broadcast_dims, core_dims):
513513
def apply_variable_ufunc(func, *args, **kwargs):
514514
"""apply_variable_ufunc(func, *args, signature, exclude_dims=frozenset())
515515
"""
516-
from .variable import Variable
516+
from .variable import Variable, as_compatible_data
517517

518518
signature = kwargs.pop('signature')
519519
exclude_dims = kwargs.pop('exclude_dims', _DEFAULT_FROZEN_SET)
@@ -559,20 +559,42 @@ def func(*arrays):
559559
'apply_ufunc: {}'.format(dask))
560560
result_data = func(*input_data)
561561

562-
if signature.num_outputs > 1:
563-
output = []
564-
for dims, data in zip(output_dims, result_data):
565-
var = Variable(dims, data)
566-
if keep_attrs and isinstance(args[0], Variable):
567-
var.attrs.update(args[0].attrs)
568-
output.append(var)
569-
return tuple(output)
570-
else:
571-
dims, = output_dims
572-
var = Variable(dims, result_data)
562+
if signature.num_outputs == 1:
563+
result_data = (result_data,)
564+
elif (not isinstance(result_data, tuple) or
565+
len(result_data) != signature.num_outputs):
566+
raise ValueError('applied function does not have the number of '
567+
'outputs specified in the ufunc signature. '
568+
'Result is not a tuple of {} elements: {!r}'
569+
.format(signature.num_outputs, result_data))
570+
571+
output = []
572+
for dims, data in zip(output_dims, result_data):
573+
data = as_compatible_data(data)
574+
if data.ndim != len(dims):
575+
raise ValueError(
576+
'applied function returned data with unexpected '
577+
'number of dimensions: {} vs {}, for dimensions {}'
578+
.format(data.ndim, len(dims), dims))
579+
580+
var = Variable(dims, data, fastpath=True)
581+
for dim, new_size in var.sizes.items():
582+
if dim in dim_sizes and new_size != dim_sizes[dim]:
583+
raise ValueError(
584+
'size of dimension {!r} on inputs was unexpectedly '
585+
'changed by applied function from {} to {}. Only '
586+
'dimensions specified in ``exclude_dims`` with '
587+
'xarray.apply_ufunc are allowed to change size.'
588+
.format(dim, dim_sizes[dim], new_size))
589+
573590
if keep_attrs and isinstance(args[0], Variable):
574591
var.attrs.update(args[0].attrs)
575-
return var
592+
output.append(var)
593+
594+
if signature.num_outputs == 1:
595+
return output[0]
596+
else:
597+
return tuple(output)
576598

577599

578600
def _apply_with_dask_atop(func, args, input_dims, output_dims, signature,
@@ -719,7 +741,8 @@ def apply_ufunc(func, *args, **kwargs):
719741
Core dimensions on the inputs to exclude from alignment and
720742
broadcasting entirely. Any input coordinates along these dimensions
721743
will be dropped. Each excluded dimension must also appear in
722-
``input_core_dims`` for at least one argument.
744+
``input_core_dims`` for at least one argument. Only dimensions listed
745+
here are allowed to change size between input and output objects.
723746
vectorize : bool, optional
724747
If True, then assume ``func`` only takes arrays defined over core
725748
dimensions as input and vectorize it automatically with
@@ -777,15 +800,38 @@ def apply_ufunc(func, *args, **kwargs):
777800
778801
Examples
779802
--------
780-
For illustrative purposes only, here are examples of how you could use
781-
``apply_ufunc`` to write functions to (very nearly) replicate existing
782-
xarray functionality:
783803
784-
Calculate the vector magnitude of two arguments::
804+
Calculate the vector magnitude of two arguments:
805+
806+
>>> def magnitude(a, b):
807+
... func = lambda x, y: np.sqrt(x ** 2 + y ** 2)
808+
... return xr.apply_ufunc(func, a, b)
809+
810+
You can now apply ``magnitude()`` to ``xr.DataArray`` and ``xr.Dataset``
811+
objects, with automatically preserved dimensions and coordinates, e.g.,
812+
813+
>>> array = xr.DataArray([1, 2, 3], coords=[('x', [0.1, 0.2, 0.3])])
814+
>>> magnitude(array, -array)
815+
<xarray.DataArray (x: 3)>
816+
array([1.414214, 2.828427, 4.242641])
817+
Coordinates:
818+
* x (x) float64 0.1 0.2 0.3
819+
820+
Plain scalars, numpy arrays and a mix of these with xarray objects is also
821+
supported:
822+
823+
>>> magnitude(4, 5)
824+
5.0
825+
>>> magnitude(3, np.array([0, 4]))
826+
array([3., 5.])
827+
>>> magnitude(array, 0)
828+
<xarray.DataArray (x: 3)>
829+
array([1., 2., 3.])
830+
Coordinates:
831+
* x (x) float64 0.1 0.2 0.3
785832
786-
def magnitude(a, b):
787-
func = lambda x, y: np.sqrt(x ** 2 + y ** 2)
788-
return xr.apply_func(func, a, b)
833+
Other examples of how you could use ``apply_ufunc`` to write functions to
834+
(very nearly) replicate existing xarray functionality:
789835
790836
Compute the mean (``.mean``) over one dimension::
791837
@@ -795,7 +841,7 @@ def mean(obj, dim):
795841
input_core_dims=[[dim]],
796842
kwargs={'axis': -1})
797843
798-
Inner product over a specific dimension::
844+
Inner product over a specific dimension (like ``xr.dot``)::
799845
800846
def _inner(x, y):
801847
result = np.matmul(x[..., np.newaxis, :], y[..., :, np.newaxis])

xarray/tests/test_computation.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -752,6 +752,94 @@ def test_vectorize_dask():
752752
assert_identical(expected, actual)
753753

754754

755+
def test_output_wrong_number():
756+
variable = xr.Variable('x', np.arange(10))
757+
758+
def identity(x):
759+
return x
760+
761+
def tuple3x(x):
762+
return (x, x, x)
763+
764+
with raises_regex(ValueError, 'number of outputs'):
765+
apply_ufunc(identity, variable, output_core_dims=[(), ()])
766+
767+
with raises_regex(ValueError, 'number of outputs'):
768+
apply_ufunc(tuple3x, variable, output_core_dims=[(), ()])
769+
770+
771+
def test_output_wrong_dims():
772+
variable = xr.Variable('x', np.arange(10))
773+
774+
def add_dim(x):
775+
return x[..., np.newaxis]
776+
777+
def remove_dim(x):
778+
return x[..., 0]
779+
780+
with raises_regex(ValueError, 'unexpected number of dimensions'):
781+
apply_ufunc(add_dim, variable, output_core_dims=[('y', 'z')])
782+
783+
with raises_regex(ValueError, 'unexpected number of dimensions'):
784+
apply_ufunc(add_dim, variable)
785+
786+
with raises_regex(ValueError, 'unexpected number of dimensions'):
787+
apply_ufunc(remove_dim, variable)
788+
789+
790+
def test_output_wrong_dim_size():
791+
array = np.arange(10)
792+
variable = xr.Variable('x', array)
793+
data_array = xr.DataArray(variable, [('x', -array)])
794+
dataset = xr.Dataset({'y': variable}, {'x': -array})
795+
796+
def truncate(array):
797+
return array[:5]
798+
799+
def apply_truncate_broadcast_invalid(obj):
800+
return apply_ufunc(truncate, obj)
801+
802+
with raises_regex(ValueError, 'size of dimension'):
803+
apply_truncate_broadcast_invalid(variable)
804+
with raises_regex(ValueError, 'size of dimension'):
805+
apply_truncate_broadcast_invalid(data_array)
806+
with raises_regex(ValueError, 'size of dimension'):
807+
apply_truncate_broadcast_invalid(dataset)
808+
809+
def apply_truncate_x_x_invalid(obj):
810+
return apply_ufunc(truncate, obj, input_core_dims=[['x']],
811+
output_core_dims=[['x']])
812+
813+
with raises_regex(ValueError, 'size of dimension'):
814+
apply_truncate_x_x_invalid(variable)
815+
with raises_regex(ValueError, 'size of dimension'):
816+
apply_truncate_x_x_invalid(data_array)
817+
with raises_regex(ValueError, 'size of dimension'):
818+
apply_truncate_x_x_invalid(dataset)
819+
820+
def apply_truncate_x_z(obj):
821+
return apply_ufunc(truncate, obj, input_core_dims=[['x']],
822+
output_core_dims=[['z']])
823+
824+
assert_identical(xr.Variable('z', array[:5]),
825+
apply_truncate_x_z(variable))
826+
assert_identical(xr.DataArray(array[:5], dims=['z']),
827+
apply_truncate_x_z(data_array))
828+
assert_identical(xr.Dataset({'y': ('z', array[:5])}),
829+
apply_truncate_x_z(dataset))
830+
831+
def apply_truncate_x_x_valid(obj):
832+
return apply_ufunc(truncate, obj, input_core_dims=[['x']],
833+
output_core_dims=[['x']], exclude_dims={'x'})
834+
835+
assert_identical(xr.Variable('x', array[:5]),
836+
apply_truncate_x_x_valid(variable))
837+
assert_identical(xr.DataArray(array[:5], dims=['x']),
838+
apply_truncate_x_x_valid(data_array))
839+
assert_identical(xr.Dataset({'y': ('x', array[:5])}),
840+
apply_truncate_x_x_valid(dataset))
841+
842+
755843
@pytest.mark.parametrize('use_dask', [True, False])
756844
def test_dot(use_dask):
757845
if use_dask:

0 commit comments

Comments
 (0)