Skip to content

Commit c6c3a35

Browse files
committed
[DLMED] update according to comments
Signed-off-by: Nic Ma <[email protected]>
1 parent 649a7c5 commit c6c3a35

File tree

2 files changed

+23
-23
lines changed

2 files changed

+23
-23
lines changed

monai/transforms/compose.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"""
1414

1515
import warnings
16+
from copy import deepcopy
1617
from typing import Any, Callable, Mapping, Optional, Sequence, Union
1718

1819
import numpy as np
@@ -245,16 +246,17 @@ def inverse(self, data):
245246
if not isinstance(data, Mapping):
246247
raise RuntimeError("Inverse only implemented for Mapping (dictionary) data")
247248

249+
d = deepcopy(dict(data))
248250
# loop until we get an index and then break (since they'll all be the same)
249251
key = self.__class__.__name__
250-
if self.trace_key(key) not in data:
252+
if self.trace_key(key) not in d:
251253
raise RuntimeError("can not find the index of transform have been applied.")
252254

253255
# get the index of the applied OneOf transform
254-
index = self.get_most_recent_transform(data, key)[TraceKeys.EXTRA_INFO]["index"]
256+
index = self.get_most_recent_transform(d, key)[TraceKeys.EXTRA_INFO]["index"]
255257
# and then remove the OneOf transform
256-
self.pop_transform(data, key)
258+
self.pop_transform(d, key)
257259

258260
_transform = self.transforms[index]
259261
# apply the inverse
260-
return _transform.inverse(data) if isinstance(_transform, InvertibleTransform) else data
262+
return _transform.inverse(d) if isinstance(_transform, InvertibleTransform) else d

tests/test_one_of.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -150,31 +150,29 @@ def _match(a, b):
150150
@parameterized.expand(TEST_INVERSES)
151151
def test_inverse(self, transform, invertible):
152152
data = {k: (i + 1) * 10.0 for i, k in enumerate(KEYS)}
153+
key = OneOf.__name__
153154
fwd_data = transform(data)
154-
155-
if invertible:
156-
for k in KEYS:
157-
t = fwd_data[TraceableTransform.trace_key(k)][-1]
158-
# make sure the OneOf index was stored
159-
self.assertEqual(t[TraceKeys.CLASS_NAME], OneOf.__name__)
160-
# make sure index exists and is in bounds
161-
self.assertTrue(0 <= t[TraceKeys.EXTRA_INFO]["index"] < len(transform))
155+
t = fwd_data[TraceableTransform.trace_key(key)][-1]
156+
# make sure the OneOf index was stored
157+
self.assertEqual(t[TraceKeys.CLASS_NAME], key)
158+
# make sure index exists and is in bounds
159+
self.assertTrue(0 <= t[TraceKeys.EXTRA_INFO]["index"] < len(transform))
162160

163161
# call the inverse
164162
fwd_inv_data = transform.inverse(fwd_data)
165163

166-
if invertible:
167-
for k in KEYS:
168-
# check transform was removed
169-
self.assertTrue(
170-
len(fwd_inv_data[TraceableTransform.trace_key(k)]) < len(fwd_data[TraceableTransform.trace_key(k)])
171-
)
172-
# check data is same as original (and different from forward)
173-
self.assertEqual(fwd_inv_data[k], data[k])
164+
# check transform was removed
165+
self.assertTrue(
166+
len(fwd_inv_data[TraceableTransform.trace_key(key)]) < len(fwd_data[TraceableTransform.trace_key(key)])
167+
)
168+
# check data is same as original (and different from forward)
169+
for k, v in data.items():
170+
if invertible:
171+
self.assertEqual(fwd_inv_data[k], v)
174172
self.assertNotEqual(fwd_inv_data[k], fwd_data[k])
175-
else:
176-
# if not invertible, should not change the data
177-
self.assertDictEqual(fwd_data, fwd_inv_data)
173+
else:
174+
# if not invertible, should not change the data
175+
self.assertEqual(fwd_inv_data[k], fwd_data[k])
178176

179177
def test_inverse_compose(self):
180178
transform = Compose(

0 commit comments

Comments
 (0)