Skip to content

Commit 5a2eadc

Browse files
committed
more test cleanup
1 parent b509ebf commit 5a2eadc

File tree

1 file changed

+33
-39
lines changed

1 file changed

+33
-39
lines changed

xarray/tests/test_dataarray.py

Lines changed: 33 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -6790,47 +6790,41 @@ def test_rolling_reduce_nonnumeric(center, pad, min_periods, window, name):
67906790
assert actual.dims == expected.dims
67916791

67926792

6793-
def test_rolling_count_correct():
6793+
@pytest.mark.parametrize(
6794+
"time, min_periods, pad, expected",
6795+
(
6796+
[11, 1, True, DataArray([1, 1, 2, 3, 3, 4, 5, 6, 6, 7, 8], dims="time")],
6797+
[11, 1, False, DataArray([8], dims="time")],
6798+
[
6799+
11,
6800+
None,
6801+
True,
6802+
DataArray(
6803+
[np.nan] * 11,
6804+
dims="time",
6805+
),
6806+
],
6807+
[11, None, False, DataArray([np.nan], dims="time")],
6808+
[
6809+
7,
6810+
2,
6811+
True,
6812+
DataArray([np.nan, np.nan, 2, 3, 3, 4, 5, 5, 5, 5, 5], dims="time"),
6813+
],
6814+
[7, 2, False, DataArray([5, 5, 5, 5, 5], dims="time")],
6815+
),
6816+
)
6817+
def test_rolling_count_correct(time, min_periods, pad, expected):
67946818
da = DataArray([0, np.nan, 1, 2, np.nan, 3, 4, 5, np.nan, 6, 7], dims="time")
6819+
result = da.rolling(time=time, min_periods=min_periods, pad=pad).count()
6820+
assert_equal(result, expected)
67956821

6796-
kwargs = [
6797-
{"time": 11, "min_periods": 1},
6798-
{"time": 11, "min_periods": 1, "pad": False},
6799-
{"time": 11, "min_periods": None},
6800-
{"time": 11, "min_periods": None, "pad": False},
6801-
{"time": 7, "min_periods": 2},
6802-
{"time": 7, "min_periods": 2, "pad": False},
6803-
]
6804-
expecteds = [
6805-
DataArray([1, 1, 2, 3, 3, 4, 5, 6, 6, 7, 8], dims="time"),
6806-
DataArray([8], dims="time"),
6807-
DataArray(
6808-
[
6809-
np.nan,
6810-
np.nan,
6811-
np.nan,
6812-
np.nan,
6813-
np.nan,
6814-
np.nan,
6815-
np.nan,
6816-
np.nan,
6817-
np.nan,
6818-
np.nan,
6819-
np.nan,
6820-
],
6821-
dims="time",
6822-
),
6823-
DataArray([np.nan], dims="time"),
6824-
DataArray([np.nan, np.nan, 2, 3, 3, 4, 5, 5, 5, 5, 5], dims="time"),
6825-
DataArray([5, 5, 5, 5, 5], dims="time"),
6826-
]
6827-
6828-
for kwarg, expected in zip(kwargs, expecteds):
6829-
result = da.rolling(**kwarg).count()
6830-
assert_equal(result, expected)
6831-
6832-
result = da.to_dataset(name="var1").rolling(**kwarg).count()["var1"]
6833-
assert_equal(result, expected)
6822+
result = (
6823+
da.to_dataset(name="var1")
6824+
.rolling(time=time, min_periods=min_periods, pad=pad)
6825+
.count()["var1"]
6826+
)
6827+
assert_equal(result, expected)
68346828

68356829

68366830
@pytest.mark.parametrize("da", (1,), indirect=True)

0 commit comments

Comments
 (0)