Skip to content

map_blocks: Allow passing dask-backed objects in args #3818

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Jun 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ New Features
- :py:meth:`map_blocks` now accepts a ``template`` kwarg. This allows use cases
where the result of a computation could not be inferred automatically.
By `Deepak Cherian <https://github.com/dcherian>`_
- :py:meth:`map_blocks` can now handle dask-backed xarray objects in ``args``. (:pull:`3818`)
By `Deepak Cherian <https://github.com/dcherian>`_

Bug fixes
~~~~~~~~~
Expand Down
84 changes: 65 additions & 19 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3262,45 +3262,91 @@ def map_blocks(
----------
func: callable
User-provided function that accepts a DataArray as its first
parameter. The function will receive a subset, i.e. one block, of this DataArray
(see below), corresponding to one chunk along each chunked dimension. ``func`` will be
executed as ``func(block_subset, *args, **kwargs)``.
parameter. The function will receive a subset or 'block' of this DataArray (see below),
corresponding to one chunk along each chunked dimension. ``func`` will be
executed as ``func(subset_dataarray, *subset_args, **kwargs)``.

This function must return either a single DataArray or a single Dataset.

This function cannot add a new chunked dimension.

obj: DataArray, Dataset
Passed to the function as its first argument, one block at a time.
args: Sequence
Passed verbatim to func after unpacking, after the sliced DataArray. xarray
objects, if any, will not be split by chunks. Passing dask collections is
not allowed.
Passed to func after unpacking and subsetting any xarray objects by blocks.
xarray objects in args must be aligned with obj, otherwise an error is raised.
kwargs: Mapping
Passed verbatim to func after unpacking. xarray objects, if any, will not be
split by chunks. Passing dask collections is not allowed.
subset to blocks. Passing dask collections in kwargs is not allowed.
template: (optional) DataArray, Dataset
xarray object representing the final result after compute is called. If not provided,
the function will be first run on mocked-up data, that looks like 'obj' but
the function will be first run on mocked-up data, that looks like ``obj`` but
has sizes 0, to determine properties of the returned object such as dtype,
variable names, new dimensions and new indexes (if any).
'template' must be provided if the function changes the size of existing dimensions.
variable names, attributes, new dimensions and new indexes (if any).
``template`` must be provided if the function changes the size of existing dimensions.
When provided, ``attrs`` on variables in `template` are copied over to the result. Any
``attrs`` set by ``func`` will be ignored.


Returns
-------
A single DataArray or Dataset with dask backend, reassembled from the outputs of
the function.
A single DataArray or Dataset with dask backend, reassembled from the outputs of the
function.

Notes
-----
This method is designed for when one needs to manipulate a whole xarray object
within each chunk. In the more common case where one can work on numpy arrays,
it is recommended to use apply_ufunc.
This function is designed for when ``func`` needs to manipulate a whole xarray object
subset to each block. In the more common case where ``func`` can work on numpy arrays, it is
recommended to use ``apply_ufunc``.

If none of the variables in this DataArray is backed by dask, calling this
method is equivalent to calling ``func(self, *args, **kwargs)``.
If none of the variables in ``obj`` is backed by dask arrays, calling this function is
equivalent to calling ``func(obj, *args, **kwargs)``.

See Also
--------
dask.array.map_blocks, xarray.apply_ufunc, xarray.map_blocks,
xarray.Dataset.map_blocks
dask.array.map_blocks, xarray.apply_ufunc, xarray.Dataset.map_blocks,
xarray.DataArray.map_blocks

Examples
--------

Calculate an anomaly from climatology using ``.groupby()``. Using
``xr.map_blocks()`` allows for parallel operations with knowledge of ``xarray``,
its indices, and its methods like ``.groupby()``.

>>> def calculate_anomaly(da, groupby_type="time.month"):
... gb = da.groupby(groupby_type)
... clim = gb.mean(dim="time")
... return gb - clim
>>> time = xr.cftime_range("1990-01", "1992-01", freq="M")
>>> np.random.seed(123)
>>> array = xr.DataArray(
... np.random.rand(len(time)), dims="time", coords=[time]
... ).chunk()
>>> array.map_blocks(calculate_anomaly, template=array).compute()
<xarray.DataArray (time: 24)>
array([ 0.12894847, 0.11323072, -0.0855964 , -0.09334032, 0.26848862,
0.12382735, 0.22460641, 0.07650108, -0.07673453, -0.22865714,
-0.19063865, 0.0590131 , -0.12894847, -0.11323072, 0.0855964 ,
0.09334032, -0.26848862, -0.12382735, -0.22460641, -0.07650108,
0.07673453, 0.22865714, 0.19063865, -0.0590131 ])
Coordinates:
* time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00

Note that one must explicitly use ``args=[]`` and ``kwargs={}`` to pass arguments
to the function being applied in ``xr.map_blocks()``:

>>> array.map_blocks(
... calculate_anomaly, kwargs={"groupby_type": "time.year"}, template=array,
... )
<xarray.DataArray (time: 24)>
array([ 0.15361741, -0.25671244, -0.31600032, 0.008463 , 0.1766172 ,
-0.11974531, 0.43791243, 0.14197797, -0.06191987, -0.15073425,
-0.19967375, 0.18619794, -0.05100474, -0.42989909, -0.09153273,
0.24841842, -0.30708526, -0.31412523, 0.04197439, 0.0422506 ,
0.14482397, 0.35985481, 0.23487834, 0.12144652])
Coordinates:
* time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00
"""
from .parallel import map_blocks

Expand Down
83 changes: 65 additions & 18 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5721,45 +5721,92 @@ def map_blocks(
----------
func: callable
User-provided function that accepts a Dataset as its first
parameter. The function will receive a subset, i.e. one block, of this Dataset
(see below), corresponding to one chunk along each chunked dimension. ``func`` will be
executed as ``func(block_subset, *args, **kwargs)``.
parameter. The function will receive a subset or 'block' of this Dataset (see below),
corresponding to one chunk along each chunked dimension. ``func`` will be
executed as ``func(subset_dataset, *subset_args, **kwargs)``.

This function must return either a single DataArray or a single Dataset.

This function cannot add a new chunked dimension.

obj: DataArray, Dataset
Passed to the function as its first argument, one block at a time.
args: Sequence
Passed verbatim to func after unpacking, after the sliced DataArray. xarray
objects, if any, will not be split by chunks. Passing dask collections is
not allowed.
Passed to func after unpacking and subsetting any xarray objects by blocks.
xarray objects in args must be aligned with obj, otherwise an error is raised.
kwargs: Mapping
Passed verbatim to func after unpacking. xarray objects, if any, will not be
split by chunks. Passing dask collections is not allowed.
subset to blocks. Passing dask collections in kwargs is not allowed.
template: (optional) DataArray, Dataset
xarray object representing the final result after compute is called. If not provided,
the function will be first run on mocked-up data, that looks like 'obj' but
the function will be first run on mocked-up data, that looks like ``obj`` but
has sizes 0, to determine properties of the returned object such as dtype,
variable names, new dimensions and new indexes (if any).
'template' must be provided if the function changes the size of existing dimensions.
variable names, attributes, new dimensions and new indexes (if any).
``template`` must be provided if the function changes the size of existing dimensions.
When provided, ``attrs`` on variables in `template` are copied over to the result. Any
``attrs`` set by ``func`` will be ignored.


Returns
-------
A single DataArray or Dataset with dask backend, reassembled from the outputs of
the function.
A single DataArray or Dataset with dask backend, reassembled from the outputs of the
function.

Notes
-----
This method is designed for when one needs to manipulate a whole xarray object
within each chunk. In the more common case where one can work on numpy arrays,
it is recommended to use apply_ufunc.
This function is designed for when ``func`` needs to manipulate a whole xarray object
subset to each block. In the more common case where ``func`` can work on numpy arrays, it is
recommended to use ``apply_ufunc``.

If none of the variables in this Dataset is backed by dask, calling this method
is equivalent to calling ``func(self, *args, **kwargs)``.
If none of the variables in ``obj`` is backed by dask arrays, calling this function is
equivalent to calling ``func(obj, *args, **kwargs)``.

See Also
--------
dask.array.map_blocks, xarray.apply_ufunc, xarray.map_blocks,
dask.array.map_blocks, xarray.apply_ufunc, xarray.Dataset.map_blocks,
xarray.DataArray.map_blocks

Examples
--------

Calculate an anomaly from climatology using ``.groupby()``. Using
``xr.map_blocks()`` allows for parallel operations with knowledge of ``xarray``,
its indices, and its methods like ``.groupby()``.

>>> def calculate_anomaly(da, groupby_type="time.month"):
... gb = da.groupby(groupby_type)
... clim = gb.mean(dim="time")
... return gb - clim
>>> time = xr.cftime_range("1990-01", "1992-01", freq="M")
>>> np.random.seed(123)
>>> array = xr.DataArray(
... np.random.rand(len(time)), dims="time", coords=[time]
... ).chunk()
>>> ds = xr.Dataset({"a": array})
>>> ds.map_blocks(calculate_anomaly, template=ds).compute()
<xarray.DataArray (time: 24)>
array([ 0.12894847, 0.11323072, -0.0855964 , -0.09334032, 0.26848862,
0.12382735, 0.22460641, 0.07650108, -0.07673453, -0.22865714,
-0.19063865, 0.0590131 , -0.12894847, -0.11323072, 0.0855964 ,
0.09334032, -0.26848862, -0.12382735, -0.22460641, -0.07650108,
0.07673453, 0.22865714, 0.19063865, -0.0590131 ])
Coordinates:
* time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00

Note that one must explicitly use ``args=[]`` and ``kwargs={}`` to pass arguments
to the function being applied in ``xr.map_blocks()``:

>>> ds.map_blocks(
... calculate_anomaly, kwargs={"groupby_type": "time.year"}, template=ds,
... )
<xarray.DataArray (time: 24)>
array([ 0.15361741, -0.25671244, -0.31600032, 0.008463 , 0.1766172 ,
-0.11974531, 0.43791243, 0.14197797, -0.06191987, -0.15073425,
-0.19967375, 0.18619794, -0.05100474, -0.42989909, -0.09153273,
0.24841842, -0.30708526, -0.31412523, 0.04197439, 0.0422506 ,
0.14482397, 0.35985481, 0.23487834, 0.12144652])
Coordinates:
* time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00
"""
from .parallel import map_blocks

Expand Down
Loading