Skip to content

pass dask compute/persist args through from load/compute/perist #1543

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Sep 5, 2017
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 33 additions & 6 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,22 +565,31 @@ def reset_coords(self, names=None, drop=False, inplace=False):
dataset[self.name] = self.variable
return dataset

def load(self):
def load(self, **kwargs):
"""Manually trigger loading of this array's data from disk or a
remote source into memory and return this array.

Normally, it should not be necessary to call this method in user code,
because all xarray functions should either work on deferred data or
load data automatically. However, this method can be necessary when
working with many file objects on disk.

Parameters
----------
**kwargs : dict
Additional keyword arguments passed on to ``dask.array.compute``.

See Also
--------
dask.array.compute
"""
ds = self._to_temp_dataset().load()
ds = self._to_temp_dataset().load(**kwargs)
new = self._from_temp_dataset(ds)
self._variable = new._variable
self._coords = new._coords
return self

def compute(self):
def compute(self, **kwargs):
"""Manually trigger loading of this array's data from disk or a
remote source into memory and return a new array. The original is
left unaltered.
Expand All @@ -589,18 +598,36 @@ def compute(self):
because all xarray functions should either work on deferred data or
load data automatically. However, this method can be necessary when
working with many file objects on disk.

Parameters
----------
**kwargs : dict
Additional keyword arguments passed on to ``dask.array.compute``.

See Also
--------
dask.array.compute
"""
new = self.copy(deep=False)
return new.load()
return new.load(**kwargs)

def persist(self):
def persist(self, **kwargs):
""" Trigger computation in constituent dask arrays

This keeps them as dask arrays but encourages them to keep data in
memory. This is particularly useful when on a distributed machine.
When on a single machine consider using ``.compute()`` instead.

Parameters
----------
**kwargs : dict
Additional keyword arguments passed on to ``dask.persist``.

See Also
--------
dask.persist
"""
ds = self._to_temp_dataset().persist()
ds = self._to_temp_dataset().persist(**kwargs)
return self._from_temp_dataset(ds)

def copy(self, deep=True):
Expand Down
43 changes: 35 additions & 8 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,14 +445,23 @@ def sizes(self):
"""
return self.dims

def load(self):
def load(self, **kwargs):
"""Manually trigger loading of this dataset's data from disk or a
remote source into memory and return this dataset.

Normally, it should not be necessary to call this method in user code,
because all xarray functions should either work on deferred data or
load data automatically. However, this method can be necessary when
working with many file objects on disk.

Parameters
----------
**kwargs : dict
Additional keyword arguments passed on to ``dask.array.compute``.

See Also
--------
dask.array.compute
"""
# access .data to coerce everything to numpy or dask arrays
lazy_data = {k: v._data for k, v in self.variables.items()
Expand All @@ -461,7 +470,7 @@ def load(self):
import dask.array as da

# evaluate all the dask arrays simultaneously
evaluated_data = da.compute(*lazy_data.values())
evaluated_data = da.compute(*lazy_data.values(), **kwargs)

for k, data in zip(lazy_data, evaluated_data):
self.variables[k].data = data
Expand All @@ -473,7 +482,7 @@ def load(self):

return self

def compute(self):
def compute(self, **kwargs):
"""Manually trigger loading of this dataset's data from disk or a
remote source into memory and return a new dataset. The original is
left unaltered.
Expand All @@ -482,11 +491,20 @@ def compute(self):
because all xarray functions should either work on deferred data or
load data automatically. However, this method can be necessary when
working with many file objects on disk.

Parameters
----------
**kwargs : dict
Additional keyword arguments passed on to ``dask.array.compute``.

See Also
--------
dask.array.compute
"""
new = self.copy(deep=False)
return new.load()
return new.load(**kwargs)

def _persist_inplace(self):
def _persist_inplace(self, **kwargs):
""" Persist all Dask arrays in memory """
# access .data to coerce everything to numpy or dask arrays
lazy_data = {k: v._data for k, v in self.variables.items()
Expand All @@ -495,24 +513,33 @@ def _persist_inplace(self):
import dask

# evaluate all the dask arrays simultaneously
evaluated_data = dask.persist(*lazy_data.values())
evaluated_data = dask.persist(*lazy_data.values(), **kwargs)

for k, data in zip(lazy_data, evaluated_data):
self.variables[k].data = data

return self

def persist(self):
def persist(self, **kwargs):
""" Trigger computation, keeping data as dask arrays

This operation can be used to trigger computation on underlying dask
arrays, similar to ``.compute()``. However this operation keeps the
data as dask arrays. This is particularly useful when using the
dask.distributed scheduler and you want to load a large amount of data
into distributed memory.

Parameters
----------
**kwargs : dict
Additional keyword arguments passed on to ``dask.persist``.

See Also
--------
dask.persist
"""
new = self.copy(deep=False)
return new._persist_inplace()
return new._persist_inplace(**kwargs)

@classmethod
def _construct_direct(cls, variables, coord_names, dims=None, attrs=None,
Expand Down
28 changes: 24 additions & 4 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,29 +307,49 @@ def data(self, data):
def _indexable_data(self):
return orthogonally_indexable(self._data)

def load(self):
def load(self, **kwargs):
"""Manually trigger loading of this variable's data from disk or a
remote source into memory and return this variable.

Normally, it should not be necessary to call this method in user code,
because all xarray functions should either work on deferred data or
load data automatically.

Parameters
----------
**kwargs : dict
Additional keyword arguments passed on to ``dask.array.compute``.

