Skip to content

Commit c968907

Browse files
mccleericspodKumoLiu
authored
Fix #8599: Add track_meta and weights_only arguments to PersistentDataset for MetaTensor support. (#8628)
Fixes #8599. ### Description `PersistentDataset` currently casts all `MetaTensor` objects to `torch.Tensor` objects and forces the use of `torch.load` with `weights_only=True`. This makes it impossible to save or load metadata to cached files, which may be necessary for accurate post-transform operations. To address this, this PR introduces the `track_meta` and `weights_only` arguments directly to `PersistentDataset`. They are internally passed to `convert_to_tensor` and `torch.load`, respectively. A `ValueError` is raised when `track_meta=True` and `weights_only=True`, since `MetaTensor` objects cannot be loaded with `weights_only=True` and the cached files would be continually deleted and rewritten. These changes restore the ability to cache `MetaTensor` objects by allowing explicit control over data casting and `torch.load` behavior. The default values of `track_meta=False` and `weights_only=True` will preserve the current behavior of `PersistentDataset`. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Mason Cleveland <[email protected]> Signed-off-by: Mason C. Cleveland <[email protected]> Signed-off-by: mccle <[email protected]> Co-authored-by: Eric Kerfoot <[email protected]> Co-authored-by: YunLiu <[email protected]>
1 parent 806c0e8 commit c968907

File tree

2 files changed

+57
-6
lines changed

2 files changed

+57
-6
lines changed

monai/data/dataset.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,8 @@ def __init__(
230230
pickle_protocol: int = DEFAULT_PROTOCOL,
231231
hash_transform: Callable[..., bytes] | None = None,
232232
reset_ops_id: bool = True,
233+
track_meta: bool = False,
234+
weights_only: bool = True,
233235
) -> None:
234236
"""
235237
Args:
@@ -264,7 +266,17 @@ def __init__(
264266
When this is enabled, the traced transform instance IDs will be removed from the cached MetaTensors.
265267
This is useful for skipping the transform instance checks when inverting applied operations
266268
using the cached content and with re-created transform instances.
267-
269+
track_meta: whether to track the meta information, if `True`, will convert to `MetaTensor`.
270+
default to `False`. Cannot be used with `weights_only=True`.
271+
weights_only: keyword argument passed to `torch.load` when reading cached files.
272+
default to `True`. When set to `True`, `torch.load` restricts loading to tensors and
273+
other safe objects. Setting this to `False` is required for loading `MetaTensor`
274+
objects saved with `track_meta=True`, however this creates the possibility of remote
275+
code execution through `torch.load` so be aware of the security implications of doing so.
276+
277+
Raises:
278+
ValueError: When both `track_meta=True` and `weights_only=True`, since this combination
279+
prevents cached MetaTensors from being reloaded and causes perpetual cache regeneration.
268280
"""
269281
super().__init__(data=data, transform=transform)
270282
self.cache_dir = Path(cache_dir) if cache_dir is not None else None
@@ -280,6 +292,13 @@ def __init__(
280292
if hash_transform is not None:
281293
self.set_transform_hash(hash_transform)
282294
self.reset_ops_id = reset_ops_id
295+
if track_meta and weights_only:
296+
raise ValueError(
297+
"Invalid argument combination: `track_meta=True` cannot be used with `weights_only=True`. "
298+
"To cache and reload MetaTensors, set `track_meta=True` and `weights_only=False`."
299+
)
300+
self.track_meta = track_meta
301+
self.weights_only = weights_only
283302

284303
def set_transform_hash(self, hash_xform_func: Callable[..., bytes]):
285304
"""Get hashable transforms, and then hash them. Hashable transforms
@@ -377,7 +396,7 @@ def _cachecheck(self, item_transformed):
377396

378397
if hashfile is not None and hashfile.is_file(): # cache hit
379398
try:
380-
return torch.load(hashfile, weights_only=True)
399+
return torch.load(hashfile, weights_only=self.weights_only)
381400
except PermissionError as e:
382401
if sys.platform != "win32":
383402
raise e
@@ -398,7 +417,7 @@ def _cachecheck(self, item_transformed):
398417
with tempfile.TemporaryDirectory() as tmpdirname:
399418
temp_hash_file = Path(tmpdirname) / hashfile.name
400419
torch.save(
401-
obj=convert_to_tensor(_item_transformed, convert_numeric=False),
420+
obj=convert_to_tensor(_item_transformed, convert_numeric=False, track_meta=self.track_meta),
402421
f=temp_hash_file,
403422
pickle_module=look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD),
404423
pickle_protocol=self.pickle_protocol,

tests/data/test_persistentdataset.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from __future__ import annotations
1313

14+
import contextlib
1415
import os
1516
import tempfile
1617
import unittest
@@ -20,7 +21,7 @@
2021
import torch
2122
from parameterized import parameterized
2223

23-
from monai.data import PersistentDataset, json_hashing
24+
from monai.data import MetaTensor, PersistentDataset, json_hashing
2425
from monai.transforms import Compose, Flip, Identity, LoadImaged, SimulateDelayd, Transform
2526

2627
TEST_CASE_1 = [
@@ -43,9 +44,16 @@
4344

4445
TEST_CASE_3 = [None, (128, 128, 128)]
4546

47+
TEST_CASE_4 = [True, False, False, MetaTensor]
48+
49+
TEST_CASE_5 = [True, True, True, None]
50+
51+
TEST_CASE_6 = [False, False, False, torch.Tensor]
52+
53+
TEST_CASE_7 = [False, True, False, torch.Tensor]
4654

47-
class _InplaceXform(Transform):
4855

56+
class _InplaceXform(Transform):
4957
def __call__(self, data):
5058
if data:
5159
data[0] = data[0] + np.pi
@@ -55,7 +63,6 @@ def __call__(self, data):
5563

5664

5765
class TestDataset(unittest.TestCase):
58-
5966
def test_cache(self):
6067
"""testing no inplace change to the hashed item"""
6168
items = [[list(range(i))] for i in range(5)]
@@ -168,6 +175,31 @@ def test_different_transforms(self):
168175
l2 = ((im1 - im2) ** 2).sum() ** 0.5
169176
self.assertGreater(l2, 1)
170177

178+
@parameterized.expand([TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7])
179+
def test_track_meta_and_weights_only(self, track_meta, weights_only, expected_error, expected_type):
180+
"""
181+
Ensure expected behavior for all combinations of `track_meta` and `weights_only`.
182+
"""
183+
test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]).astype(float), np.eye(4))
184+
with tempfile.TemporaryDirectory() as tempdir:
185+
nib.save(test_image, os.path.join(tempdir, "test_image.nii.gz"))
186+
test_data = [{"image": os.path.join(tempdir, "test_image.nii.gz")}]
187+
transform = Compose([LoadImaged(keys=["image"])])
188+
cache_dir = os.path.join(os.path.join(tempdir, "cache"), "data")
189+
190+
cm = self.assertRaises(ValueError) if expected_error else contextlib.nullcontext()
191+
with cm:
192+
test_dataset = PersistentDataset(
193+
data=test_data,
194+
transform=transform,
195+
cache_dir=cache_dir,
196+
track_meta=track_meta,
197+
weights_only=weights_only,
198+
)
199+
200+
im = test_dataset[0]["image"]
201+
self.assertIsInstance(im, expected_type)
202+
171203

172204
if __name__ == "__main__":
173205
unittest.main()

0 commit comments

Comments
 (0)