@@ -101,25 +101,32 @@ def assert_equal(a, b, tolerance=None):
101
101
else :
102
102
tolerance = {}
103
103
104
- if has_dask and isinstance (a , dask_array_type ) or isinstance (b , dask_array_type ):
105
- # sometimes it's nice to see values and shapes
106
- # rather than being dropped into some file in dask
107
- np .testing .assert_allclose (a , b , ** tolerance )
108
- # does some validation of the dask graph
109
- da .utils .assert_eq (a , b , equal_nan = True )
104
+ # Always run the numpy comparison first, so that we get nice error messages with dask.
105
+ # sometimes it's nice to see values and shapes
106
+ # rather than being dropped into some file in dask
107
+ if a .dtype != b .dtype :
108
+ raise AssertionError (f"a and b have different dtypes: (a: { a .dtype } , b: { b .dtype } )" )
109
+
110
+ if has_dask :
111
+ a_eager = a .compute () if isinstance (a , dask_array_type ) else a
112
+ b_eager = b .compute () if isinstance (b , dask_array_type ) else b
113
+
114
+ if a .dtype .kind in "SUMm" :
115
+ np .testing .assert_equal (a_eager , b_eager )
110
116
else :
111
- if a .dtype != b .dtype :
112
- raise AssertionError (f"a and b have different dtypes: (a: { a .dtype } , b: { b .dtype } )" )
117
+ np .testing .assert_allclose (a_eager , b_eager , equal_nan = True , ** tolerance )
113
118
114
- np .testing .assert_allclose (a , b , equal_nan = True , ** tolerance )
119
+ if has_dask and isinstance (a , dask_array_type ) or isinstance (b , dask_array_type ):
120
+ # does some validation of the dask graph
121
+ dask_assert_eq (a , b , equal_nan = True )
115
122
116
123
117
124
def assert_equal_tuple (a , b ):
118
125
"""assert_equal for .blocks indexing tuples"""
119
126
assert len (a ) == len (b )
120
127
121
128
for a_ , b_ in zip (a , b ):
122
- assert type (a_ ) == type (b_ )
129
+ assert type (a_ ) is type (b_ )
123
130
if isinstance (a_ , np .ndarray ):
124
131
np .testing .assert_array_equal (a_ , b_ )
125
132
else :
@@ -156,3 +163,91 @@ def assert_equal_tuple(a, b):
156
163
"quantile" ,
157
164
"nanquantile" ,
158
165
) + tuple (SCIPY_STATS_FUNCS )
166
+
167
+
168
+ def dask_assert_eq (
169
+ a ,
170
+ b ,
171
+ check_shape = True ,
172
+ check_graph = True ,
173
+ check_meta = True ,
174
+ check_chunks = True ,
175
+ check_ndim = True ,
176
+ check_type = True ,
177
+ check_dtype = True ,
178
+ equal_nan = True ,
179
+ scheduler = "sync" ,
180
+ ** kwargs ,
181
+ ):
182
+ """dask.array.utils.assert_eq modified to skip value checks. Their code is buggy for some dtypes.
183
+ We just check values through numpy and care about validating the graph in this function."""
184
+ from dask .array .utils import _get_dt_meta_computed
185
+
186
+ a_original = a
187
+ b_original = b
188
+
189
+ if isinstance (a , (list , int , float )):
190
+ a = np .array (a )
191
+ if isinstance (b , (list , int , float )):
192
+ b = np .array (b )
193
+
194
+ a , adt , a_meta , a_computed = _get_dt_meta_computed (
195
+ a ,
196
+ check_shape = check_shape ,
197
+ check_graph = check_graph ,
198
+ check_chunks = check_chunks ,
199
+ check_ndim = check_ndim ,
200
+ scheduler = scheduler ,
201
+ )
202
+ b , bdt , b_meta , b_computed = _get_dt_meta_computed (
203
+ b ,
204
+ check_shape = check_shape ,
205
+ check_graph = check_graph ,
206
+ check_chunks = check_chunks ,
207
+ check_ndim = check_ndim ,
208
+ scheduler = scheduler ,
209
+ )
210
+
211
+ if check_type :
212
+ _a = a if a .shape else a .item ()
213
+ _b = b if b .shape else b .item ()
214
+ assert type (_a ) is type (_b ), f"a and b have different types (a: { type (_a )} , b: { type (_b )} )"
215
+ if check_meta :
216
+ if hasattr (a , "_meta" ) and hasattr (b , "_meta" ):
217
+ dask_assert_eq (a ._meta , b ._meta )
218
+ if hasattr (a_original , "_meta" ):
219
+ msg = (
220
+ f"compute()-ing 'a' changes its number of dimensions "
221
+ f"(before: { a_original ._meta .ndim } , after: { a .ndim } )"
222
+ )
223
+ assert a_original ._meta .ndim == a .ndim , msg
224
+ if a_meta is not None :
225
+ msg = (
226
+ f"compute()-ing 'a' changes its type "
227
+ f"(before: { type (a_original ._meta )} , after: { type (a_meta )} )"
228
+ )
229
+ assert type (a_original ._meta ) is type (a_meta ), msg
230
+ if not (np .isscalar (a_meta ) or np .isscalar (a_computed )):
231
+ msg = (
232
+ f"compute()-ing 'a' results in a different type than implied by its metadata "
233
+ f"(meta: { type (a_meta )} , computed: { type (a_computed )} )"
234
+ )
235
+ assert type (a_meta ) is type (a_computed ), msg
236
+ if hasattr (b_original , "_meta" ):
237
+ msg = (
238
+ f"compute()-ing 'b' changes its number of dimensions "
239
+ f"(before: { b_original ._meta .ndim } , after: { b .ndim } )"
240
+ )
241
+ assert b_original ._meta .ndim == b .ndim , msg
242
+ if b_meta is not None :
243
+ msg = (
244
+ f"compute()-ing 'b' changes its type "
245
+ f"(before: { type (b_original ._meta )} , after: { type (b_meta )} )"
246
+ )
247
+ assert type (b_original ._meta ) is type (b_meta ), msg
248
+ if not (np .isscalar (b_meta ) or np .isscalar (b_computed )):
249
+ msg = (
250
+ f"compute()-ing 'b' results in a different type than implied by its metadata "
251
+ f"(meta: { type (b_meta )} , computed: { type (b_computed )} )"
252
+ )
253
+ assert type (b_meta ) is type (b_computed ), msg
0 commit comments