10
10
from xarray import DataArray , Dataset , set_options
11
11
from xarray .tests import (
12
12
assert_allclose ,
13
- assert_array_equal ,
14
13
assert_equal ,
15
14
assert_identical ,
16
15
has_dask ,
24
23
]
25
24
26
25
26
+ @pytest .fixture (params = ["numbagg" , "bottleneck" ])
27
+ def compute_backend (request ):
28
+ if request .param == "bottleneck" :
29
+ options = dict (use_bottleneck = True , use_numbagg = False )
30
+ elif request .param == "numbagg" :
31
+ options = dict (use_bottleneck = False , use_numbagg = True )
32
+ else :
33
+ raise ValueError
34
+
35
+ with xr .set_options (** options ):
36
+ yield request .param
37
+
38
+
27
39
class TestDataArrayRolling :
28
40
@pytest .mark .parametrize ("da" , (1 , 2 ), indirect = True )
29
41
@pytest .mark .parametrize ("center" , [True , False ])
@@ -87,9 +99,10 @@ def test_rolling_properties(self, da) -> None:
87
99
@pytest .mark .parametrize ("center" , (True , False , None ))
88
100
@pytest .mark .parametrize ("min_periods" , (1 , None ))
89
101
@pytest .mark .parametrize ("backend" , ["numpy" ], indirect = True )
90
- def test_rolling_wrapped_bottleneck (self , da , name , center , min_periods ) -> None :
102
+ def test_rolling_wrapped_bottleneck (
103
+ self , da , name , center , min_periods , compute_backend
104
+ ) -> None :
91
105
bn = pytest .importorskip ("bottleneck" , minversion = "1.1" )
92
-
93
106
# Test all bottleneck functions
94
107
rolling_obj = da .rolling (time = 7 , min_periods = min_periods )
95
108
@@ -98,15 +111,18 @@ def test_rolling_wrapped_bottleneck(self, da, name, center, min_periods) -> None
98
111
expected = getattr (bn , func_name )(
99
112
da .values , window = 7 , axis = 1 , min_count = min_periods
100
113
)
101
- assert_array_equal (actual .values , expected )
114
+
115
+ # Using assert_allclose because we get tiny (1e-17) differences in numbagg.
116
+ np .testing .assert_allclose (actual .values , expected )
102
117
103
118
with pytest .warns (DeprecationWarning , match = "Reductions are applied" ):
104
119
getattr (rolling_obj , name )(dim = "time" )
105
120
106
121
# Test center
107
122
rolling_obj = da .rolling (time = 7 , center = center )
108
123
actual = getattr (rolling_obj , name )()["time" ]
109
- assert_equal (actual , da ["time" ])
124
+ # Using assert_allclose because we get tiny (1e-17) differences in numbagg.
125
+ assert_allclose (actual , da ["time" ])
110
126
111
127
@requires_dask
112
128
@pytest .mark .parametrize ("name" , ("mean" , "count" ))
@@ -153,7 +169,9 @@ def test_rolling_wrapped_dask_nochunk(self, center) -> None:
153
169
@pytest .mark .parametrize ("center" , (True , False ))
154
170
@pytest .mark .parametrize ("min_periods" , (None , 1 , 2 , 3 ))
155
171
@pytest .mark .parametrize ("window" , (1 , 2 , 3 , 4 ))
156
- def test_rolling_pandas_compat (self , center , window , min_periods ) -> None :
172
+ def test_rolling_pandas_compat (
173
+ self , center , window , min_periods , compute_backend
174
+ ) -> None :
157
175
s = pd .Series (np .arange (10 ))
158
176
da = DataArray .from_series (s )
159
177
@@ -203,7 +221,9 @@ def test_rolling_construct(self, center: bool, window: int) -> None:
203
221
@pytest .mark .parametrize ("min_periods" , (None , 1 , 2 , 3 ))
204
222
@pytest .mark .parametrize ("window" , (1 , 2 , 3 , 4 ))
205
223
@pytest .mark .parametrize ("name" , ("sum" , "mean" , "std" , "max" ))
206
- def test_rolling_reduce (self , da , center , min_periods , window , name ) -> None :
224
+ def test_rolling_reduce (
225
+ self , da , center , min_periods , window , name , compute_backend
226
+ ) -> None :
207
227
if min_periods is not None and window < min_periods :
208
228
min_periods = window
209
229
@@ -223,7 +243,9 @@ def test_rolling_reduce(self, da, center, min_periods, window, name) -> None:
223
243
@pytest .mark .parametrize ("min_periods" , (None , 1 , 2 , 3 ))
224
244
@pytest .mark .parametrize ("window" , (1 , 2 , 3 , 4 ))
225
245
@pytest .mark .parametrize ("name" , ("sum" , "max" ))
226
- def test_rolling_reduce_nonnumeric (self , center , min_periods , window , name ) -> None :
246
+ def test_rolling_reduce_nonnumeric (
247
+ self , center , min_periods , window , name , compute_backend
248
+ ) -> None :
227
249
da = DataArray (
228
250
[0 , np .nan , 1 , 2 , np .nan , 3 , 4 , 5 , np .nan , 6 , 7 ], dims = "time"
229
251
).isnull ()
@@ -239,7 +261,7 @@ def test_rolling_reduce_nonnumeric(self, center, min_periods, window, name) -> N
239
261
assert_allclose (actual , expected )
240
262
assert actual .dims == expected .dims
241
263
242
- def test_rolling_count_correct (self ) -> None :
264
+ def test_rolling_count_correct (self , compute_backend ) -> None :
243
265
da = DataArray ([0 , np .nan , 1 , 2 , np .nan , 3 , 4 , 5 , np .nan , 6 , 7 ], dims = "time" )
244
266
245
267
kwargs : list [dict [str , Any ]] = [
@@ -279,7 +301,9 @@ def test_rolling_count_correct(self) -> None:
279
301
@pytest .mark .parametrize ("center" , (True , False ))
280
302
@pytest .mark .parametrize ("min_periods" , (None , 1 ))
281
303
@pytest .mark .parametrize ("name" , ("sum" , "mean" , "max" ))
282
- def test_ndrolling_reduce (self , da , center , min_periods , name ) -> None :
304
+ def test_ndrolling_reduce (
305
+ self , da , center , min_periods , name , compute_backend
306
+ ) -> None :
283
307
rolling_obj = da .rolling (time = 3 , x = 2 , center = center , min_periods = min_periods )
284
308
285
309
actual = getattr (rolling_obj , name )()
@@ -560,7 +584,7 @@ def test_rolling_properties(self, ds) -> None:
560
584
@pytest .mark .parametrize ("key" , ("z1" , "z2" ))
561
585
@pytest .mark .parametrize ("backend" , ["numpy" ], indirect = True )
562
586
def test_rolling_wrapped_bottleneck (
563
- self , ds , name , center , min_periods , key
587
+ self , ds , name , center , min_periods , key , compute_backend
564
588
) -> None :
565
589
bn = pytest .importorskip ("bottleneck" , minversion = "1.1" )
566
590
@@ -577,12 +601,12 @@ def test_rolling_wrapped_bottleneck(
577
601
)
578
602
else :
579
603
raise ValueError
580
- assert_array_equal (actual [key ].values , expected )
604
+ np . testing . assert_allclose (actual [key ].values , expected )
581
605
582
606
# Test center
583
607
rolling_obj = ds .rolling (time = 7 , center = center )
584
608
actual = getattr (rolling_obj , name )()["time" ]
585
- assert_equal (actual , ds ["time" ])
609
+ assert_allclose (actual , ds ["time" ])
586
610
587
611
@pytest .mark .parametrize ("center" , (True , False ))
588
612
@pytest .mark .parametrize ("min_periods" , (None , 1 , 2 , 3 ))
0 commit comments