Skip to content

Commit 14b5f1c

Browse files
crusaderkyJoe Hamman
authored and
Joe Hamman
committed
Load nonindex coords ahead of concat() (#1551)
* Load non-index coords to memory ahead of concat * Update unit test after #1522 * Minimise loads on concat. Extend new concat logic to data_vars. * Trivial tweaks * Added unit tests Fix loads when vars are found different halfway through * Add xfail for #1586 * Revert "Add xfail for #1586" This reverts commit f99313c.
1 parent 772f7e0 commit 14b5f1c

File tree

3 files changed

+153
-43
lines changed

3 files changed

+153
-43
lines changed

doc/whats-new.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,13 @@ Bug fixes
189189
``rtol`` arguments when called on ``DataArray`` objects.
190190
By `Stephan Hoyer <https://github.com/shoyer>`_.
191191

192+
- :py:func:`~xarray.concat` was computing variables that aren't in memory
193+
(e.g. dask-based) multiple times; :py:func:`~xarray.open_mfdataset`
194+
was loading them multiple times from disk. Now, both functions will instead
195+
load them at most once and, if they do, store them in memory in the
196+
concatenated array/dataset (:issue:`1521`).
197+
By `Guido Imperiale <https://github.com/crusaderky>`_.
198+
192199
- xarray ``quantile`` methods now properly raise a ``TypeError`` when applied to
193200
objects with data stored as ``dask`` arrays (:issue:`1529`).
194201
By `Joe Hamman <https://github.com/jhamman>`_.

xarray/core/combine.py

Lines changed: 72 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -148,68 +148,85 @@ def _calc_concat_over(datasets, dim, data_vars, coords):
148148
Determine which dataset variables need to be concatenated in the result,
149149
and which can simply be taken from the first dataset.
150150
"""
151-
def process_subset_opt(opt, subset):
152-
if subset == 'coords':
153-
subset_long_name = 'coordinates'
154-
else:
155-
subset_long_name = 'data variables'
151+
# Return values
152+
concat_over = set()
153+
equals = {}
154+
155+
if dim in datasets[0]:
156+
concat_over.add(dim)
157+
for ds in datasets:
158+
concat_over.update(k for k, v in ds.variables.items()
159+
if dim in v.dims)
156160

161+
def process_subset_opt(opt, subset):
157162
if isinstance(opt, basestring):
158163
if opt == 'different':
159-
def differs(vname):
160-
# simple helper function which compares a variable
161-
# across all datasets and indicates whether that
162-
# variable differs or not.
163-
v = datasets[0].variables[vname]
164-
return any(not ds.variables[vname].equals(v)
165-
for ds in datasets[1:])
166164
# all nonindexes that are not the same in each dataset
167-
concat_new = set(k for k in getattr(datasets[0], subset)
168-
if k not in concat_over and differs(k))
165+
for k in getattr(datasets[0], subset):
166+
if k not in concat_over:
167+
# Compare the variable of all datasets vs. the one
168+
# of the first dataset. Perform the minimum amount of
169+
# loads in order to avoid multiple loads from disk while
170+
# keeping the RAM footprint low.
171+
v_lhs = datasets[0].variables[k].load()
172+
# We'll need to know later on if variables are equal.
173+
computed = []
174+
for ds_rhs in datasets[1:]:
175+
v_rhs = ds_rhs.variables[k].compute()
176+
computed.append(v_rhs)
177+
if not v_lhs.equals(v_rhs):
178+
concat_over.add(k)
179+
equals[k] = False
180+
# computed variables are not to be re-computed
181+
# again in the future
182+
for ds, v in zip(datasets[1:], computed):
183+
ds.variables[k].data = v.data
184+
break
185+
else:
186+
equals[k] = True
187+
169188
elif opt == 'all':
170-
concat_new = (set(getattr(datasets[0], subset)) -
171-
set(datasets[0].dims))
189+
concat_over.update(set(getattr(datasets[0], subset)) -
190+
set(datasets[0].dims))
172191
elif opt == 'minimal':
173-
concat_new = set()
192+
pass
174193
else:
175-
raise ValueError("unexpected value for concat_%s: %s"
176-
% (subset, opt))
194+
raise ValueError("unexpected value for %s: %s" % (subset, opt))
177195
else:
178196
invalid_vars = [k for k in opt
179197
if k not in getattr(datasets[0], subset)]
180198
if invalid_vars:
181-
raise ValueError('some variables in %s are not '
182-
'%s on the first dataset: %s'
183-
% (subset, subset_long_name, invalid_vars))
184-
concat_new = set(opt)
185-
return concat_new
199+
if subset == 'coords':
200+
raise ValueError(
201+
'some variables in coords are not coordinates on '
202+
'the first dataset: %s' % invalid_vars)
203+
else:
204+
raise ValueError(
205+
'some variables in data_vars are not data variables on '
206+
'the first dataset: %s' % invalid_vars)
207+
concat_over.update(opt)
186208

187-
concat_over = set()
188-
for ds in datasets:
189-
concat_over.update(k for k, v in ds.variables.items()
190-
if dim in v.dims)
191-
concat_over.update(process_subset_opt(data_vars, 'data_vars'))
192-
concat_over.update(process_subset_opt(coords, 'coords'))
193-
if dim in datasets[0]:
194-
concat_over.add(dim)
195-
return concat_over
209+
process_subset_opt(data_vars, 'data_vars')
210+
process_subset_opt(coords, 'coords')
211+
return concat_over, equals
196212

197213

198214
def _dataset_concat(datasets, dim, data_vars, coords, compat, positions):
199215
"""
200216
Concatenate a sequence of datasets along a new or existing dimension
201217
"""
202-
from .dataset import Dataset, as_dataset
218+
from .dataset import Dataset
203219

204220
if compat not in ['equals', 'identical']:
205221
raise ValueError("compat=%r invalid: must be 'equals' "
206222
"or 'identical'" % compat)
207223

208224
dim, coord = _calc_concat_dim_coord(dim)
209-
datasets = [as_dataset(ds) for ds in datasets]
225+
# Make sure we're working on a copy (we'll be loading variables)
226+
datasets = [ds.copy() for ds in datasets]
210227
datasets = align(*datasets, join='outer', copy=False, exclude=[dim])
211228

212-
concat_over = _calc_concat_over(datasets, dim, data_vars, coords)
229+
concat_over, equals = _calc_concat_over(datasets, dim, data_vars, coords)
213230

214231
def insert_result_variable(k, v):
215232
assert isinstance(v, Variable)
@@ -239,11 +256,25 @@ def insert_result_variable(k, v):
239256
elif (k in result_coord_names) != (k in ds.coords):
240257
raise ValueError('%r is a coordinate in some datasets but not '
241258
'others' % k)
242-
elif (k in result_vars and k != dim and
243-
not getattr(v, compat)(result_vars[k])):
244-
verb = 'equal' if compat == 'equals' else compat
245-
raise ValueError(
246-
'variable %r not %s across datasets' % (k, verb))
259+
elif k in result_vars and k != dim:
260+
# Don't use Variable.identical as it internally invokes
261+
# Variable.equals, and we may already know the answer
262+
if compat == 'identical' and not utils.dict_equiv(
263+
v.attrs, result_vars[k].attrs):
264+
raise ValueError(
265+
'variable %s not identical across datasets' % k)
266+
267+
# Proceed with equals()
268+
try:
269+
# May be populated when using the "different" method
270+
is_equal = equals[k]
271+
except KeyError:
272+
result_vars[k].load()
273+
is_equal = v.equals(result_vars[k])
274+
if not is_equal:
275+
raise ValueError(
276+
'variable %s not equal across datasets' % k)
277+
247278

248279
# we've already verified everything is consistent; now, calculate
249280
# shared dimension sizes so we can expand the necessary variables

xarray/tests/test_dask.py

Lines changed: 74 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,77 @@ def test_lazy_array(self):
250250
actual = xr.concat([v[:2], v[2:]], 'x')
251251
self.assertLazyAndAllClose(u, actual)
252252

253+
def test_concat_loads_variables(self):
254+
# Test that concat() computes not-in-memory variables at most once
255+
# and loads them in the output, while leaving the input unaltered.
256+
d1 = build_dask_array('d1')
257+
c1 = build_dask_array('c1')
258+
d2 = build_dask_array('d2')
259+
c2 = build_dask_array('c2')
260+
d3 = build_dask_array('d3')
261+
c3 = build_dask_array('c3')
262+
# Note: c is a non-index coord.
263+
# Index coords are loaded by IndexVariable.__init__.
264+
ds1 = Dataset(data_vars={'d': ('x', d1)}, coords={'c': ('x', c1)})
265+
ds2 = Dataset(data_vars={'d': ('x', d2)}, coords={'c': ('x', c2)})
266+
ds3 = Dataset(data_vars={'d': ('x', d3)}, coords={'c': ('x', c3)})
267+
268+
assert kernel_call_count == 0
269+
out = xr.concat([ds1, ds2, ds3], dim='n', data_vars='different',
270+
coords='different')
271+
# each kernel is computed exactly once
272+
assert kernel_call_count == 6
273+
# variables are loaded in the output
274+
assert isinstance(out['d'].data, np.ndarray)
275+
assert isinstance(out['c'].data, np.ndarray)
276+
277+
out = xr.concat([ds1, ds2, ds3], dim='n', data_vars='all', coords='all')
278+
# no extra kernel calls
279+
assert kernel_call_count == 6
280+
assert isinstance(out['d'].data, dask.array.Array)
281+
assert isinstance(out['c'].data, dask.array.Array)
282+
283+
out = xr.concat([ds1, ds2, ds3], dim='n', data_vars=['d'], coords=['c'])
284+
# no extra kernel calls
285+
assert kernel_call_count == 6
286+
assert isinstance(out['d'].data, dask.array.Array)
287+
assert isinstance(out['c'].data, dask.array.Array)
288+
289+
out = xr.concat([ds1, ds2, ds3], dim='n', data_vars=[], coords=[])
290+
# variables are loaded once as we are validing that they're identical
291+
assert kernel_call_count == 12
292+
assert isinstance(out['d'].data, np.ndarray)
293+
assert isinstance(out['c'].data, np.ndarray)
294+
295+
out = xr.concat([ds1, ds2, ds3], dim='n', data_vars='different',
296+
coords='different', compat='identical')
297+
# compat=identical doesn't do any more kernel calls than compat=equals
298+
assert kernel_call_count == 18
299+
assert isinstance(out['d'].data, np.ndarray)
300+
assert isinstance(out['c'].data, np.ndarray)
301+
302+
# When the test for different turns true halfway through,
303+
# stop computing variables as it would not have any benefit
304+
ds4 = Dataset(data_vars={'d': ('x', [2.0])}, coords={'c': ('x', [2.0])})
305+
out = xr.concat([ds1, ds2, ds4, ds3], dim='n', data_vars='different',
306+
coords='different')
307+
# the variables of ds1 and ds2 were computed, but those of ds3 didn't
308+
assert kernel_call_count == 22
309+
assert isinstance(out['d'].data, dask.array.Array)
310+
assert isinstance(out['c'].data, dask.array.Array)
311+
# the data of ds1 and ds2 was loaded into numpy and then
312+
# concatenated to the data of ds3. Thus, only ds3 is computed now.
313+
out.compute()
314+
assert kernel_call_count == 24
315+
316+
# Finally, test that riginals are unaltered
317+
assert ds1['d'].data is d1
318+
assert ds1['c'].data is c1
319+
assert ds2['d'].data is d2
320+
assert ds2['c'].data is c2
321+
assert ds3['d'].data is d3
322+
assert ds3['c'].data is c3
323+
253324
def test_groupby(self):
254325
if LooseVersion(dask.__version__) == LooseVersion('0.15.3'):
255326
pytest.xfail('upstream bug in dask: '
@@ -517,10 +588,11 @@ def test_dask_kwargs_dataset(method):
517588
kernel_call_count = 0
518589

519590

520-
def kernel():
591+
def kernel(name):
521592
"""Dask kernel to test pickling/unpickling and __repr__.
522593
Must be global to make it pickleable.
523594
"""
595+
print("kernel(%s)" % name)
524596
global kernel_call_count
525597
kernel_call_count += 1
526598
return np.ones(1, dtype=np.int64)
@@ -530,5 +602,5 @@ def build_dask_array(name):
530602
global kernel_call_count
531603
kernel_call_count = 0
532604
return dask.array.Array(
533-
dask={(name, 0): (kernel, )}, name=name,
605+
dask={(name, 0): (kernel, name)}, name=name,
534606
chunks=((1,),), dtype=np.int64)

0 commit comments

Comments
 (0)