Skip to content

Commit 1f94829

Browse files
Use numbagg for rolling methods (#8493)
* Use numbagg for `rolling` methods A couple of tests are failing for the multi-dimensional case, which I'll fix before merge. * wip * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * whatsnew --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 704de55 commit 1f94829

File tree

3 files changed

+144
-27
lines changed

3 files changed

+144
-27
lines changed

doc/whats-new.rst

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,13 @@ v2023.11.1 (unreleased)
2323
New Features
2424
~~~~~~~~~~~~
2525

26+
- :py:meth:`rolling` uses numbagg <https://github.com/numbagg/numbagg>`_ for
27+
most of its computations by default. Numbagg is up to 5x faster than bottleneck
28+
where parallelization is possible. Where parallelization isn't possible — for
29+
example a 1D array — it's about the same speed as bottleneck, and 2-5x faster
30+
than pandas' default functions. (:pull:`8493`). numbagg is an optional
31+
dependency, so requires installing separately.
32+
By `Maximilian Roos <https://github.com/max-sixty>`_.
2633
- Use a concise format when plotting datetime arrays. (:pull:`8449`).
2734
By `Jimmy Westling <https://github.com/illviljan>`_.
2835
- Avoid overwriting unchanged existing coordinate variables when appending by setting ``mode='a-'``.
@@ -90,7 +97,7 @@ Documentation
9097
Internal Changes
9198
~~~~~~~~~~~~~~~~
9299

93-
- :py:meth:`DataArray.bfill` & :py:meth:`DataArray.ffill` now use numbagg by
100+
- :py:meth:`DataArray.bfill` & :py:meth:`DataArray.ffill` now use numbagg <https://github.com/numbagg/numbagg>`_ by
94101
default, which is up to 5x faster where parallelization is possible. (:pull:`8339`)
95102
By `Maximilian Roos <https://github.com/max-sixty>`_.
96103
- Update mypy version to 1.7 (:issue:`8448`, :pull:`8501`).

xarray/core/rolling.py

Lines changed: 99 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,14 @@
88
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar
99

1010
import numpy as np
11+
from packaging.version import Version
1112

12-
from xarray.core import dtypes, duck_array_ops, utils
13+
from xarray.core import dtypes, duck_array_ops, pycompat, utils
1314
from xarray.core.arithmetic import CoarsenArithmetic
1415
from xarray.core.options import OPTIONS, _get_keep_attrs
1516
from xarray.core.pycompat import is_duck_dask_array
1617
from xarray.core.types import CoarsenBoundaryOptions, SideOptions, T_Xarray
17-
from xarray.core.utils import either_dict_or_kwargs
18+
from xarray.core.utils import either_dict_or_kwargs, module_available
1819

1920
try:
2021
import bottleneck
@@ -145,22 +146,35 @@ def _reduce_method( # type: ignore[misc]
145146
name: str, fillna: Any, rolling_agg_func: Callable | None = None
146147
) -> Callable[..., T_Xarray]:
147148
"""Constructs reduction methods built on a numpy reduction function (e.g. sum),
148-
a bottleneck reduction function (e.g. move_sum), or a Rolling reduction (_mean).
149+
a numbagg reduction function (e.g. move_sum), a bottleneck reduction function
150+
(e.g. move_sum), or a Rolling reduction (_mean).
151+
152+
The logic here for which function to run is quite diffuse, across this method &
153+
_array_reduce. Arguably we could refactor this. But one constraint is that we
154+
need context of xarray options, of the functions each library offers, of
155+
the array (e.g. dtype).
149156
"""
150157
if rolling_agg_func:
151158
array_agg_func = None
152159
else:
153160
array_agg_func = getattr(duck_array_ops, name)
154161

155162
bottleneck_move_func = getattr(bottleneck, "move_" + name, None)
163+
if module_available("numbagg"):
164+
import numbagg
165+
166+
numbagg_move_func = getattr(numbagg, "move_" + name, None)
167+
else:
168+
numbagg_move_func = None
156169

157170
def method(self, keep_attrs=None, **kwargs):
158171
keep_attrs = self._get_keep_attrs(keep_attrs)
159172

160-
return self._numpy_or_bottleneck_reduce(
161-
array_agg_func,
162-
bottleneck_move_func,
163-
rolling_agg_func,
173+
return self._array_reduce(
174+
array_agg_func=array_agg_func,
175+
bottleneck_move_func=bottleneck_move_func,
176+
numbagg_move_func=numbagg_move_func,
177+
rolling_agg_func=rolling_agg_func,
164178
keep_attrs=keep_attrs,
165179
fillna=fillna,
166180
**kwargs,
@@ -510,9 +524,47 @@ def _counts(self, keep_attrs: bool | None) -> DataArray:
510524
)
511525
return counts
512526

513-
def _bottleneck_reduce(self, func, keep_attrs, **kwargs):
514-
from xarray.core.dataarray import DataArray
527+
def _numbagg_reduce(self, func, keep_attrs, **kwargs):
528+
# Some of this is copied from `_bottleneck_reduce`, we could reduce this as part
529+
# of a wider refactor.
530+
531+
axis = self.obj.get_axis_num(self.dim[0])
515532

533+
padded = self.obj.variable
534+
if self.center[0]:
535+
if is_duck_dask_array(padded.data):
536+
# workaround to make the padded chunk size larger than
537+
# self.window - 1
538+
shift = -(self.window[0] + 1) // 2
539+
offset = (self.window[0] - 1) // 2
540+
valid = (slice(None),) * axis + (
541+
slice(offset, offset + self.obj.shape[axis]),
542+
)
543+
else:
544+
shift = (-self.window[0] // 2) + 1
545+
valid = (slice(None),) * axis + (slice(-shift, None),)
546+
padded = padded.pad({self.dim[0]: (0, -shift)}, mode="constant")
547+
548+
if is_duck_dask_array(padded.data) and False:
549+
raise AssertionError("should not be reachable")
550+
else:
551+
values = func(
552+
padded.data,
553+
window=self.window[0],
554+
min_count=self.min_periods,
555+
axis=axis,
556+
)
557+
558+
if self.center[0]:
559+
values = values[valid]
560+
561+
attrs = self.obj.attrs if keep_attrs else {}
562+
563+
return self.obj.__class__(
564+
values, self.obj.coords, attrs=attrs, name=self.obj.name
565+
)
566+
567+
def _bottleneck_reduce(self, func, keep_attrs, **kwargs):
516568
# bottleneck doesn't allow min_count to be 0, although it should
517569
# work the same as if min_count = 1
518570
# Note bottleneck only works with 1d-rolling.
@@ -550,12 +602,15 @@ def _bottleneck_reduce(self, func, keep_attrs, **kwargs):
550602

551603
attrs = self.obj.attrs if keep_attrs else {}
552604

553-
return DataArray(values, self.obj.coords, attrs=attrs, name=self.obj.name)
605+
return self.obj.__class__(
606+
values, self.obj.coords, attrs=attrs, name=self.obj.name
607+
)
554608

555-
def _numpy_or_bottleneck_reduce(
609+
def _array_reduce(
556610
self,
557611
array_agg_func,
558612
bottleneck_move_func,
613+
numbagg_move_func,
559614
rolling_agg_func,
560615
keep_attrs,
561616
fillna,
@@ -571,6 +626,35 @@ def _numpy_or_bottleneck_reduce(
571626
)
572627
del kwargs["dim"]
573628

629+
if (
630+
OPTIONS["use_numbagg"]
631+
and module_available("numbagg")
632+
and pycompat.mod_version("numbagg") >= Version("0.6.3")
633+
and numbagg_move_func is not None
634+
# TODO: we could at least allow this for the equivalent of `apply_ufunc`'s
635+
# "parallelized". `rolling_exp` does this, as an example (but rolling_exp is
636+
# much simpler)
637+
and not is_duck_dask_array(self.obj.data)
638+
# Numbagg doesn't handle object arrays and generally has dtype consistency,
639+
# so doesn't deal well with bool arrays which are expected to change type.
640+
and self.obj.data.dtype.kind not in "ObMm"
641+
# TODO: we could also allow this, probably as part of a refactoring of this
642+
# module, so we can use the machinery in `self.reduce`.
643+
and self.ndim == 1
644+
):
645+
import numbagg
646+
647+
# Numbagg has a default ddof of 1. I (@max-sixty) think we should make
648+
# this the default in xarray too, but until we do, don't use numbagg for
649+
# std and var unless ddof is set to 1.
650+
if (
651+
numbagg_move_func not in [numbagg.move_std, numbagg.move_var]
652+
or kwargs.get("ddof") == 1
653+
):
654+
return self._numbagg_reduce(
655+
numbagg_move_func, keep_attrs=keep_attrs, **kwargs
656+
)
657+
574658
if (
575659
OPTIONS["use_bottleneck"]
576660
and bottleneck_move_func is not None
@@ -583,8 +667,10 @@ def _numpy_or_bottleneck_reduce(
583667
return self._bottleneck_reduce(
584668
bottleneck_move_func, keep_attrs=keep_attrs, **kwargs
585669
)
670+
586671
if rolling_agg_func:
587672
return rolling_agg_func(self, keep_attrs=self._get_keep_attrs(keep_attrs))
673+
588674
if fillna is not None:
589675
if fillna is dtypes.INF:
590676
fillna = dtypes.get_pos_infinity(self.obj.dtype, max_for_int=True)
@@ -705,7 +791,7 @@ def _counts(self, keep_attrs: bool | None) -> Dataset:
705791
DataArrayRolling._counts, keep_attrs=keep_attrs
706792
)
707793

708-
def _numpy_or_bottleneck_reduce(
794+
def _array_reduce(
709795
self,
710796
array_agg_func,
711797
bottleneck_move_func,
@@ -715,7 +801,7 @@ def _numpy_or_bottleneck_reduce(
715801
):
716802
return self._dataset_implementation(
717803
functools.partial(
718-
DataArrayRolling._numpy_or_bottleneck_reduce,
804+
DataArrayRolling._array_reduce,
719805
array_agg_func=array_agg_func,
720806
bottleneck_move_func=bottleneck_move_func,
721807
rolling_agg_func=rolling_agg_func,

xarray/tests/test_rolling.py

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from xarray import DataArray, Dataset, set_options
1111
from xarray.tests import (
1212
assert_allclose,
13-
assert_array_equal,
1413
assert_equal,
1514
assert_identical,
1615
has_dask,
@@ -24,6 +23,19 @@
2423
]
2524

2625

26+
@pytest.fixture(params=["numbagg", "bottleneck"])
27+
def compute_backend(request):
28+
if request.param == "bottleneck":
29+
options = dict(use_bottleneck=True, use_numbagg=False)
30+
elif request.param == "numbagg":
31+
options = dict(use_bottleneck=False, use_numbagg=True)
32+
else:
33+
raise ValueError
34+
35+
with xr.set_options(**options):
36+
yield request.param
37+
38+
2739
class TestDataArrayRolling:
2840
@pytest.mark.parametrize("da", (1, 2), indirect=True)
2941
@pytest.mark.parametrize("center", [True, False])
@@ -87,9 +99,10 @@ def test_rolling_properties(self, da) -> None:
8799
@pytest.mark.parametrize("center", (True, False, None))
88100
@pytest.mark.parametrize("min_periods", (1, None))
89101
@pytest.mark.parametrize("backend", ["numpy"], indirect=True)
90-
def test_rolling_wrapped_bottleneck(self, da, name, center, min_periods) -> None:
102+
def test_rolling_wrapped_bottleneck(
103+
self, da, name, center, min_periods, compute_backend
104+
) -> None:
91105
bn = pytest.importorskip("bottleneck", minversion="1.1")
92-
93106
# Test all bottleneck functions
94107
rolling_obj = da.rolling(time=7, min_periods=min_periods)
95108

@@ -98,15 +111,18 @@ def test_rolling_wrapped_bottleneck(self, da, name, center, min_periods) -> None
98111
expected = getattr(bn, func_name)(
99112
da.values, window=7, axis=1, min_count=min_periods
100113
)
101-
assert_array_equal(actual.values, expected)
114+
115+
# Using assert_allclose because we get tiny (1e-17) differences in numbagg.
116+
np.testing.assert_allclose(actual.values, expected)
102117

103118
with pytest.warns(DeprecationWarning, match="Reductions are applied"):
104119
getattr(rolling_obj, name)(dim="time")
105120

106121
# Test center
107122
rolling_obj = da.rolling(time=7, center=center)
108123
actual = getattr(rolling_obj, name)()["time"]
109-
assert_equal(actual, da["time"])
124+
# Using assert_allclose because we get tiny (1e-17) differences in numbagg.
125+
assert_allclose(actual, da["time"])
110126

111127
@requires_dask
112128
@pytest.mark.parametrize("name", ("mean", "count"))
@@ -153,7 +169,9 @@ def test_rolling_wrapped_dask_nochunk(self, center) -> None:
153169
@pytest.mark.parametrize("center", (True, False))
154170
@pytest.mark.parametrize("min_periods", (None, 1, 2, 3))
155171
@pytest.mark.parametrize("window", (1, 2, 3, 4))
156-
def test_rolling_pandas_compat(self, center, window, min_periods) -> None:
172+
def test_rolling_pandas_compat(
173+
self, center, window, min_periods, compute_backend
174+
) -> None:
157175
s = pd.Series(np.arange(10))
158176
da = DataArray.from_series(s)
159177

@@ -203,7 +221,9 @@ def test_rolling_construct(self, center: bool, window: int) -> None:
203221
@pytest.mark.parametrize("min_periods", (None, 1, 2, 3))
204222
@pytest.mark.parametrize("window", (1, 2, 3, 4))
205223
@pytest.mark.parametrize("name", ("sum", "mean", "std", "max"))
206-
def test_rolling_reduce(self, da, center, min_periods, window, name) -> None:
224+
def test_rolling_reduce(
225+
self, da, center, min_periods, window, name, compute_backend
226+
) -> None:
207227
if min_periods is not None and window < min_periods:
208228
min_periods = window
209229

@@ -223,7 +243,9 @@ def test_rolling_reduce(self, da, center, min_periods, window, name) -> None:
223243
@pytest.mark.parametrize("min_periods", (None, 1, 2, 3))
224244
@pytest.mark.parametrize("window", (1, 2, 3, 4))
225245
@pytest.mark.parametrize("name", ("sum", "max"))
226-
def test_rolling_reduce_nonnumeric(self, center, min_periods, window, name) -> None:
246+
def test_rolling_reduce_nonnumeric(
247+
self, center, min_periods, window, name, compute_backend
248+
) -> None:
227249
da = DataArray(
228250
[0, np.nan, 1, 2, np.nan, 3, 4, 5, np.nan, 6, 7], dims="time"
229251
).isnull()
@@ -239,7 +261,7 @@ def test_rolling_reduce_nonnumeric(self, center, min_periods, window, name) -> N
239261
assert_allclose(actual, expected)
240262
assert actual.dims == expected.dims
241263

242-
def test_rolling_count_correct(self) -> None:
264+
def test_rolling_count_correct(self, compute_backend) -> None:
243265
da = DataArray([0, np.nan, 1, 2, np.nan, 3, 4, 5, np.nan, 6, 7], dims="time")
244266

245267
kwargs: list[dict[str, Any]] = [
@@ -279,7 +301,9 @@ def test_rolling_count_correct(self) -> None:
279301
@pytest.mark.parametrize("center", (True, False))
280302
@pytest.mark.parametrize("min_periods", (None, 1))
281303
@pytest.mark.parametrize("name", ("sum", "mean", "max"))
282-
def test_ndrolling_reduce(self, da, center, min_periods, name) -> None:
304+
def test_ndrolling_reduce(
305+
self, da, center, min_periods, name, compute_backend
306+
) -> None:
283307
rolling_obj = da.rolling(time=3, x=2, center=center, min_periods=min_periods)
284308

285309
actual = getattr(rolling_obj, name)()
@@ -560,7 +584,7 @@ def test_rolling_properties(self, ds) -> None:
560584
@pytest.mark.parametrize("key", ("z1", "z2"))
561585
@pytest.mark.parametrize("backend", ["numpy"], indirect=True)
562586
def test_rolling_wrapped_bottleneck(
563-
self, ds, name, center, min_periods, key
587+
self, ds, name, center, min_periods, key, compute_backend
564588
) -> None:
565589
bn = pytest.importorskip("bottleneck", minversion="1.1")
566590

@@ -577,12 +601,12 @@ def test_rolling_wrapped_bottleneck(
577601
)
578602
else:
579603
raise ValueError
580-
assert_array_equal(actual[key].values, expected)
604+
np.testing.assert_allclose(actual[key].values, expected)
581605

582606
# Test center
583607
rolling_obj = ds.rolling(time=7, center=center)
584608
actual = getattr(rolling_obj, name)()["time"]
585-
assert_equal(actual, ds["time"])
609+
assert_allclose(actual, ds["time"])
586610

587611
@pytest.mark.parametrize("center", (True, False))
588612
@pytest.mark.parametrize("min_periods", (None, 1, 2, 3))

0 commit comments

Comments
 (0)