See Also
--------
dask.array.compute
"""
if not isinstance(self._data, np.ndarray):
if isinstance(self._data, dask_array_type):
self._data = np.asarray(self._data.compute(**kwargs))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't want to invoke asarray if dask returns a scalar numpy type

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be honest, we don't define what can go in _data as carefully as we ought to. I guess there are two ways to define it:

  • anything "array like" that defines at least shape, dtype and __getitem__
  • what comes out of xarray.core.variable.as_compatible_data

Numpy scalars do actually pass through here (since they define all those attributes!)... but then would get converted into an array when calling .values anyways:

@property
def values(self):
"""The variable's data as a numpy.ndarray"""
return _as_array_or_item(self._data)

So I guess I agree, but on the other hand I'm also a little nervous that a dask routine might return a non-numpy scalar, which would definitely break if we don't wrap it in asarray. The safe thing to do is to leave this as is or call as_compatible_data on it.

elif not isinstance(self._data, np.ndarray):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This clause should be removed as it causes inconsistent behaviour with numpy scalar types. I cannot think of any other use case where data is neither a dask ARRAY nor a numpy ndarray?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this allows for on-disk type arrays. @shoyer, any thoughts on calling np.asarray here and in the line above?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we need this to support on-disk arrays that aren't backed by dask. (I'd love to get rid of this in favor of always using dask, but dask has some limitations that make this tricky.)

self._data = np.asarray(self._data)
return self

def compute(self):
def compute(self, **kwargs):
"""Manually trigger loading of this variable's data from disk or a
remote source into memory and return a new variable. The original is
left unaltered.

Normally, it should not be necessary to call this method in user code,
because all xarray functions should either work on deferred data or
load data automatically.

Parameters
----------
**kwargs : dict
Additional keyword arguments passed on to ``dask.array.compute``.

See Also
--------
dask.array.compute
"""
new = self.copy(deep=False)
return new.load()
return new.load(**kwargs)

@property
def values(self):
Expand Down
50 changes: 49 additions & 1 deletion xarray/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
from xarray.core.pycompat import suppress
from . import TestCase, requires_dask

from xarray.tests import unittest
from xarray.tests import unittest, assert_equal

with suppress(ImportError):
import dask
import dask.array as da
import dask.multiprocessing
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No longer used



class DaskTestCase(TestCase):
Expand Down Expand Up @@ -182,6 +183,26 @@ def test_bivariate_ufunc(self):
self.assertLazyAndAllClose(np.maximum(u, 0), xu.maximum(v, 0))
self.assertLazyAndAllClose(np.maximum(u, 0), xu.maximum(0, v))

def test_compute_args(self):
a = DataArray([1, 2]).chunk()
expected = DataArray([1, 4])
b = a * a
# compute
b1 = b.compute(get=dask.multiprocessing.get)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The clean way to test this is probably with mock, e.g.,

In [63]: import numpy as np

In [64]: import dask.array as da

In [65]: from unittest import mock

In [66]: x = da.from_array(np.arange(3), chunks=(2,))

In [67]: with mock.patch.object(da.Array, 'compute', return_value=np.arange(3)) as mock_compute:
    ...:     x.compute(foo='bar')
    ...:

In [68]: mock_compute.assert_called_with(foo='bar')

In [69]: mock_compute.assert_called_with(bar='foo')
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-69-22937cf26ca1> in <module>()
----> 1 mock_compute.assert_called_with(bar='foo')

~/conda/envs/xarray-dev/lib/python3.5/unittest/mock.py in assert_called_with(_mock_self, *args, **kwargs)
    792         if expected != actual:
    793             cause = expected if isinstance(expected, Exception) else None
--> 794             raise AssertionError(_error_message()) from cause
    795
    796

AssertionError: Expected call: compute(bar='foo')
Actual call: compute(foo='bar')

unittest.mock is part of Python 3's standard library, but there's also a widely used Python 2 backport on pypi. I think it would be perfectly fine to add it as a dependency for our test suite.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @shoyer, I'll add mock as part of this PR.

assert b1._in_memory
assert_equal(b1, expected)
b2 = b.compute(get=dask.multiprocessing.get, num_workers=4)
assert b2._in_memory
assert_equal(b2, expected)
# load
b3 = b.load(get=dask.multiprocessing.get, num_workers=4)
assert b3._in_memory
assert_equal(b3, expected)
# persist
b4 = b.persist(get=dask.multiprocessing.get, num_workers=4)
assert b4._in_memory
assert_equal(b4, expected)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Redundant with the test below?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed, these should have been Variables.


@requires_dask
class TestDataArrayAndDataset(DaskTestCase):
Expand Down Expand Up @@ -393,6 +414,32 @@ def test_from_dask_variable(self):
coords={'x': range(4)}, name='foo')
self.assertLazyAndIdentical(self.lazy_array, a)

def test_compute_args(self):
a = DataArray([1, 2], name='a').chunk()
expected = DataArray([1, 4], name='expected')
b = a * a
# compute
b1 = b.compute(get=dask.multiprocessing.get)
assert b1._in_memory
assert_equal(b1, expected)
b2 = b.compute(get=dask.multiprocessing.get, num_workers=4)
assert b2._in_memory
assert_equal(b2, expected)
# load
b3 = b.load(get=dask.multiprocessing.get, num_workers=4)
assert b3._in_memory
assert_equal(b3, expected)
# persist
b4 = b.persist(get=dask.multiprocessing.get, num_workers=4)
assert b4._in_memory
assert_equal(b4, expected)

# dataset
ds = a.to_dataset()
ds.compute(get=dask.multiprocessing.get, num_workers=4)
ds.load(get=dask.multiprocessing.get, num_workers=4)
ds.persist(get=dask.multiprocessing.get, num_workers=4)


kernel_call_count = 0
def kernel():
Expand All @@ -403,6 +450,7 @@ def kernel():
kernel_call_count += 1
return np.ones(1)


def build_dask_array():
global kernel_call_count
kernel_call_count = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing tests for Variable

Expand Down