1
1
import unittest
2
2
3
+ import numpy as np
3
4
from numpy .testing import assert_array_equal
4
5
5
6
from xray import utils , DataArray
@@ -36,12 +37,26 @@ def requires_netCDF4(test):
36
37
return test if has_netCDF4 else unittest .skip ('requires netCDF4' )(test )
37
38
38
39
40
+ def data_allclose_or_equiv (arr1 , arr2 , rtol = 1e-05 , atol = 1e-08 ):
41
+ exact_dtypes = [np .datetime64 , np .timedelta64 , np .string_ ]
42
+ if any (any (np .issubdtype (arr .dtype , t ) for t in exact_dtypes )
43
+ or arr .dtype == object for arr in [arr1 , arr2 ]):
44
+ return np .array_equal (arr1 , arr2 )
45
+ else :
46
+ return utils .allclose_or_equiv (arr1 , arr2 , rtol = rtol , atol = atol )
47
+
48
+
39
49
class TestCase (unittest .TestCase ):
40
50
def assertVariableEqual (self , v1 , v2 ):
41
51
self .assertTrue (as_variable (v1 ).equals (v2 ))
42
52
53
+ def assertVariableIdentical (self , v1 , v2 ):
54
+ self .assertTrue (as_variable (v1 ).identical (v2 ))
55
+
43
56
def assertVariableAllClose (self , v1 , v2 , rtol = 1e-05 , atol = 1e-08 ):
44
- self .assertTrue (utils .variable_allclose (v1 , v2 , rtol = rtol , atol = atol ))
57
+ self .assertEqual (v1 .dimensions , v2 .dimensions )
58
+ self .assertTrue (data_allclose_or_equiv (v1 .values , v2 .values ,
59
+ rtol = rtol , atol = atol ))
45
60
46
61
def assertVariableNotEqual (self , v1 , v2 ):
47
62
self .assertFalse (as_variable (v1 ).equals (v2 ))
@@ -52,36 +67,47 @@ def assertArrayEqual(self, a1, a2):
52
67
def assertDatasetEqual (self , d1 , d2 ):
53
68
# this method is functionally equivalent to `assert d1 == d2`, but it
54
69
# checks each aspect of equality separately for easier debugging
55
- self .assertTrue (utils .dict_equal (d1 .attributes , d2 .attributes ))
56
70
self .assertEqual (sorted (d1 .variables ), sorted (d2 .variables ))
57
71
for k in d1 :
58
72
v1 = d1 .variables [k ]
59
73
v2 = d2 .variables [k ]
60
74
self .assertVariableEqual (v1 , v2 )
61
75
76
+ def assertDatasetIdentical (self , d1 , d2 ):
77
+ # this method is functionally equivalent to `assert d1.identical(d2)`,
78
+ # but it checks each aspect of equality separately for easier debugging
79
+ self .assertTrue (utils .dict_equal (d1 .attrs , d2 .attrs ))
80
+ self .assertEqual (sorted (d1 .variables ), sorted (d2 .variables ))
81
+ for k in d1 :
82
+ v1 = d1 .variables [k ]
83
+ v2 = d2 .variables [k ]
84
+ self .assertTrue (v1 .identical (v2 ))
85
+
62
86
def assertDatasetAllClose (self , d1 , d2 , rtol = 1e-05 , atol = 1e-08 ):
63
- self .assertTrue (utils .dict_equal (d1 .attributes , d2 .attributes ))
64
87
self .assertEqual (sorted (d1 .variables ), sorted (d2 .variables ))
65
88
for k in d1 :
66
89
v1 = d1 .variables [k ]
67
90
v2 = d2 .variables [k ]
68
91
self .assertVariableAllClose (v1 , v2 , rtol = rtol , atol = atol )
69
92
93
+ def assertCoordsEqual (self , d1 , d2 ):
94
+ self .assertEqual (sorted (d1 .coordinates ), sorted (d2 .coordinates ))
95
+ for k in d1 .coordinates :
96
+ v1 = d1 .coordinates [k ]
97
+ v2 = d2 .coordinates [k ]
98
+ self .assertVariableEqual (v1 , v2 )
99
+
70
100
def assertDataArrayEqual (self , ar1 , ar2 ):
101
+ self .assertVariableEqual (ar1 , ar2 )
102
+ self .assertCoordsEqual (ar1 , ar2 )
103
+
104
+ def assertDataArrayIdentical (self , ar1 , ar2 ):
71
105
self .assertEqual (ar1 .name , ar2 .name )
72
- self .assertDatasetEqual (ar1 .dataset , ar2 .dataset )
106
+ self .assertDatasetIdentical (ar1 .dataset , ar2 .dataset )
73
107
74
108
def assertDataArrayAllClose (self , ar1 , ar2 , rtol = 1e-05 , atol = 1e-08 ):
75
- self .assertEqual (ar1 .name , ar2 .name )
76
- self .assertDatasetAllClose (ar1 .dataset , ar2 .dataset ,
77
- rtol = rtol , atol = atol )
78
-
79
- def assertDataArrayEquiv (self , ar1 , ar2 ):
80
- self .assertIsInstance (ar1 , DataArray )
81
- self .assertIsInstance (ar2 , DataArray )
82
- random_name = 'randomly-renamed-variable'
83
- self .assertDataArrayEqual (ar1 .rename (random_name ),
84
- ar2 .rename (random_name ))
109
+ self .assertVariableAllClose (ar1 , ar2 , rtol = rtol , atol = atol )
110
+ self .assertCoordsEqual (ar1 , ar2 )
85
111
86
112
87
113
class ReturnItem (object ):
0 commit comments