diff --git a/ci/requirements-py27-cdat+pynio.yml b/ci/requirements-py27-cdat+pynio.yml index 0258c8c9672..ccd3fbf9cb4 100644 --- a/ci/requirements-py27-cdat+pynio.yml +++ b/ci/requirements-py27-cdat+pynio.yml @@ -16,6 +16,7 @@ dependencies: - pathlib2 - pynio - pytest + - mock - scipy - seaborn - toolz diff --git a/ci/requirements-py27-min.yml b/ci/requirements-py27-min.yml index 9c7d7c5a9e9..6f63315db67 100644 --- a/ci/requirements-py27-min.yml +++ b/ci/requirements-py27-min.yml @@ -2,6 +2,7 @@ name: test_env dependencies: - python=2.7 - pytest + - mock - numpy==1.11 - pandas==0.18.0 - pip: diff --git a/ci/requirements-py27-windows.yml b/ci/requirements-py27-windows.yml index e953b5ffdcb..73baca68dfa 100644 --- a/ci/requirements-py27-windows.yml +++ b/ci/requirements-py27-windows.yml @@ -11,6 +11,7 @@ dependencies: - netcdf4 - pathlib2 - pytest + - mock - numpy - pandas - scipy diff --git a/doc/installing.rst b/doc/installing.rst index a316ef38fc5..522577a078b 100644 --- a/doc/installing.rst +++ b/doc/installing.rst @@ -73,6 +73,12 @@ pandas) installed first. Then, install xarray with pip:: $ pip install xarray -To run the test suite after installing xarray, install -`py.test `__ (``pip install pytest``) and run +Testing +------- + +To run the test suite after installing xarray, first install (via pypi or conda) +- `py.test `__: Simple unit testing library +- `mock `__: additional testing library required for python version 2 + +and run ``py.test --pyargs xarray``. diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 8f17601f650..6f0466acf92 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -97,6 +97,10 @@ Enhancements other means (:issue:`1459`). By `Ryan May `_. + - Support passing keyword arguments to ``load``, ``compute``, and ``persist`` + methods. Any keyword arguments supplied to these methods are passed on to + the corresponding dask function (:issue:`1523`). + By `Joe Hamman `_. - Encoding attributes are now preserved when xarray objects are concatenated. The encoding is copied from the first object (:issue:`1297`). By `Joe Hamman `_ and diff --git a/setup.py b/setup.py index 6ff8de60666..e157a825d07 100644 --- a/setup.py +++ b/setup.py @@ -37,6 +37,8 @@ INSTALL_REQUIRES = ['numpy >= 1.11', 'pandas >= 0.18.0'] TESTS_REQUIRE = ['pytest >= 2.7.1'] +if sys.version_info[0] < 3: + TESTS_REQUIRE.append('mock') DESCRIPTION = "N-D labeled arrays and datasets in Python" LONG_DESCRIPTION = """ diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 14e53aababf..ea7e46d8225 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -565,7 +565,7 @@ def reset_coords(self, names=None, drop=False, inplace=False): dataset[self.name] = self.variable return dataset - def load(self): + def load(self, **kwargs): """Manually trigger loading of this array's data from disk or a remote source into memory and return this array. @@ -573,14 +573,23 @@ def load(self): because all xarray functions should either work on deferred data or load data automatically. However, this method can be necessary when working with many file objects on disk. + + Parameters + ---------- + **kwargs : dict + Additional keyword arguments passed on to ``dask.array.compute``. + + See Also + -------- + dask.array.compute """ - ds = self._to_temp_dataset().load() + ds = self._to_temp_dataset().load(**kwargs) new = self._from_temp_dataset(ds) self._variable = new._variable self._coords = new._coords return self - def compute(self): + def compute(self, **kwargs): """Manually trigger loading of this array's data from disk or a remote source into memory and return a new array. The original is left unaltered. @@ -589,18 +598,36 @@ def compute(self): because all xarray functions should either work on deferred data or load data automatically. However, this method can be necessary when working with many file objects on disk. + + Parameters + ---------- + **kwargs : dict + Additional keyword arguments passed on to ``dask.array.compute``. + + See Also + -------- + dask.array.compute """ new = self.copy(deep=False) - return new.load() + return new.load(**kwargs) - def persist(self): + def persist(self, **kwargs): """ Trigger computation in constituent dask arrays This keeps them as dask arrays but encourages them to keep data in memory. This is particularly useful when on a distributed machine. When on a single machine consider using ``.compute()`` instead. + + Parameters + ---------- + **kwargs : dict + Additional keyword arguments passed on to ``dask.persist``. + + See Also + -------- + dask.persist """ - ds = self._to_temp_dataset().persist() + ds = self._to_temp_dataset().persist(**kwargs) return self._from_temp_dataset(ds) def copy(self, deep=True): diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index aa49d8a73b0..37021f865d9 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -445,7 +445,7 @@ def sizes(self): """ return self.dims - def load(self): + def load(self, **kwargs): """Manually trigger loading of this dataset's data from disk or a remote source into memory and return this dataset. @@ -453,6 +453,15 @@ def load(self): because all xarray functions should either work on deferred data or load data automatically. However, this method can be necessary when working with many file objects on disk. + + Parameters + ---------- + **kwargs : dict + Additional keyword arguments passed on to ``dask.array.compute``. + + See Also + -------- + dask.array.compute """ # access .data to coerce everything to numpy or dask arrays lazy_data = {k: v._data for k, v in self.variables.items() @@ -461,7 +470,7 @@ def load(self): import dask.array as da # evaluate all the dask arrays simultaneously - evaluated_data = da.compute(*lazy_data.values()) + evaluated_data = da.compute(*lazy_data.values(), **kwargs) for k, data in zip(lazy_data, evaluated_data): self.variables[k].data = data @@ -473,7 +482,7 @@ def load(self): return self - def compute(self): + def compute(self, **kwargs): """Manually trigger loading of this dataset's data from disk or a remote source into memory and return a new dataset. The original is left unaltered. @@ -482,11 +491,20 @@ def compute(self): because all xarray functions should either work on deferred data or load data automatically. However, this method can be necessary when working with many file objects on disk. + + Parameters + ---------- + **kwargs : dict + Additional keyword arguments passed on to ``dask.array.compute``. + + See Also + -------- + dask.array.compute """ new = self.copy(deep=False) - return new.load() + return new.load(**kwargs) - def _persist_inplace(self): + def _persist_inplace(self, **kwargs): """ Persist all Dask arrays in memory """ # access .data to coerce everything to numpy or dask arrays lazy_data = {k: v._data for k, v in self.variables.items() @@ -495,14 +513,14 @@ def _persist_inplace(self): import dask # evaluate all the dask arrays simultaneously - evaluated_data = dask.persist(*lazy_data.values()) + evaluated_data = dask.persist(*lazy_data.values(), **kwargs) for k, data in zip(lazy_data, evaluated_data): self.variables[k].data = data return self - def persist(self): + def persist(self, **kwargs): """ Trigger computation, keeping data as dask arrays This operation can be used to trigger computation on underlying dask @@ -510,9 +528,18 @@ def persist(self): data as dask arrays. This is particularly useful when using the dask.distributed scheduler and you want to load a large amount of data into distributed memory. + + Parameters + ---------- + **kwargs : dict + Additional keyword arguments passed on to ``dask.persist``. + + See Also + -------- + dask.persist """ new = self.copy(deep=False) - return new._persist_inplace() + return new._persist_inplace(**kwargs) @classmethod def _construct_direct(cls, variables, coord_names, dims=None, attrs=None, diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 2da9e599e1b..dc8f3b39d2d 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -307,19 +307,30 @@ def data(self, data): def _indexable_data(self): return orthogonally_indexable(self._data) - def load(self): + def load(self, **kwargs): """Manually trigger loading of this variable's data from disk or a remote source into memory and return this variable. Normally, it should not be necessary to call this method in user code, because all xarray functions should either work on deferred data or load data automatically. + + Parameters + ---------- + **kwargs : dict + Additional keyword arguments passed on to ``dask.array.compute``. + + See Also + -------- + dask.array.compute """ - if not isinstance(self._data, np.ndarray): + if isinstance(self._data, dask_array_type): + self._data = as_compatible_data(self._data.compute(**kwargs)) + elif not isinstance(self._data, np.ndarray): self._data = np.asarray(self._data) return self - def compute(self): + def compute(self, **kwargs): """Manually trigger loading of this variable's data from disk or a remote source into memory and return a new variable. The original is left unaltered. @@ -327,9 +338,18 @@ def compute(self): Normally, it should not be necessary to call this method in user code, because all xarray functions should either work on deferred data or load data automatically. + + Parameters + ---------- + **kwargs : dict + Additional keyword arguments passed on to ``dask.array.compute``. + + See Also + -------- + dask.array.compute """ new = self.copy(deep=False) - return new.load() + return new.load(**kwargs) @property def values(self): diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 7afad6ffe92..05c4cd340cb 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -20,6 +20,11 @@ except ImportError: import unittest +try: + from unittest import mock +except ImportError: + import mock + try: import scipy has_scipy = True diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 92f616b8bd6..422c34adfa3 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -4,6 +4,7 @@ import pickle import numpy as np import pandas as pd +import pytest import xarray as xr from xarray import Variable, DataArray, Dataset @@ -11,7 +12,7 @@ from xarray.core.pycompat import suppress from . import TestCase, requires_dask -from xarray.tests import unittest +from xarray.tests import unittest, mock with suppress(ImportError): import dask @@ -394,6 +395,47 @@ def test_from_dask_variable(self): self.assertLazyAndIdentical(self.lazy_array, a) +@requires_dask +@pytest.mark.parametrize("method", ['load', 'compute']) +def test_dask_kwargs_variable(method): + x = Variable('y', da.from_array(np.arange(3), chunks=(2,))) + # args should be passed on to da.Array.compute() + with mock.patch.object(da.Array, 'compute', + return_value=np.arange(3)) as mock_compute: + getattr(x, method)(foo='bar') + mock_compute.assert_called_with(foo='bar') + + +@requires_dask +@pytest.mark.parametrize("method", ['load', 'compute', 'persist']) +def test_dask_kwargs_dataarray(method): + data = da.from_array(np.arange(3), chunks=(2,)) + x = DataArray(data) + if method in ['load', 'compute']: + dask_func = 'dask.array.compute' + else: + dask_func = 'dask.persist' + # args should be passed on to "dask_func" + with mock.patch(dask_func) as mock_func: + getattr(x, method)(foo='bar') + mock_func.assert_called_with(data, foo='bar') + + +@requires_dask +@pytest.mark.parametrize("method", ['load', 'compute', 'persist']) +def test_dask_kwargs_dataset(method): + data = da.from_array(np.arange(3), chunks=(2,)) + x = Dataset({'x': (('y'), data)}) + if method in ['load', 'compute']: + dask_func = 'dask.array.compute' + else: + dask_func = 'dask.persist' + # args should be passed on to "dask_func" + with mock.patch(dask_func) as mock_func: + getattr(x, method)(foo='bar') + mock_func.assert_called_with(data, foo='bar') + + kernel_call_count = 0 def kernel(): """Dask kernel to test pickling/unpickling. @@ -403,6 +445,7 @@ def kernel(): kernel_call_count += 1 return np.ones(1) + def build_dask_array(): global kernel_call_count kernel_call_count = 0