Skip to content

Commit 56abba5

Browse files
crusaderkyshoyer
authored andcommitted
Align arguments, auto align in broadcast and concat (#963)
* align(): added exclude and index optional parameters broadcast(): added exclude and copy optional parameters. Auto-align misaligned attrs This is unfinished work - see TODO * Recommitted changes after rebaselining * Added tests for broadcast on DataArray with copy=True|False and exclude= parameters * Removed copy flag from broadcast (now it's always False). Added/fixed unit tests * Fixed bug in dataset broadcast. Completed unit tests on align() and broadcast(). * Cleaned up code in broadcast(). Use source_ndarray() in no-copy tests * concat() to auto-align inputs. Fixed failing unit test on no-copy. Updated What's New. * tweaked What's New
1 parent 606e1d9 commit 56abba5

File tree

7 files changed

+228
-48
lines changed

7 files changed

+228
-48
lines changed

doc/whats-new.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@ Enhancements
2828
xarray objects (:issue:`432`). See :ref:`dictionary IO <dictionary io>`
2929
for more details. By `Julia Signell <https://github.com/jsignell>`_.
3030

31+
- Added ``exclude`` and ``indexes`` optional parameters to :py:func:`~xarray.align`,
32+
and ``exclude`` optional parameter to :py:func:`~xarray.broadcast`.
33+
By `Guido Imperiale <https://github.com/crusaderky>`_.
34+
- :py:func:`~xarray.broadcast` and :py:func:`~xarray.concat` will now auto-align inputs,
35+
using ``join=outer``. By `Guido Imperiale <https://github.com/crusaderky>`_.
36+
3137
Bug fixes
3238
~~~~~~~~~
3339

xarray/core/alignment.py

Lines changed: 41 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from .common import _maybe_promote
1010
from .pycompat import iteritems, OrderedDict
1111
from .utils import is_full_slice, is_dict_like
12-
from .variable import Variable, Coordinate, broadcast_variables
12+
from .variable import Variable, Coordinate
1313

1414

1515
def _get_joiner(join):
@@ -71,25 +71,17 @@ def align(*objects, **kwargs):
7171
If ``copy=True``, the returned objects contain all new variables. If
7272
``copy=False`` and no reindexing is required then the aligned objects
7373
will include original variables.
74+
exclude : sequence of str, optional
75+
Dimensions that must be excluded from alignment
76+
indexes : dict-like, optional
77+
Any indexes explicitly provided with the `indexes` argument should be
78+
used in preference to the aligned indexes.
7479
7580
Returns
7681
-------
7782
aligned : same as *objects
7883
Tuple of objects with aligned coordinates.
7984
"""
80-
return partial_align(*objects, exclude=None, **kwargs)
81-
82-
83-
def partial_align(*objects, **kwargs):
84-
"""partial_align(*objects, join='inner', copy=True, indexes=None,
85-
exclude=set())
86-
87-
Like align, but don't align along dimensions in exclude. Any indexes
88-
explicitly provided with the `indexes` argument should be used in preference
89-
to the aligned indexes.
90-
91-
Not public API.
92-
"""
9385
join = kwargs.pop('join', 'inner')
9486
copy = kwargs.pop('copy', True)
9587
indexes = kwargs.pop('indexes', None)
@@ -247,21 +239,26 @@ def var_indexers(var, indexers):
247239
return reindexed
248240

249241

250-
def broadcast(*args):
242+
def broadcast(*args, **kwargs):
251243
"""Explicitly broadcast any number of DataArray or Dataset objects against
252244
one another.
253245
254246
xarray objects automatically broadcast against each other in arithmetic
255247
operations, so this function should not be necessary for normal use.
256248
249+
If no change is needed, the input data is returned to the output without
250+
being copied.
251+
257252
Parameters
258253
----------
259-
*args: DataArray or Dataset objects
254+
*args : DataArray or Dataset objects
260255
Arrays to broadcast against each other.
256+
exclude : sequence of str, optional
257+
Dimensions that must not be broadcasted
261258
262259
Returns
263260
-------
264-
broadcast: tuple of xarray objects
261+
broadcast : tuple of xarray objects
265262
The same data as the input arrays, but with additional dimensions
266263
inserted so that all data arrays have the same dimensions and shape.
267264
@@ -322,36 +319,48 @@ def broadcast(*args):
322319
from .dataarray import DataArray
323320
from .dataset import Dataset
324321

325-
all_indexes = _get_all_indexes(args)
326-
for k, v in all_indexes.items():
327-
if not all(v[0].equals(vi) for vi in v[1:]):
328-
raise ValueError('cannot broadcast arrays: the %s index is not '
329-
'aligned (use xarray.align first)' % k)
322+
exclude = kwargs.pop('exclude', None)
323+
if exclude is None:
324+
exclude = set()
325+
if kwargs:
326+
raise TypeError('broadcast() got unexpected keyword arguments: %s'
327+
% list(kwargs))
328+
329+
args = align(*args, join='outer', copy=False, exclude=exclude)
330330

331331
common_coords = OrderedDict()
332332
dims_map = OrderedDict()
333333
for arg in args:
334334
for dim in arg.dims:
335-
if dim not in common_coords:
335+
if dim not in common_coords and dim not in exclude:
336336
common_coords[dim] = arg.coords[dim].variable
337337
dims_map[dim] = common_coords[dim].size
338338

339+
def _expand_dims(var):
340+
# Add excluded dims to a copy of dims_map
341+
var_dims_map = dims_map.copy()
342+
for dim in exclude:
343+
try:
344+
var_dims_map[dim] = var.shape[var.dims.index(dim)]
345+
except ValueError:
346+
# dim not in var.dims
347+
pass
348+
349+
return var.expand_dims(var_dims_map)
350+
339351
def _broadcast_array(array):
340-
data = array.variable.expand_dims(dims_map)
352+
data = _expand_dims(array.variable)
341353
coords = OrderedDict(array.coords)
342354
coords.update(common_coords)
343-
dims = tuple(common_coords)
344-
return DataArray(data, coords, dims, name=array.name,
355+
return DataArray(data, coords, data.dims, name=array.name,
345356
attrs=array.attrs, encoding=array.encoding)
346357

347358
def _broadcast_dataset(ds):
348-
data_vars = OrderedDict()
349-
for k in ds.data_vars:
350-
data_vars[k] = ds.variables[k].expand_dims(dims_map)
351-
359+
data_vars = OrderedDict(
360+
(k, _expand_dims(ds.variables[k]))
361+
for k in ds.data_vars)
352362
coords = OrderedDict(ds.coords)
353363
coords.update(common_coords)
354-
355364
return Dataset(data_vars, coords, ds.attrs)
356365

357366
result = []
@@ -367,6 +376,7 @@ def _broadcast_dataset(ds):
367376

368377

369378
def broadcast_arrays(*args):
379+
import warnings
370380
warnings.warn('xarray.broadcast_arrays is deprecated: use '
371381
'xarray.broadcast instead', DeprecationWarning, stacklevel=2)
372382
return broadcast(*args)

xarray/core/combine.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import pandas as pd
44

55
from . import utils
6+
from .alignment import align
67
from .merge import merge
78
from .pycompat import iteritems, OrderedDict, basestring
89
from .variable import Variable, as_variable, Coordinate, concat as concat_vars
@@ -202,10 +203,9 @@ def _dataset_concat(datasets, dim, data_vars, coords, compat, positions):
202203
raise ValueError("compat=%r invalid: must be 'equals' "
203204
"or 'identical'" % compat)
204205

205-
# don't bother trying to work with datasets as a generator instead of a
206-
# list; the gains would be minimal
207-
datasets = [as_dataset(ds) for ds in datasets]
208206
dim, coord = _calc_concat_dim_coord(dim)
207+
datasets = [as_dataset(ds) for ds in datasets]
208+
datasets = align(*datasets, join='outer', copy=False, exclude=[dim])
209209

210210
concat_over = _calc_concat_over(datasets, dim, data_vars, coords)
211211

xarray/core/merge.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pandas as pd
22

3-
from .alignment import partial_align
3+
from .alignment import align
44
from .utils import Frozen, is_dict_like
55
from .variable import as_variable, default_index_coordinate
66
from .pycompat import (basestring, OrderedDict)
@@ -322,7 +322,7 @@ def is_alignable(obj):
322322
# https://github.com/pydata/xarray/issues/943
323323
return input_objects
324324

325-
aligned = partial_align(*targets, join=join, copy=copy, indexes=indexes)
325+
aligned = align(*targets, join=join, copy=copy, indexes=indexes)
326326

327327
for position, key, aligned_obj in zip(positions, keys, aligned):
328328
if key is no_key:

xarray/test/test_combine.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,14 @@ def test_concat_size0(self):
106106
actual = concat(split_data[::-1], 'dim1')
107107
self.assertDatasetIdentical(data, actual)
108108

109+
def test_concat_autoalign(self):
110+
ds1 = Dataset({'foo': DataArray([1, 2], coords={'x': [1, 2]})})
111+
ds2 = Dataset({'foo': DataArray([1, 2], coords={'x': [1, 3]})})
112+
actual = concat([ds1, ds2], 'y')
113+
expected = Dataset({'foo': DataArray([[1, 2, np.nan], [1, np.nan, 2]],
114+
dims=['y', 'x'], coords={'y': [0, 1], 'x': [1, 2, 3]})})
115+
self.assertDatasetIdentical(expected, actual)
116+
109117
def test_concat_errors(self):
110118
data = create_test_data()
111119
split_data = [data.isel(dim1=slice(3)),
@@ -129,14 +137,6 @@ def test_concat_errors(self):
129137
data1['foo'] = ('bar', np.random.randn(10))
130138
concat([data0, data1], 'dim1')
131139

132-
with self.assertRaisesRegexp(ValueError, 'not equal across datasets'):
133-
data0, data1 = deepcopy(split_data)
134-
data1['dim2'] = 2 * data1['dim2']
135-
concat([data0, data1], 'dim1', coords='minimal')
136-
137-
with self.assertRaisesRegexp(ValueError, 'it is not 1-dimensional'):
138-
concat([data0, data1], 'dim1')
139-
140140
with self.assertRaisesRegexp(ValueError, 'compat.* invalid'):
141141
concat(split_data, 'dim1', compat='foobar')
142142

xarray/test/test_dataarray.py

Lines changed: 93 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1618,6 +1618,60 @@ def test_align_dtype(self):
16181618
c, d = align(a, b, join='outer')
16191619
self.assertEqual(c.dtype, np.float32)
16201620

1621+
def test_align_copy(self):
1622+
x = DataArray([1, 2, 3], coords=[('a', [1, 2, 3])])
1623+
y = DataArray([1, 2], coords=[('a', [3, 1])])
1624+
1625+
expected_x2 = x
1626+
expected_y2 = DataArray([2, np.nan, 1], coords=[('a', [1, 2, 3])])
1627+
1628+
x2, y2 = align(x, y, join='outer', copy=False)
1629+
self.assertDataArrayIdentical(expected_x2, x2)
1630+
self.assertDataArrayIdentical(expected_y2, y2)
1631+
assert source_ndarray(x2.data) is source_ndarray(x.data)
1632+
1633+
x2, y2 = align(x, y, join='outer', copy=True)
1634+
self.assertDataArrayIdentical(expected_x2, x2)
1635+
self.assertDataArrayIdentical(expected_y2, y2)
1636+
assert source_ndarray(x2.data) is not source_ndarray(x.data)
1637+
1638+
# Trivial align - 1 element
1639+
x = DataArray([1, 2, 3], coords=[('a', [1, 2, 3])])
1640+
x2, = align(x, copy=False)
1641+
self.assertDataArrayIdentical(x, x2)
1642+
assert source_ndarray(x2.data) is source_ndarray(x.data)
1643+
1644+
x2, = align(x, copy=True)
1645+
self.assertDataArrayIdentical(x, x2)
1646+
assert source_ndarray(x2.data) is not source_ndarray(x.data)
1647+
1648+
def test_align_exclude(self):
1649+
x = DataArray([[1, 2], [3, 4]], coords=[('a', [-1, -2]), ('b', [3, 4])])
1650+
y = DataArray([[1, 2], [3, 4]], coords=[('a', [-1, 20]), ('b', [5, 6])])
1651+
z = DataArray([1], dims=['a'], coords={'a': [20], 'b': 7})
1652+
1653+
x2, y2, z2 = align(x, y, z, join='outer', exclude=['b'])
1654+
expected_x2 = DataArray([[3, 4], [1, 2], [np.nan, np.nan]], coords=[('a', [-2, -1, 20]), ('b', [3, 4])])
1655+
expected_y2 = DataArray([[np.nan, np.nan], [1, 2], [3, 4]], coords=[('a', [-2, -1, 20]), ('b', [5, 6])])
1656+
expected_z2 = DataArray([np.nan, np.nan, 1], dims=['a'], coords={'a': [-2, -1, 20], 'b': 7})
1657+
self.assertDataArrayIdentical(expected_x2, x2)
1658+
self.assertDataArrayIdentical(expected_y2, y2)
1659+
self.assertDataArrayIdentical(expected_z2, z2)
1660+
1661+
def test_align_indexes(self):
1662+
x = DataArray([1, 2, 3], coords=[('a', [-1, 10, -2])])
1663+
y = DataArray([1, 2], coords=[('a', [-2, -1])])
1664+
1665+
x2, y2 = align(x, y, join='outer', indexes={'a': [10, -1, -2]})
1666+
expected_x2 = DataArray([2, 1, 3], coords=[('a', [10, -1, -2])])
1667+
expected_y2 = DataArray([np.nan, 2, 1], coords=[('a', [10, -1, -2])])
1668+
self.assertDataArrayIdentical(expected_x2, x2)
1669+
self.assertDataArrayIdentical(expected_y2, y2)
1670+
1671+
x2, = align(x, join='outer', indexes={'a': [-2, 7, 10, -1]})
1672+
expected_x2 = DataArray([3, np.nan, 2, 1], coords=[('a', [-2, 7, 10, -1])])
1673+
self.assertDataArrayIdentical(expected_x2, x2)
1674+
16211675
def test_broadcast_arrays(self):
16221676
x = DataArray([1, 2], coords=[('a', [-1, -2])], name='x')
16231677
y = DataArray([1, 2], coords=[('b', [3, 4])], name='y')
@@ -1636,9 +1690,45 @@ def test_broadcast_arrays(self):
16361690
self.assertDataArrayIdentical(expected_x2, x2)
16371691
self.assertDataArrayIdentical(expected_y2, y2)
16381692

1639-
z = DataArray([1, 2], coords=[('a', [-10, 20])])
1640-
with self.assertRaisesRegexp(ValueError, 'cannot broadcast'):
1641-
broadcast(x, z)
1693+
def test_broadcast_arrays_misaligned(self):
1694+
# broadcast on misaligned coords must auto-align
1695+
x = DataArray([[1, 2], [3, 4]], coords=[('a', [-1, -2]), ('b', [3, 4])])
1696+
y = DataArray([1, 2], coords=[('a', [-1, 20])])
1697+
expected_x2 = DataArray([[3, 4], [1, 2], [np.nan, np.nan]], coords=[('a', [-2, -1, 20]), ('b', [3, 4])])
1698+
expected_y2 = DataArray([[np.nan, np.nan], [1, 1], [2, 2]], coords=[('a', [-2, -1, 20]), ('b', [3, 4])])
1699+
x2, y2 = broadcast(x, y)
1700+
self.assertDataArrayIdentical(expected_x2, x2)
1701+
self.assertDataArrayIdentical(expected_y2, y2)
1702+
1703+
def test_broadcast_arrays_nocopy(self):
1704+
# Test that input data is not copied over in case no alteration is needed
1705+
x = DataArray([1, 2], coords=[('a', [-1, -2])], name='x')
1706+
y = DataArray(3, name='y')
1707+
expected_x2 = DataArray([1, 2], coords=[('a', [-1, -2])], name='x')
1708+
expected_y2 = DataArray([3, 3], coords=[('a', [-1, -2])], name='y')
1709+
1710+
x2, y2 = broadcast(x, y)
1711+
self.assertDataArrayIdentical(expected_x2, x2)
1712+
self.assertDataArrayIdentical(expected_y2, y2)
1713+
assert source_ndarray(x2.data) is source_ndarray(x.data)
1714+
1715+
# single-element broadcast (trivial case)
1716+
x2, = broadcast(x)
1717+
self.assertDataArrayIdentical(x, x2)
1718+
assert source_ndarray(x2.data) is source_ndarray(x.data)
1719+
1720+
def test_broadcast_arrays_exclude(self):
1721+
x = DataArray([[1, 2], [3, 4]], coords=[('a', [-1, -2]), ('b', [3, 4])])
1722+
y = DataArray([1, 2], coords=[('a', [-1, 20])])
1723+
z = DataArray(5, coords={'b': 5})
1724+
1725+
x2, y2, z2 = broadcast(x, y, z, exclude=['b'])
1726+
expected_x2 = DataArray([[3, 4], [1, 2], [np.nan, np.nan]], coords=[('a', [-2, -1, 20]), ('b', [3, 4])])
1727+
expected_y2 = DataArray([np.nan, 1, 2], coords=[('a', [-2, -1, 20])])
1728+
expected_z2 = DataArray([5, 5, 5], dims=['a'], coords={'a': [-2, -1, 20], 'b': 5})
1729+
self.assertDataArrayIdentical(expected_x2, x2)
1730+
self.assertDataArrayIdentical(expected_y2, y2)
1731+
self.assertDataArrayIdentical(expected_z2, z2)
16421732

16431733
def test_broadcast_coordinates(self):
16441734
# regression test for GH649

0 commit comments

Comments
 (0)