-
-
Notifications
You must be signed in to change notification settings - Fork 1.2k
pass dask compute/persist args through from load/compute/perist #1543
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
7f0a3b8
df4f5d9
a186e70
490784a
a7af62a
2b506c5
a2dbe26
6fd941f
b5cc3bb
a879214
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -307,29 +307,49 @@ def data(self, data): | |
def _indexable_data(self): | ||
return orthogonally_indexable(self._data) | ||
|
||
def load(self): | ||
def load(self, **kwargs): | ||
"""Manually trigger loading of this variable's data from disk or a | ||
remote source into memory and return this variable. | ||
|
||
Normally, it should not be necessary to call this method in user code, | ||
because all xarray functions should either work on deferred data or | ||
load data automatically. | ||
|
||
Parameters | ||
---------- | ||
**kwargs : dict | ||
Additional keyword arguments passed on to ``dask.array.compute``. | ||
|
||
See Also | ||
-------- | ||
dask.array.compute | ||
""" | ||
if not isinstance(self._data, np.ndarray): | ||
if isinstance(self._data, dask_array_type): | ||
self._data = np.asarray(self._data.compute(**kwargs)) | ||
elif not isinstance(self._data, np.ndarray): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This clause should be removed as it causes inconsistent behaviour with numpy scalar types. I cannot think of any other use case where data is neither a dask ARRAY nor a numpy ndarray? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this allows for on-disk type arrays. @shoyer, any thoughts on calling There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, we need this to support on-disk arrays that aren't backed by dask. (I'd love to get rid of this in favor of always using dask, but dask has some limitations that make this tricky.) |
||
self._data = np.asarray(self._data) | ||
return self | ||
|
||
def compute(self): | ||
def compute(self, **kwargs): | ||
"""Manually trigger loading of this variable's data from disk or a | ||
remote source into memory and return a new variable. The original is | ||
left unaltered. | ||
|
||
Normally, it should not be necessary to call this method in user code, | ||
because all xarray functions should either work on deferred data or | ||
load data automatically. | ||
|
||
Parameters | ||
---------- | ||
**kwargs : dict | ||
Additional keyword arguments passed on to ``dask.array.compute``. | ||
|
||
See Also | ||
-------- | ||
dask.array.compute | ||
""" | ||
new = self.copy(deep=False) | ||
return new.load() | ||
return new.load(**kwargs) | ||
|
||
@property | ||
def values(self): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,11 +11,12 @@ | |
from xarray.core.pycompat import suppress | ||
from . import TestCase, requires_dask | ||
|
||
from xarray.tests import unittest | ||
from xarray.tests import unittest, assert_equal | ||
|
||
with suppress(ImportError): | ||
import dask | ||
import dask.array as da | ||
import dask.multiprocessing | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No longer used |
||
|
||
|
||
class DaskTestCase(TestCase): | ||
|
@@ -182,6 +183,26 @@ def test_bivariate_ufunc(self): | |
self.assertLazyAndAllClose(np.maximum(u, 0), xu.maximum(v, 0)) | ||
self.assertLazyAndAllClose(np.maximum(u, 0), xu.maximum(0, v)) | ||
|
||
def test_compute_args(self): | ||
a = DataArray([1, 2]).chunk() | ||
expected = DataArray([1, 4]) | ||
b = a * a | ||
# compute | ||
b1 = b.compute(get=dask.multiprocessing.get) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The clean way to test this is probably with mock, e.g.,
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks @shoyer, I'll add mock as part of this PR. |
||
assert b1._in_memory | ||
assert_equal(b1, expected) | ||
b2 = b.compute(get=dask.multiprocessing.get, num_workers=4) | ||
assert b2._in_memory | ||
assert_equal(b2, expected) | ||
# load | ||
b3 = b.load(get=dask.multiprocessing.get, num_workers=4) | ||
assert b3._in_memory | ||
assert_equal(b3, expected) | ||
# persist | ||
b4 = b.persist(get=dask.multiprocessing.get, num_workers=4) | ||
assert b4._in_memory | ||
assert_equal(b4, expected) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Redundant with the test below? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed, these should have been |
||
|
||
@requires_dask | ||
class TestDataArrayAndDataset(DaskTestCase): | ||
|
@@ -393,6 +414,32 @@ def test_from_dask_variable(self): | |
coords={'x': range(4)}, name='foo') | ||
self.assertLazyAndIdentical(self.lazy_array, a) | ||
|
||
def test_compute_args(self): | ||
a = DataArray([1, 2], name='a').chunk() | ||
expected = DataArray([1, 4], name='expected') | ||
b = a * a | ||
# compute | ||
b1 = b.compute(get=dask.multiprocessing.get) | ||
assert b1._in_memory | ||
assert_equal(b1, expected) | ||
b2 = b.compute(get=dask.multiprocessing.get, num_workers=4) | ||
assert b2._in_memory | ||
assert_equal(b2, expected) | ||
# load | ||
b3 = b.load(get=dask.multiprocessing.get, num_workers=4) | ||
assert b3._in_memory | ||
assert_equal(b3, expected) | ||
# persist | ||
b4 = b.persist(get=dask.multiprocessing.get, num_workers=4) | ||
assert b4._in_memory | ||
assert_equal(b4, expected) | ||
|
||
# dataset | ||
ds = a.to_dataset() | ||
ds.compute(get=dask.multiprocessing.get, num_workers=4) | ||
ds.load(get=dask.multiprocessing.get, num_workers=4) | ||
ds.persist(get=dask.multiprocessing.get, num_workers=4) | ||
|
||
|
||
kernel_call_count = 0 | ||
def kernel(): | ||
|
@@ -403,6 +450,7 @@ def kernel(): | |
kernel_call_count += 1 | ||
return np.ones(1) | ||
|
||
|
||
def build_dask_array(): | ||
global kernel_call_count | ||
kernel_call_count = 0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing tests for Variable |
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't want to invoke asarray if dask returns a scalar numpy type
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be honest, we don't define what can go in
_data
as carefully as we ought to. I guess there are two ways to define it:shape
,dtype
and__getitem__
xarray.core.variable.as_compatible_data
Numpy scalars do actually pass through here (since they define all those attributes!)... but then would get converted into an array when calling
.values
anyways:xarray/xarray/core/variable.py
Lines 334 to 337 in 78ca20a
So I guess I agree, but on the other hand I'm also a little nervous that a dask routine might return a non-numpy scalar, which would definitely break if we don't wrap it in
asarray
. The safe thing to do is to leave this as is or callas_compatible_data
on it.