@@ -150,31 +150,29 @@ def _match(a, b):
150
150
@parameterized .expand (TEST_INVERSES )
151
151
def test_inverse (self , transform , invertible ):
152
152
data = {k : (i + 1 ) * 10.0 for i , k in enumerate (KEYS )}
153
+ key = OneOf .__name__
153
154
fwd_data = transform (data )
154
-
155
- if invertible :
156
- for k in KEYS :
157
- t = fwd_data [TraceableTransform .trace_key (k )][- 1 ]
158
- # make sure the OneOf index was stored
159
- self .assertEqual (t [TraceKeys .CLASS_NAME ], OneOf .__name__ )
160
- # make sure index exists and is in bounds
161
- self .assertTrue (0 <= t [TraceKeys .EXTRA_INFO ]["index" ] < len (transform ))
155
+ t = fwd_data [TraceableTransform .trace_key (key )][- 1 ]
156
+ # make sure the OneOf index was stored
157
+ self .assertEqual (t [TraceKeys .CLASS_NAME ], key )
158
+ # make sure index exists and is in bounds
159
+ self .assertTrue (0 <= t [TraceKeys .EXTRA_INFO ]["index" ] < len (transform ))
162
160
163
161
# call the inverse
164
162
fwd_inv_data = transform .inverse (fwd_data )
165
163
166
- if invertible :
167
- for k in KEYS :
168
- # check transform was removed
169
- self . assertTrue (
170
- len ( fwd_inv_data [ TraceableTransform . trace_key ( k )]) < len ( fwd_data [ TraceableTransform . trace_key ( k )] )
171
- )
172
- # check data is same as original (and different from forward)
173
- self .assertEqual (fwd_inv_data [k ], data [ k ] )
164
+ # check transform was removed
165
+ self . assertTrue (
166
+ len ( fwd_inv_data [ TraceableTransform . trace_key ( key )]) < len ( fwd_data [ TraceableTransform . trace_key ( key )])
167
+ )
168
+ # check data is same as original (and different from forward )
169
+ for k , v in data . items ():
170
+ if invertible :
171
+ self .assertEqual (fwd_inv_data [k ], v )
174
172
self .assertNotEqual (fwd_inv_data [k ], fwd_data [k ])
175
- else :
176
- # if not invertible, should not change the data
177
- self .assertDictEqual ( fwd_data , fwd_inv_data )
173
+ else :
174
+ # if not invertible, should not change the data
175
+ self .assertEqual ( fwd_inv_data [ k ], fwd_data [ k ] )
178
176
179
177
def test_inverse_compose (self ):
180
178
transform = Compose (
0 commit comments