Skip to content

pass dask compute/persist args through from load/compute/perist #1543

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Sep 5, 2017
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ci/requirements-py27-cdat+pynio.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ dependencies:
- pathlib2
- pynio
- pytest
- mock
- scipy
- seaborn
- toolz
Expand Down
1 change: 1 addition & 0 deletions ci/requirements-py27-min.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ name: test_env
dependencies:
- python=2.7
- pytest
- mock
- numpy==1.11
- pandas==0.18.0
- pip:
Expand Down
1 change: 1 addition & 0 deletions ci/requirements-py27-windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ dependencies:
- netcdf4
- pathlib2
- pytest
- mock
- numpy
- pandas
- scipy
Expand Down
1 change: 1 addition & 0 deletions ci/requirements-py34.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ dependencies:
- python=3.4
- bottleneck
- pytest
- mock
- pandas
- pip:
- coveralls
Expand Down
1 change: 1 addition & 0 deletions ci/requirements-py35.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ dependencies:
- matplotlib
- netcdf4
- pytest
- mock
- numpy
- pandas
- scipy
Expand Down
1 change: 1 addition & 0 deletions ci/requirements-py36-bottleneck-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ dependencies:
- matplotlib
- netcdf4
- pytest
- mock
- numpy
- pandas
- scipy
Expand Down
1 change: 1 addition & 0 deletions ci/requirements-py36-condaforge-rc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ dependencies:
- matplotlib
- netcdf4
- pytest
- mock
- numpy
- pandas
- seaborn
Expand Down
1 change: 1 addition & 0 deletions ci/requirements-py36-dask-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ dependencies:
- matplotlib
- netcdf4
- pytest
- mock
- numpy
- pandas
- seaborn
Expand Down
1 change: 1 addition & 0 deletions ci/requirements-py36-netcdf4-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ dependencies:
- h5netcdf
- matplotlib
- pytest
- mock
- numpy
- pandas
- scipy
Expand Down
1 change: 1 addition & 0 deletions ci/requirements-py36-pandas-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ dependencies:
- matplotlib
- netcdf4
- pytest
- mock
- numpy
- scipy
- toolz
Expand Down
1 change: 1 addition & 0 deletions ci/requirements-py36-windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ dependencies:
- matplotlib
- netcdf4
- pytest
- mock
- numpy
- pandas
- scipy
Expand Down
1 change: 1 addition & 0 deletions ci/requirements-py36.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ dependencies:
- matplotlib
- netcdf4
- pytest
- mock
- numpy
- pandas
- scipy
Expand Down
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ Enhancements
methods. Any keyword arguments supplied to these methods are passed on to
the corresponding dask function (:issue:`1523`).
By `Joe Hamman <https://github.com/jhamman>`_.
- Encoding attributes are now preserved when xarray objects are concatenated.
The encoding is copied from the first object (:issue:`1297`).
By `Joe Hamman <https://github.com/jhamman>`_ and
`Gerrit Holl <https://github.com/gerritholl`_.

Bug fixes
~~~~~~~~~
Expand Down
95 changes: 54 additions & 41 deletions xarray/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import pickle
import numpy as np
import pandas as pd
import pytest
import mock
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On Python 3, use from unittest import mock to avoid adding the new dependency.


import xarray as xr
from xarray import Variable, DataArray, Dataset
Expand Down Expand Up @@ -183,22 +185,6 @@ def test_bivariate_ufunc(self):
self.assertLazyAndAllClose(np.maximum(u, 0), xu.maximum(v, 0))
self.assertLazyAndAllClose(np.maximum(u, 0), xu.maximum(0, v))

def test_compute_args(self):
a = Variable('x', [1, 2]).chunk()
expected = Variable('x', [1, 4])
b = a * a
# compute
b1 = b.compute(get=dask.multiprocessing.get)
assert b1._in_memory
assert_equal(b1, expected)
b2 = b.compute(get=dask.multiprocessing.get, num_workers=4)
assert b2._in_memory
assert_equal(b2, expected)
# load
b3 = b.load(get=dask.multiprocessing.get, num_workers=4)
assert b3._in_memory
assert_equal(b3, expected)


@requires_dask
class TestDataArrayAndDataset(DaskTestCase):
Expand Down Expand Up @@ -410,31 +396,58 @@ def test_from_dask_variable(self):
coords={'x': range(4)}, name='foo')
self.assertLazyAndIdentical(self.lazy_array, a)

def test_compute_args(self):
a = DataArray([1, 2], name='a').chunk()
expected = DataArray([1, 4], name='expected')
b = a * a
# compute
b1 = b.compute(get=dask.multiprocessing.get)
assert b1._in_memory
assert_equal(b1, expected)
b2 = b.compute(get=dask.multiprocessing.get, num_workers=4)
assert b2._in_memory
assert_equal(b2, expected)
# load
b3 = b.load(get=dask.multiprocessing.get, num_workers=4)
assert b3._in_memory
assert_equal(b3, expected)
# persist
b4 = b.persist(get=dask.multiprocessing.get, num_workers=4)
assert b4._in_memory
assert_equal(b4, expected)

# dataset
ds = a.to_dataset()
ds.compute(get=dask.multiprocessing.get, num_workers=4)
ds.load(get=dask.multiprocessing.get, num_workers=4)
ds.persist(get=dask.multiprocessing.get, num_workers=4)

@pytest.mark.parametrize("method", ['load', 'compute'])
def test_dask_kwargs_variable(method):
x = Variable('y', da.from_array(np.arange(3), chunks=(2,)))
with mock.patch.object(Variable, method,
return_value=np.arange(3)) as mock_method:
getattr(x, method)(foo='bar')
mock_method.assert_called_with(foo='bar')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can skip this first test. This verifies that Variable.compute() was called if you directly call Variable.compute().


# args should be passed on to da.Array.compute()
with mock.patch.object(da.Array, 'compute',
return_value=np.arange(3)) as mock_compute:
getattr(x, method)(foo='bar')
mock_compute.assert_called_with(foo='bar')


@pytest.mark.parametrize("method", ['load', 'compute', 'persist'])
def test_dask_kwargs_dataarray(method):
data = da.from_array(np.arange(3), chunks=(2,))
x = DataArray(data)
with mock.patch.object(DataArray, method,
return_value=np.arange(3)) as mock_method:
getattr(x, method)(foo='bar')
mock_method.assert_called_with(foo='bar')

if method in ['load', 'compute']:
dask_func = 'dask.array.compute'
else:
dask_func = 'dask.persist'
# args should be passed on to "dask_func"
with mock.patch(dask_func) as mock_func:
getattr(x, method)(foo='bar')
mock_func.assert_called_with(data, foo='bar')


@pytest.mark.parametrize("method", ['load', 'compute', 'persist'])
def test_dask_kwargs_dataset(method):
data = da.from_array(np.arange(3), chunks=(2,))
x = Dataset({'x': (('y'), data)})
with mock.patch.object(Dataset, method,
return_value=np.arange(3)) as mock_method:
getattr(x, method)(foo='bar')
mock_method.assert_called_with(foo='bar')

if method in ['load', 'compute']:
dask_func = 'dask.array.compute'
else:
dask_func = 'dask.persist'
# args should be passed on to "dask_func"
with mock.patch(dask_func) as mock_func:
getattr(x, method)(foo='bar')
mock_func.assert_called_with(data, foo='bar')


kernel_call_count = 0
Expand Down