diff --git a/doc/api.rst b/doc/api.rst
index 4492d882355..43a9cf53ead 100644
--- a/doc/api.rst
+++ b/doc/api.rst
@@ -165,6 +165,7 @@ Computation
Dataset.groupby_bins
Dataset.rolling
Dataset.rolling_exp
+ Dataset.weighted
Dataset.coarsen
Dataset.resample
Dataset.diff
@@ -340,6 +341,7 @@ Computation
DataArray.groupby_bins
DataArray.rolling
DataArray.rolling_exp
+ DataArray.weighted
DataArray.coarsen
DataArray.dt
DataArray.resample
@@ -577,6 +579,22 @@ Rolling objects
core.rolling.DatasetRolling.reduce
core.rolling_exp.RollingExp
+Weighted objects
+================
+
+.. autosummary::
+ :toctree: generated/
+
+ core.weighted.DataArrayWeighted
+ core.weighted.DataArrayWeighted.mean
+ core.weighted.DataArrayWeighted.sum
+ core.weighted.DataArrayWeighted.sum_of_weights
+ core.weighted.DatasetWeighted
+ core.weighted.DatasetWeighted.mean
+ core.weighted.DatasetWeighted.sum
+ core.weighted.DatasetWeighted.sum_of_weights
+
+
Coarsen objects
===============
diff --git a/doc/computation.rst b/doc/computation.rst
index 1ac30f55ee7..5309f27e9b6 100644
--- a/doc/computation.rst
+++ b/doc/computation.rst
@@ -1,3 +1,5 @@
+.. currentmodule:: xarray
+
.. _comput:
###########
@@ -241,12 +243,94 @@ You can also use ``construct`` to compute a weighted rolling sum:
To avoid this, use ``skipna=False`` as the above example.
+.. _comput.weighted:
+
+Weighted array reductions
+=========================
+
+:py:class:`DataArray` and :py:class:`Dataset` objects include :py:meth:`DataArray.weighted`
+and :py:meth:`Dataset.weighted` array reduction methods. They currently
+support weighted ``sum`` and weighted ``mean``.
+
+.. ipython:: python
+
+ coords = dict(month=('month', [1, 2, 3]))
+
+ prec = xr.DataArray([1.1, 1.0, 0.9], dims=('month', ), coords=coords)
+ weights = xr.DataArray([31, 28, 31], dims=('month', ), coords=coords)
+
+Create a weighted object:
+
+.. ipython:: python
+
+ weighted_prec = prec.weighted(weights)
+ weighted_prec
+
+Calculate the weighted sum:
+
+.. ipython:: python
+
+ weighted_prec.sum()
+
+Calculate the weighted mean:
+
+.. ipython:: python
+
+ weighted_prec.mean(dim="month")
+
+The weighted sum corresponds to:
+
+.. ipython:: python
+
+ weighted_sum = (prec * weights).sum()
+ weighted_sum
+
+and the weighted mean to:
+
+.. ipython:: python
+
+ weighted_mean = weighted_sum / weights.sum()
+ weighted_mean
+
+However, the functions also take missing values in the data into account:
+
+.. ipython:: python
+
+ data = xr.DataArray([np.NaN, 2, 4])
+ weights = xr.DataArray([8, 1, 1])
+
+ data.weighted(weights).mean()
+
+Using ``(data * weights).sum() / weights.sum()`` would (incorrectly) result
+in 0.6.
+
+
+If the weights add up to to 0, ``sum`` returns 0:
+
+.. ipython:: python
+
+ data = xr.DataArray([1.0, 1.0])
+ weights = xr.DataArray([-1.0, 1.0])
+
+ data.weighted(weights).sum()
+
+and ``mean`` returns ``NaN``:
+
+.. ipython:: python
+
+ data.weighted(weights).mean()
+
+
+.. note::
+ ``weights`` must be a :py:class:`DataArray` and cannot contain missing values.
+ Missing values can be replaced manually by ``weights.fillna(0)``.
+
.. _comput.coarsen:
Coarsen large arrays
====================
-``DataArray`` and ``Dataset`` objects include a
+:py:class:`DataArray` and :py:class:`Dataset` objects include a
:py:meth:`~xarray.DataArray.coarsen` and :py:meth:`~xarray.Dataset.coarsen`
methods. This supports the block aggregation along multiple dimensions,
diff --git a/doc/examples.rst b/doc/examples.rst
index 805395808e0..1d48d29bcc5 100644
--- a/doc/examples.rst
+++ b/doc/examples.rst
@@ -6,6 +6,7 @@ Examples
examples/weather-data
examples/monthly-means
+ examples/area_weighted_temperature
examples/multidimensional-coords
examples/visualization_gallery
examples/ROMS_ocean_model
diff --git a/doc/examples/area_weighted_temperature.ipynb b/doc/examples/area_weighted_temperature.ipynb
new file mode 100644
index 00000000000..72876e3fc29
--- /dev/null
+++ b/doc/examples/area_weighted_temperature.ipynb
@@ -0,0 +1,226 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "toc": true
+ },
+ "source": [
+ "
Table of Contents
\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Compare weighted and unweighted mean temperature\n",
+ "\n",
+ "\n",
+ "Author: [Mathias Hauser](https://github.com/mathause/)\n",
+ "\n",
+ "\n",
+ "We use the `air_temperature` example dataset to calculate the area-weighted temperature over its domain. This dataset has a regular latitude/ longitude grid, thus the gridcell area decreases towards the pole. For this grid we can use the cosine of the latitude as proxy for the grid cell area.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-03-17T14:43:57.222351Z",
+ "start_time": "2020-03-17T14:43:56.147541Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "%matplotlib inline\n",
+ "\n",
+ "import cartopy.crs as ccrs\n",
+ "import matplotlib.pyplot as plt\n",
+ "import numpy as np\n",
+ "\n",
+ "import xarray as xr"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Data\n",
+ "\n",
+ "Load the data, convert to celsius, and resample to daily values"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-03-17T14:43:57.831734Z",
+ "start_time": "2020-03-17T14:43:57.651845Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "ds = xr.tutorial.load_dataset(\"air_temperature\")\n",
+ "\n",
+ "# to celsius\n",
+ "air = ds.air - 273.15\n",
+ "\n",
+ "# resample from 6-hourly to daily values\n",
+ "air = air.resample(time=\"D\").mean()\n",
+ "\n",
+ "air"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Plot the first timestep:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-03-17T14:43:59.887120Z",
+ "start_time": "2020-03-17T14:43:59.582894Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "projection = ccrs.LambertConformal(central_longitude=-95, central_latitude=45)\n",
+ "\n",
+ "f, ax = plt.subplots(subplot_kw=dict(projection=projection))\n",
+ "\n",
+ "air.isel(time=0).plot(transform=ccrs.PlateCarree(), cbar_kwargs=dict(shrink=0.7))\n",
+ "ax.coastlines()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Creating weights\n",
+ "\n",
+ "For a for a rectangular grid the cosine of the latitude is proportional to the grid cell area."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-03-17T14:44:18.777092Z",
+ "start_time": "2020-03-17T14:44:18.736587Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "weights = np.cos(np.deg2rad(air.lat))\n",
+ "weights.name = \"weights\"\n",
+ "weights"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Weighted mean"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-03-17T14:44:52.607120Z",
+ "start_time": "2020-03-17T14:44:52.564674Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "air_weighted = air.weighted(weights)\n",
+ "air_weighted"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-03-17T14:44:54.334279Z",
+ "start_time": "2020-03-17T14:44:54.280022Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "weighted_mean = air_weighted.mean((\"lon\", \"lat\"))\n",
+ "weighted_mean"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Plot: comparison with unweighted mean\n",
+ "\n",
+ "Note how the weighted mean temperature is higher than the unweighted."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-03-17T14:45:08.877307Z",
+ "start_time": "2020-03-17T14:45:08.673383Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "weighted_mean.plot(label=\"weighted\")\n",
+ "air.mean((\"lon\", \"lat\")).plot(label=\"unweighted\")\n",
+ "\n",
+ "plt.legend()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.6"
+ },
+ "toc": {
+ "base_numbering": 1,
+ "nav_menu": {},
+ "number_sections": true,
+ "sideBar": true,
+ "skip_h1_title": false,
+ "title_cell": "Table of Contents",
+ "title_sidebar": "Contents",
+ "toc_cell": true,
+ "toc_position": {},
+ "toc_section_display": true,
+ "toc_window_display": true
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/doc/whats-new.rst b/doc/whats-new.rst
index aad0e083a8c..5640e872bea 100644
--- a/doc/whats-new.rst
+++ b/doc/whats-new.rst
@@ -25,6 +25,9 @@ Breaking changes
New Features
~~~~~~~~~~~~
+- Weighted array reductions are now supported via the new :py:meth:`DataArray.weighted`
+ and :py:meth:`Dataset.weighted` methods. See :ref:`comput.weighted`. (:issue:`422`, :pull:`2922`).
+ By `Mathias Hauser `_
- Added support for :py:class:`pandas.DatetimeIndex`-style rounding of
``cftime.datetime`` objects directly via a :py:class:`CFTimeIndex` or via the
:py:class:`~core.accessor_dt.DatetimeAccessor`.
diff --git a/xarray/core/common.py b/xarray/core/common.py
index 39aa7982091..a003642076f 100644
--- a/xarray/core/common.py
+++ b/xarray/core/common.py
@@ -745,6 +745,25 @@ def groupby_bins(
},
)
+ def weighted(self, weights):
+ """
+ Weighted operations.
+
+ Parameters
+ ----------
+ weights : DataArray
+ An array of weights associated with the values in this Dataset.
+ Each value in the data contributes to the reduction operation
+ according to its associated weight.
+
+ Notes
+ -----
+ ``weights`` must be a DataArray and cannot contain missing values.
+ Missing values can be replaced by ``weights.fillna(0)``.
+ """
+
+ return self._weighted_cls(self, weights)
+
def rolling(
self,
dim: Mapping[Hashable, int] = None,
diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py
index b335eeb293b..4b3ecb2744c 100644
--- a/xarray/core/dataarray.py
+++ b/xarray/core/dataarray.py
@@ -33,6 +33,7 @@
resample,
rolling,
utils,
+ weighted,
)
from .accessor_dt import CombinedDatetimelikeAccessor
from .accessor_str import StringAccessor
@@ -258,6 +259,7 @@ class DataArray(AbstractArray, DataWithCoords):
_rolling_cls = rolling.DataArrayRolling
_coarsen_cls = rolling.DataArrayCoarsen
_resample_cls = resample.DataArrayResample
+ _weighted_cls = weighted.DataArrayWeighted
dt = property(CombinedDatetimelikeAccessor)
diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py
index d5ad1123a54..c10447f6d11 100644
--- a/xarray/core/dataset.py
+++ b/xarray/core/dataset.py
@@ -46,6 +46,7 @@
resample,
rolling,
utils,
+ weighted,
)
from .alignment import _broadcast_helper, _get_broadcast_dims_map_common_coords, align
from .common import (
@@ -457,6 +458,7 @@ class Dataset(Mapping, ImplementsDatasetReduce, DataWithCoords):
_rolling_cls = rolling.DatasetRolling
_coarsen_cls = rolling.DatasetCoarsen
_resample_cls = resample.DatasetResample
+ _weighted_cls = weighted.DatasetWeighted
def __init__(
self,
diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py
new file mode 100644
index 00000000000..996d2e4c43e
--- /dev/null
+++ b/xarray/core/weighted.py
@@ -0,0 +1,255 @@
+from typing import TYPE_CHECKING, Hashable, Iterable, Optional, Union, overload
+
+from .computation import dot
+from .options import _get_keep_attrs
+
+if TYPE_CHECKING:
+ from .dataarray import DataArray, Dataset
+
+_WEIGHTED_REDUCE_DOCSTRING_TEMPLATE = """
+ Reduce this {cls}'s data by a weighted ``{fcn}`` along some dimension(s).
+
+ Parameters
+ ----------
+ dim : str or sequence of str, optional
+ Dimension(s) over which to apply the weighted ``{fcn}``.
+ skipna : bool, optional
+ If True, skip missing values (as marked by NaN). By default, only
+ skips missing values for float dtypes; other dtypes either do not
+ have a sentinel missing value (int) or skipna=True has not been
+ implemented (object, datetime64 or timedelta64).
+ keep_attrs : bool, optional
+ If True, the attributes (``attrs``) will be copied from the original
+ object to the new one. If False (default), the new object will be
+ returned without attributes.
+
+ Returns
+ -------
+ reduced : {cls}
+ New {cls} object with weighted ``{fcn}`` applied to its data and
+ the indicated dimension(s) removed.
+
+ Notes
+ -----
+ Returns {on_zero} if the ``weights`` sum to 0.0 along the reduced
+ dimension(s).
+ """
+
+_SUM_OF_WEIGHTS_DOCSTRING = """
+ Calculate the sum of weights, accounting for missing values in the data
+
+ Parameters
+ ----------
+ dim : str or sequence of str, optional
+ Dimension(s) over which to sum the weights.
+ keep_attrs : bool, optional
+ If True, the attributes (``attrs``) will be copied from the original
+ object to the new one. If False (default), the new object will be
+ returned without attributes.
+
+ Returns
+ -------
+ reduced : {cls}
+ New {cls} object with the sum of the weights over the given dimension.
+ """
+
+
+class Weighted:
+ """An object that implements weighted operations.
+
+ You should create a Weighted object by using the ``DataArray.weighted`` or
+ ``Dataset.weighted`` methods.
+
+ See Also
+ --------
+ Dataset.weighted
+ DataArray.weighted
+ """
+
+ __slots__ = ("obj", "weights")
+
+ @overload
+ def __init__(self, obj: "DataArray", weights: "DataArray") -> None:
+ ...
+
+ @overload # noqa: F811
+ def __init__(self, obj: "Dataset", weights: "DataArray") -> None: # noqa: F811
+ ...
+
+ def __init__(self, obj, weights): # noqa: F811
+ """
+ Create a Weighted object
+
+ Parameters
+ ----------
+ obj : DataArray or Dataset
+ Object over which the weighted reduction operation is applied.
+ weights : DataArray
+ An array of weights associated with the values in the obj.
+ Each value in the obj contributes to the reduction operation
+ according to its associated weight.
+
+ Notes
+ -----
+ ``weights`` must be a ``DataArray`` and cannot contain missing values.
+ Missing values can be replaced by ``weights.fillna(0)``.
+ """
+
+ from .dataarray import DataArray
+
+ if not isinstance(weights, DataArray):
+ raise ValueError("`weights` must be a DataArray")
+
+ if weights.isnull().any():
+ raise ValueError(
+ "`weights` cannot contain missing values. "
+ "Missing values can be replaced by `weights.fillna(0)`."
+ )
+
+ self.obj = obj
+ self.weights = weights
+
+ @staticmethod
+ def _reduce(
+ da: "DataArray",
+ weights: "DataArray",
+ dim: Optional[Union[Hashable, Iterable[Hashable]]] = None,
+ skipna: Optional[bool] = None,
+ ) -> "DataArray":
+ """reduce using dot; equivalent to (da * weights).sum(dim, skipna)
+
+ for internal use only
+ """
+
+ # need to infer dims as we use `dot`
+ if dim is None:
+ dim = ...
+
+ # need to mask invalid values in da, as `dot` does not implement skipna
+ if skipna or (skipna is None and da.dtype.kind in "cfO"):
+ da = da.fillna(0.0)
+
+ # `dot` does not broadcast arrays, so this avoids creating a large
+ # DataArray (if `weights` has additional dimensions)
+ # maybe add fasttrack (`(da * weights).sum(dims=dim, skipna=skipna)`)
+ return dot(da, weights, dims=dim)
+
+ def _sum_of_weights(
+ self, da: "DataArray", dim: Optional[Union[Hashable, Iterable[Hashable]]] = None
+ ) -> "DataArray":
+ """ Calculate the sum of weights, accounting for missing values """
+
+ # we need to mask data values that are nan; else the weights are wrong
+ mask = da.notnull()
+
+ sum_of_weights = self._reduce(mask, self.weights, dim=dim, skipna=False)
+
+ # 0-weights are not valid
+ valid_weights = sum_of_weights != 0.0
+
+ return sum_of_weights.where(valid_weights)
+
+ def _weighted_sum(
+ self,
+ da: "DataArray",
+ dim: Optional[Union[Hashable, Iterable[Hashable]]] = None,
+ skipna: Optional[bool] = None,
+ ) -> "DataArray":
+ """Reduce a DataArray by a by a weighted ``sum`` along some dimension(s)."""
+
+ return self._reduce(da, self.weights, dim=dim, skipna=skipna)
+
+ def _weighted_mean(
+ self,
+ da: "DataArray",
+ dim: Optional[Union[Hashable, Iterable[Hashable]]] = None,
+ skipna: Optional[bool] = None,
+ ) -> "DataArray":
+ """Reduce a DataArray by a weighted ``mean`` along some dimension(s)."""
+
+ weighted_sum = self._weighted_sum(da, dim=dim, skipna=skipna)
+
+ sum_of_weights = self._sum_of_weights(da, dim=dim)
+
+ return weighted_sum / sum_of_weights
+
+ def _implementation(self, func, dim, **kwargs):
+
+ raise NotImplementedError("Use `Dataset.weighted` or `DataArray.weighted`")
+
+ def sum_of_weights(
+ self,
+ dim: Optional[Union[Hashable, Iterable[Hashable]]] = None,
+ keep_attrs: Optional[bool] = None,
+ ) -> Union["DataArray", "Dataset"]:
+
+ return self._implementation(
+ self._sum_of_weights, dim=dim, keep_attrs=keep_attrs
+ )
+
+ def sum(
+ self,
+ dim: Optional[Union[Hashable, Iterable[Hashable]]] = None,
+ skipna: Optional[bool] = None,
+ keep_attrs: Optional[bool] = None,
+ ) -> Union["DataArray", "Dataset"]:
+
+ return self._implementation(
+ self._weighted_sum, dim=dim, skipna=skipna, keep_attrs=keep_attrs
+ )
+
+ def mean(
+ self,
+ dim: Optional[Union[Hashable, Iterable[Hashable]]] = None,
+ skipna: Optional[bool] = None,
+ keep_attrs: Optional[bool] = None,
+ ) -> Union["DataArray", "Dataset"]:
+
+ return self._implementation(
+ self._weighted_mean, dim=dim, skipna=skipna, keep_attrs=keep_attrs
+ )
+
+ def __repr__(self):
+ """provide a nice str repr of our Weighted object"""
+
+ klass = self.__class__.__name__
+ weight_dims = ", ".join(self.weights.dims)
+ return f"{klass} with weights along dimensions: {weight_dims}"
+
+
+class DataArrayWeighted(Weighted):
+ def _implementation(self, func, dim, **kwargs):
+
+ keep_attrs = kwargs.pop("keep_attrs")
+ if keep_attrs is None:
+ keep_attrs = _get_keep_attrs(default=False)
+
+ weighted = func(self.obj, dim=dim, **kwargs)
+
+ if keep_attrs:
+ weighted.attrs = self.obj.attrs
+
+ return weighted
+
+
+class DatasetWeighted(Weighted):
+ def _implementation(self, func, dim, **kwargs) -> "Dataset":
+
+ return self.obj.map(func, dim=dim, **kwargs)
+
+
+def _inject_docstring(cls, cls_name):
+
+ cls.sum_of_weights.__doc__ = _SUM_OF_WEIGHTS_DOCSTRING.format(cls=cls_name)
+
+ cls.sum.__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE.format(
+ cls=cls_name, fcn="sum", on_zero="0"
+ )
+
+ cls.mean.__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE.format(
+ cls=cls_name, fcn="mean", on_zero="NaN"
+ )
+
+
+_inject_docstring(DataArrayWeighted, "DataArray")
+_inject_docstring(DatasetWeighted, "Dataset")
diff --git a/xarray/tests/test_weighted.py b/xarray/tests/test_weighted.py
new file mode 100644
index 00000000000..24531215dfb
--- /dev/null
+++ b/xarray/tests/test_weighted.py
@@ -0,0 +1,311 @@
+import numpy as np
+import pytest
+
+import xarray as xr
+from xarray import DataArray
+from xarray.tests import assert_allclose, assert_equal, raises_regex
+
+
+@pytest.mark.parametrize("as_dataset", (True, False))
+def test_weighted_non_DataArray_weights(as_dataset):
+
+ data = DataArray([1, 2])
+ if as_dataset:
+ data = data.to_dataset(name="data")
+
+ with raises_regex(ValueError, "`weights` must be a DataArray"):
+ data.weighted([1, 2])
+
+
+@pytest.mark.parametrize("as_dataset", (True, False))
+@pytest.mark.parametrize("weights", ([np.nan, 2], [np.nan, np.nan]))
+def test_weighted_weights_nan_raises(as_dataset, weights):
+
+ data = DataArray([1, 2])
+ if as_dataset:
+ data = data.to_dataset(name="data")
+
+ with pytest.raises(ValueError, match="`weights` cannot contain missing values."):
+ data.weighted(DataArray(weights))
+
+
+@pytest.mark.parametrize(
+ ("weights", "expected"),
+ (([1, 2], 3), ([2, 0], 2), ([0, 0], np.nan), ([-1, 1], np.nan)),
+)
+def test_weighted_sum_of_weights_no_nan(weights, expected):
+
+ da = DataArray([1, 2])
+ weights = DataArray(weights)
+ result = da.weighted(weights).sum_of_weights()
+
+ expected = DataArray(expected)
+
+ assert_equal(expected, result)
+
+
+@pytest.mark.parametrize(
+ ("weights", "expected"),
+ (([1, 2], 2), ([2, 0], np.nan), ([0, 0], np.nan), ([-1, 1], 1)),
+)
+def test_weighted_sum_of_weights_nan(weights, expected):
+
+ da = DataArray([np.nan, 2])
+ weights = DataArray(weights)
+ result = da.weighted(weights).sum_of_weights()
+
+ expected = DataArray(expected)
+
+ assert_equal(expected, result)
+
+
+@pytest.mark.parametrize("da", ([1.0, 2], [1, np.nan], [np.nan, np.nan]))
+@pytest.mark.parametrize("factor", [0, 1, 3.14])
+@pytest.mark.parametrize("skipna", (True, False))
+def test_weighted_sum_equal_weights(da, factor, skipna):
+ # if all weights are 'f'; weighted sum is f times the ordinary sum
+
+ da = DataArray(da)
+ weights = xr.full_like(da, factor)
+
+ expected = da.sum(skipna=skipna) * factor
+ result = da.weighted(weights).sum(skipna=skipna)
+
+ assert_equal(expected, result)
+
+
+@pytest.mark.parametrize(
+ ("weights", "expected"), (([1, 2], 5), ([0, 2], 4), ([0, 0], 0))
+)
+def test_weighted_sum_no_nan(weights, expected):
+
+ da = DataArray([1, 2])
+
+ weights = DataArray(weights)
+ result = da.weighted(weights).sum()
+ expected = DataArray(expected)
+
+ assert_equal(expected, result)
+
+
+@pytest.mark.parametrize(
+ ("weights", "expected"), (([1, 2], 4), ([0, 2], 4), ([1, 0], 0), ([0, 0], 0))
+)
+@pytest.mark.parametrize("skipna", (True, False))
+def test_weighted_sum_nan(weights, expected, skipna):
+
+ da = DataArray([np.nan, 2])
+
+ weights = DataArray(weights)
+ result = da.weighted(weights).sum(skipna=skipna)
+
+ if skipna:
+ expected = DataArray(expected)
+ else:
+ expected = DataArray(np.nan)
+
+ assert_equal(expected, result)
+
+
+@pytest.mark.filterwarnings("ignore:Mean of empty slice")
+@pytest.mark.parametrize("da", ([1.0, 2], [1, np.nan], [np.nan, np.nan]))
+@pytest.mark.parametrize("skipna", (True, False))
+@pytest.mark.parametrize("factor", [1, 2, 3.14])
+def test_weighted_mean_equal_weights(da, skipna, factor):
+ # if all weights are equal (!= 0), should yield the same result as mean
+
+ da = DataArray(da)
+
+ # all weights as 1.
+ weights = xr.full_like(da, factor)
+
+ expected = da.mean(skipna=skipna)
+ result = da.weighted(weights).mean(skipna=skipna)
+
+ assert_equal(expected, result)
+
+
+@pytest.mark.parametrize(
+ ("weights", "expected"), (([4, 6], 1.6), ([1, 0], 1.0), ([0, 0], np.nan))
+)
+def test_weighted_mean_no_nan(weights, expected):
+
+ da = DataArray([1, 2])
+ weights = DataArray(weights)
+ expected = DataArray(expected)
+
+ result = da.weighted(weights).mean()
+
+ assert_equal(expected, result)
+
+
+@pytest.mark.parametrize(
+ ("weights", "expected"), (([4, 6], 2.0), ([1, 0], np.nan), ([0, 0], np.nan))
+)
+@pytest.mark.parametrize("skipna", (True, False))
+def test_weighted_mean_nan(weights, expected, skipna):
+
+ da = DataArray([np.nan, 2])
+ weights = DataArray(weights)
+
+ if skipna:
+ expected = DataArray(expected)
+ else:
+ expected = DataArray(np.nan)
+
+ result = da.weighted(weights).mean(skipna=skipna)
+
+ assert_equal(expected, result)
+
+
+def expected_weighted(da, weights, dim, skipna, operation):
+ """
+ Generate expected result using ``*`` and ``sum``. This is checked against
+ the result of da.weighted which uses ``dot``
+ """
+
+ weighted_sum = (da * weights).sum(dim=dim, skipna=skipna)
+
+ if operation == "sum":
+ return weighted_sum
+
+ masked_weights = weights.where(da.notnull())
+ sum_of_weights = masked_weights.sum(dim=dim, skipna=True)
+ valid_weights = sum_of_weights != 0
+ sum_of_weights = sum_of_weights.where(valid_weights)
+
+ if operation == "sum_of_weights":
+ return sum_of_weights
+
+ weighted_mean = weighted_sum / sum_of_weights
+
+ if operation == "mean":
+ return weighted_mean
+
+
+@pytest.mark.parametrize("dim", ("a", "b", "c", ("a", "b"), ("a", "b", "c"), None))
+@pytest.mark.parametrize("operation", ("sum_of_weights", "sum", "mean"))
+@pytest.mark.parametrize("add_nans", (True, False))
+@pytest.mark.parametrize("skipna", (None, True, False))
+@pytest.mark.parametrize("as_dataset", (True, False))
+def test_weighted_operations_3D(dim, operation, add_nans, skipna, as_dataset):
+
+ dims = ("a", "b", "c")
+ coords = dict(a=[0, 1, 2, 3], b=[0, 1, 2, 3], c=[0, 1, 2, 3])
+
+ weights = DataArray(np.random.randn(4, 4, 4), dims=dims, coords=coords)
+
+ data = np.random.randn(4, 4, 4)
+
+ # add approximately 25 % NaNs (https://stackoverflow.com/a/32182680/3010700)
+ if add_nans:
+ c = int(data.size * 0.25)
+ data.ravel()[np.random.choice(data.size, c, replace=False)] = np.NaN
+
+ data = DataArray(data, dims=dims, coords=coords)
+
+ if as_dataset:
+ data = data.to_dataset(name="data")
+
+ if operation == "sum_of_weights":
+ result = data.weighted(weights).sum_of_weights(dim)
+ else:
+ result = getattr(data.weighted(weights), operation)(dim, skipna=skipna)
+
+ expected = expected_weighted(data, weights, dim, skipna, operation)
+
+ assert_allclose(expected, result)
+
+
+@pytest.mark.parametrize("operation", ("sum_of_weights", "sum", "mean"))
+@pytest.mark.parametrize("as_dataset", (True, False))
+def test_weighted_operations_nonequal_coords(operation, as_dataset):
+
+ weights = DataArray(np.random.randn(4), dims=("a",), coords=dict(a=[0, 1, 2, 3]))
+ data = DataArray(np.random.randn(4), dims=("a",), coords=dict(a=[1, 2, 3, 4]))
+
+ if as_dataset:
+ data = data.to_dataset(name="data")
+
+ expected = expected_weighted(
+ data, weights, dim="a", skipna=None, operation=operation
+ )
+ result = getattr(data.weighted(weights), operation)(dim="a")
+
+ assert_allclose(expected, result)
+
+
+@pytest.mark.parametrize("dim", ("dim_0", None))
+@pytest.mark.parametrize("shape_data", ((4,), (4, 4), (4, 4, 4)))
+@pytest.mark.parametrize("shape_weights", ((4,), (4, 4), (4, 4, 4)))
+@pytest.mark.parametrize("operation", ("sum_of_weights", "sum", "mean"))
+@pytest.mark.parametrize("add_nans", (True, False))
+@pytest.mark.parametrize("skipna", (None, True, False))
+@pytest.mark.parametrize("as_dataset", (True, False))
+def test_weighted_operations_different_shapes(
+ dim, shape_data, shape_weights, operation, add_nans, skipna, as_dataset
+):
+
+ weights = DataArray(np.random.randn(*shape_weights))
+
+ data = np.random.randn(*shape_data)
+
+ # add approximately 25 % NaNs
+ if add_nans:
+ c = int(data.size * 0.25)
+ data.ravel()[np.random.choice(data.size, c, replace=False)] = np.NaN
+
+ data = DataArray(data)
+
+ if as_dataset:
+ data = data.to_dataset(name="data")
+
+ if operation == "sum_of_weights":
+ result = getattr(data.weighted(weights), operation)(dim)
+ else:
+ result = getattr(data.weighted(weights), operation)(dim, skipna=skipna)
+
+ expected = expected_weighted(data, weights, dim, skipna, operation)
+
+ assert_allclose(expected, result)
+
+
+@pytest.mark.parametrize("operation", ("sum_of_weights", "sum", "mean"))
+@pytest.mark.parametrize("as_dataset", (True, False))
+@pytest.mark.parametrize("keep_attrs", (True, False, None))
+def test_weighted_operations_keep_attr(operation, as_dataset, keep_attrs):
+
+ weights = DataArray(np.random.randn(2, 2), attrs=dict(attr="weights"))
+ data = DataArray(np.random.randn(2, 2))
+
+ if as_dataset:
+ data = data.to_dataset(name="data")
+
+ data.attrs = dict(attr="weights")
+
+ result = getattr(data.weighted(weights), operation)(keep_attrs=True)
+
+ if operation == "sum_of_weights":
+ assert weights.attrs == result.attrs
+ else:
+ assert data.attrs == result.attrs
+
+ result = getattr(data.weighted(weights), operation)(keep_attrs=None)
+ assert not result.attrs
+
+ result = getattr(data.weighted(weights), operation)(keep_attrs=False)
+ assert not result.attrs
+
+
+@pytest.mark.xfail(reason="xr.Dataset.map does not copy attrs of DataArrays GH: 3595")
+@pytest.mark.parametrize("operation", ("sum", "mean"))
+def test_weighted_operations_keep_attr_da_in_ds(operation):
+ # GH #3595
+
+ weights = DataArray(np.random.randn(2, 2))
+ data = DataArray(np.random.randn(2, 2), attrs=dict(attr="data"))
+ data = data.to_dataset(name="a")
+
+ result = getattr(data.weighted(weights), operation)(keep_attrs=True)
+
+ assert data.a.attrs == result.a.attrs