@@ -513,7 +513,7 @@ def broadcast_compat_data(variable, broadcast_dims, core_dims):
513
513
def apply_variable_ufunc (func , * args , ** kwargs ):
514
514
"""apply_variable_ufunc(func, *args, signature, exclude_dims=frozenset())
515
515
"""
516
- from .variable import Variable
516
+ from .variable import Variable , as_compatible_data
517
517
518
518
signature = kwargs .pop ('signature' )
519
519
exclude_dims = kwargs .pop ('exclude_dims' , _DEFAULT_FROZEN_SET )
@@ -559,20 +559,42 @@ def func(*arrays):
559
559
'apply_ufunc: {}' .format (dask ))
560
560
result_data = func (* input_data )
561
561
562
- if signature .num_outputs > 1 :
563
- output = []
564
- for dims , data in zip (output_dims , result_data ):
565
- var = Variable (dims , data )
566
- if keep_attrs and isinstance (args [0 ], Variable ):
567
- var .attrs .update (args [0 ].attrs )
568
- output .append (var )
569
- return tuple (output )
570
- else :
571
- dims , = output_dims
572
- var = Variable (dims , result_data )
562
+ if signature .num_outputs == 1 :
563
+ result_data = (result_data ,)
564
+ elif (not isinstance (result_data , tuple ) or
565
+ len (result_data ) != signature .num_outputs ):
566
+ raise ValueError ('applied function does not have the number of '
567
+ 'outputs specified in the ufunc signature. '
568
+ 'Result is not a tuple of {} elements: {!r}'
569
+ .format (signature .num_outputs , result_data ))
570
+
571
+ output = []
572
+ for dims , data in zip (output_dims , result_data ):
573
+ data = as_compatible_data (data )
574
+ if data .ndim != len (dims ):
575
+ raise ValueError (
576
+ 'applied function returned data with unexpected '
577
+ 'number of dimensions: {} vs {}, for dimensions {}'
578
+ .format (data .ndim , len (dims ), dims ))
579
+
580
+ var = Variable (dims , data , fastpath = True )
581
+ for dim , new_size in var .sizes .items ():
582
+ if dim in dim_sizes and new_size != dim_sizes [dim ]:
583
+ raise ValueError (
584
+ 'size of dimension {!r} on inputs was unexpectedly '
585
+ 'changed by applied function from {} to {}. Only '
586
+ 'dimensions specified in ``exclude_dims`` with '
587
+ 'xarray.apply_ufunc are allowed to change size.'
588
+ .format (dim , dim_sizes [dim ], new_size ))
589
+
573
590
if keep_attrs and isinstance (args [0 ], Variable ):
574
591
var .attrs .update (args [0 ].attrs )
575
- return var
592
+ output .append (var )
593
+
594
+ if signature .num_outputs == 1 :
595
+ return output [0 ]
596
+ else :
597
+ return tuple (output )
576
598
577
599
578
600
def _apply_with_dask_atop (func , args , input_dims , output_dims , signature ,
@@ -719,7 +741,8 @@ def apply_ufunc(func, *args, **kwargs):
719
741
Core dimensions on the inputs to exclude from alignment and
720
742
broadcasting entirely. Any input coordinates along these dimensions
721
743
will be dropped. Each excluded dimension must also appear in
722
- ``input_core_dims`` for at least one argument.
744
+ ``input_core_dims`` for at least one argument. Only dimensions listed
745
+ here are allowed to change size between input and output objects.
723
746
vectorize : bool, optional
724
747
If True, then assume ``func`` only takes arrays defined over core
725
748
dimensions as input and vectorize it automatically with
@@ -777,15 +800,38 @@ def apply_ufunc(func, *args, **kwargs):
777
800
778
801
Examples
779
802
--------
780
- For illustrative purposes only, here are examples of how you could use
781
- ``apply_ufunc`` to write functions to (very nearly) replicate existing
782
- xarray functionality:
783
803
784
- Calculate the vector magnitude of two arguments::
804
+ Calculate the vector magnitude of two arguments:
805
+
806
+ >>> def magnitude(a, b):
807
+ ... func = lambda x, y: np.sqrt(x ** 2 + y ** 2)
808
+ ... return xr.apply_ufunc(func, a, b)
809
+
810
+ You can now apply ``magnitude()`` to ``xr.DataArray`` and ``xr.Dataset``
811
+ objects, with automatically preserved dimensions and coordinates, e.g.,
812
+
813
+ >>> array = xr.DataArray([1, 2, 3], coords=[('x', [0.1, 0.2, 0.3])])
814
+ >>> magnitude(array, -array)
815
+ <xarray.DataArray (x: 3)>
816
+ array([1.414214, 2.828427, 4.242641])
817
+ Coordinates:
818
+ * x (x) float64 0.1 0.2 0.3
819
+
820
+ Plain scalars, numpy arrays and a mix of these with xarray objects is also
821
+ supported:
822
+
823
+ >>> magnitude(4, 5)
824
+ 5.0
825
+ >>> magnitude(3, np.array([0, 4]))
826
+ array([3., 5.])
827
+ >>> magnitude(array, 0)
828
+ <xarray.DataArray (x: 3)>
829
+ array([1., 2., 3.])
830
+ Coordinates:
831
+ * x (x) float64 0.1 0.2 0.3
785
832
786
- def magnitude(a, b):
787
- func = lambda x, y: np.sqrt(x ** 2 + y ** 2)
788
- return xr.apply_func(func, a, b)
833
+ Other examples of how you could use ``apply_ufunc`` to write functions to
834
+ (very nearly) replicate existing xarray functionality:
789
835
790
836
Compute the mean (``.mean``) over one dimension::
791
837
@@ -795,7 +841,7 @@ def mean(obj, dim):
795
841
input_core_dims=[[dim]],
796
842
kwargs={'axis': -1})
797
843
798
- Inner product over a specific dimension::
844
+ Inner product over a specific dimension (like ``xr.dot``) ::
799
845
800
846
def _inner(x, y):
801
847
result = np.matmul(x[..., np.newaxis, :], y[..., :, np.newaxis])
0 commit comments