Skip to content
This repository was archived by the owner on Oct 24, 2024. It is now read-only.

Alternative fix for #188 #268

Merged
merged 2 commits into from
Oct 24, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions datatree/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
HybridMappingProxy,
_default,
either_dict_or_kwargs,
maybe_wrap_array,
)
from xarray.core.variable import Variable

Expand Down Expand Up @@ -235,6 +236,65 @@ def _replace(
inplace=inplace,
)

def map(
self,
func: Callable,
keep_attrs: bool | None = None,
args: Iterable[Any] = (),
**kwargs: Any,
) -> Dataset:
"""Apply a function to each data variable in this dataset

Parameters
----------
func : callable
Function which can be called in the form `func(x, *args, **kwargs)`
to transform each DataArray `x` in this dataset into another
DataArray.
keep_attrs : bool or None, optional
If True, both the dataset's and variables' attributes (`attrs`) will be
copied from the original objects to the new ones. If False, the new dataset
and variables will be returned without copying the attributes.
args : iterable, optional
Positional arguments passed on to `func`.
**kwargs : Any
Keyword arguments passed on to `func`.

Returns
-------
applied : Dataset
Resulting dataset from applying ``func`` to each data variable.

Examples
--------
>>> da = xr.DataArray(np.random.randn(2, 3))
>>> ds = xr.Dataset({"foo": da, "bar": ("x", [-1, 2])})
>>> ds
<xarray.Dataset>
Dimensions: (dim_0: 2, dim_1: 3, x: 2)
Dimensions without coordinates: dim_0, dim_1, x
Data variables:
foo (dim_0, dim_1) float64 1.764 0.4002 0.9787 2.241 1.868 -0.9773
bar (x) int64 -1 2
>>> ds.map(np.fabs)
<xarray.Dataset>
Dimensions: (dim_0: 2, dim_1: 3, x: 2)
Dimensions without coordinates: dim_0, dim_1, x
Data variables:
foo (dim_0, dim_1) float64 1.764 0.4002 0.9787 2.241 1.868 0.9773
bar (x) float64 1.0 2.0
"""

# Copied from xarray.Dataset so as not to call type(self), which causes problems (see datatree GH188).
# TODO Refactor xarray upstream to avoid needing to overwrite this.
# TODO This copied version will drop all attrs - the keep_attrs stuff should be re-instated
variables = {
k: maybe_wrap_array(v, func(v, *args, **kwargs))
for k, v in self.data_vars.items()
}
# return type(self)(variables, attrs=attrs)
return Dataset(variables)


class DataTree(
NamedNode,
Expand Down