From e7e14801087fa34705667cf83b419f953182a688 Mon Sep 17 00:00:00 2001 From: Maximilian Roos Date: Fri, 8 Dec 2023 14:07:55 -0800 Subject: [PATCH 1/3] Add Cumulative aggregation Offer a fixture for unifying `DataArray` & `Dataset` tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit (stacked on #8512, worth reviewing after that's merged) Some tests are literally copy & pasted between DataArray & Dataset tests. This change allows them to use a single test. Not everything will work — sometimes we want to check specifics — but sometimes they will... --- xarray/tests/conftest.py | 43 +++++++++++++++++++++++ xarray/tests/test_rolling.py | 67 ++++++++++++++---------------------- 2 files changed, 68 insertions(+), 42 deletions(-) diff --git a/xarray/tests/conftest.py b/xarray/tests/conftest.py index 6a8cf008f9f..ca46b369726 100644 --- a/xarray/tests/conftest.py +++ b/xarray/tests/conftest.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np import pandas as pd import pytest @@ -77,3 +79,44 @@ def da(request, backend): return da else: raise ValueError + + +@pytest.fixture(params=[Dataset, DataArray]) +def type(request): + return request.param + + +@pytest.fixture(params=[1]) +def d(request, backend, type) -> DataArray | Dataset: + """ + For tests which can test either a DataArray or a Dataset. + """ + result: DataArray | Dataset + if request.param == 1: + ds = Dataset( + dict( + a=(["x", "y"], np.arange(16).reshape(8, 2)), + b=(["y", "z"], np.arange(12, 32).reshape(2, 10).astype(np.float64)), + ), + dict( + x=("x", np.linspace(0, 1.0, 8)), + y=range(2), + z=("z", np.linspace(0, 1.0, 10)), + w=("y", ["a", "b"]), + ), + ) + if type == DataArray: + result = ds["a"].assign_coords(w=ds.coords["w"]) + elif type == Dataset: + result = ds + else: + raise ValueError + else: + raise ValueError + + if backend == "dask": + return result.chunk() + elif backend == "numpy": + return result + else: + raise ValueError diff --git a/xarray/tests/test_rolling.py b/xarray/tests/test_rolling.py index 645ec1f85e6..647288e3ade 100644 --- a/xarray/tests/test_rolling.py +++ b/xarray/tests/test_rolling.py @@ -36,6 +36,31 @@ def compute_backend(request): yield request.param +@pytest.mark.parametrize("func", ["mean", "sum"]) +@pytest.mark.parametrize("min_periods", [1, 20]) +def test_cumulative(d, func, min_periods) -> None: + # One dim + result = getattr(d.cumulative("x", min_periods=min_periods), func)() + expected = getattr(d.rolling(x=d["x"].size, min_periods=min_periods), func)() + assert_identical(result, expected) + + # Multiple dim + result = getattr(d.cumulative(["x", "y"], min_periods=min_periods), func)() + expected = getattr( + d.rolling(x=d["x"].size, y=d["y"].size, min_periods=min_periods), + func, + )() + assert_identical(result, expected) + + +def test_cumulative_vs_cum(d) -> None: + result = d.cumulative("x").sum() + expected = d.cumsum("x") + # cumsum drops the coord of the dimension; cumulative doesn't + expected = expected.assign_coords(x=result["x"]) + assert_identical(result, expected) + + class TestDataArrayRolling: @pytest.mark.parametrize("da", (1, 2), indirect=True) @pytest.mark.parametrize("center", [True, False]) @@ -485,29 +510,6 @@ def test_rolling_exp_keep_attrs(self, da, func) -> None: ): da.rolling_exp(time=10, keep_attrs=True) - @pytest.mark.parametrize("func", ["mean", "sum"]) - @pytest.mark.parametrize("min_periods", [1, 20]) - def test_cumulative(self, da, func, min_periods) -> None: - # One dim - result = getattr(da.cumulative("time", min_periods=min_periods), func)() - expected = getattr( - da.rolling(time=da.time.size, min_periods=min_periods), func - )() - assert_identical(result, expected) - - # Multiple dim - result = getattr(da.cumulative(["time", "a"], min_periods=min_periods), func)() - expected = getattr( - da.rolling(time=da.time.size, a=da.a.size, min_periods=min_periods), - func, - )() - assert_identical(result, expected) - - def test_cumulative_vs_cum(self, da) -> None: - result = da.cumulative("time").sum() - expected = da.cumsum("time") - assert_identical(result, expected) - class TestDatasetRolling: @pytest.mark.parametrize( @@ -832,25 +834,6 @@ def test_raise_no_warning_dask_rolling_assert_close(self, ds, name) -> None: expected = getattr(getattr(ds.rolling(time=4), name)().rolling(x=3), name)() assert_allclose(actual, expected) - @pytest.mark.parametrize("func", ["mean", "sum"]) - @pytest.mark.parametrize("ds", (2,), indirect=True) - @pytest.mark.parametrize("min_periods", [1, 10]) - def test_cumulative(self, ds, func, min_periods) -> None: - # One dim - result = getattr(ds.cumulative("time", min_periods=min_periods), func)() - expected = getattr( - ds.rolling(time=ds.time.size, min_periods=min_periods), func - )() - assert_identical(result, expected) - - # Multiple dim - result = getattr(ds.cumulative(["time", "x"], min_periods=min_periods), func)() - expected = getattr( - ds.rolling(time=ds.time.size, x=ds.x.size, min_periods=min_periods), - func, - )() - assert_identical(result, expected) - @requires_numbagg class TestDatasetRollingExp: From db2a141bb841b8331b1ed858db4f873ee00b6cb6 Mon Sep 17 00:00:00 2001 From: Maximilian Roos Date: Fri, 8 Dec 2023 14:23:07 -0800 Subject: [PATCH 2/3] --- xarray/tests/conftest.py | 12 ++++++------ xarray/tests/test_rolling.py | 10 +++++----- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/xarray/tests/conftest.py b/xarray/tests/conftest.py index ca46b369726..f153c2f4dc0 100644 --- a/xarray/tests/conftest.py +++ b/xarray/tests/conftest.py @@ -95,14 +95,14 @@ def d(request, backend, type) -> DataArray | Dataset: if request.param == 1: ds = Dataset( dict( - a=(["x", "y"], np.arange(16).reshape(8, 2)), - b=(["y", "z"], np.arange(12, 32).reshape(2, 10).astype(np.float64)), + a=(["x", "z"], np.arange(24).reshape(2, 12)), + b=(["y", "z"], np.arange(100, 136).reshape(3, 12).astype(np.float64)), ), dict( - x=("x", np.linspace(0, 1.0, 8)), - y=range(2), - z=("z", np.linspace(0, 1.0, 10)), - w=("y", ["a", "b"]), + x=("x", np.linspace(0, 1.0, 2)), + y=range(3), + z=("z", pd.date_range("2000-01-01", periods=12)), + w=("x", ["a", "b"]), ), ) if type == DataArray: diff --git a/xarray/tests/test_rolling.py b/xarray/tests/test_rolling.py index 647288e3ade..5240eabca0b 100644 --- a/xarray/tests/test_rolling.py +++ b/xarray/tests/test_rolling.py @@ -37,17 +37,17 @@ def compute_backend(request): @pytest.mark.parametrize("func", ["mean", "sum"]) -@pytest.mark.parametrize("min_periods", [1, 20]) +@pytest.mark.parametrize("min_periods", [1, 10]) def test_cumulative(d, func, min_periods) -> None: # One dim - result = getattr(d.cumulative("x", min_periods=min_periods), func)() - expected = getattr(d.rolling(x=d["x"].size, min_periods=min_periods), func)() + result = getattr(d.cumulative("z", min_periods=min_periods), func)() + expected = getattr(d.rolling(z=d["z"].size, min_periods=min_periods), func)() assert_identical(result, expected) # Multiple dim - result = getattr(d.cumulative(["x", "y"], min_periods=min_periods), func)() + result = getattr(d.cumulative(["z", "x"], min_periods=min_periods), func)() expected = getattr( - d.rolling(x=d["x"].size, y=d["y"].size, min_periods=min_periods), + d.rolling(z=d["z"].size, x=d["x"].size, min_periods=min_periods), func, )() assert_identical(result, expected) From 34728164be745b1ad217c4791ca00498b89dd805 Mon Sep 17 00:00:00 2001 From: Maximilian Roos Date: Fri, 8 Dec 2023 14:23:59 -0800 Subject: [PATCH 3/3] --- xarray/tests/test_rolling.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/xarray/tests/test_rolling.py b/xarray/tests/test_rolling.py index 5240eabca0b..7cb2cd70d29 100644 --- a/xarray/tests/test_rolling.py +++ b/xarray/tests/test_rolling.py @@ -54,10 +54,10 @@ def test_cumulative(d, func, min_periods) -> None: def test_cumulative_vs_cum(d) -> None: - result = d.cumulative("x").sum() - expected = d.cumsum("x") + result = d.cumulative("z").sum() + expected = d.cumsum("z") # cumsum drops the coord of the dimension; cumulative doesn't - expected = expected.assign_coords(x=result["x"]) + expected = expected.assign_coords(z=result["z"]) assert_identical(result, expected)