forked from Project-MONAI/MONAI
-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathtest_splitdimd.py
More file actions
96 lines (82 loc) · 3.73 KB
/
test_splitdimd.py
File metadata and controls
96 lines (82 loc) · 3.73 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import unittest
from copy import deepcopy
import numpy as np
import torch
from parameterized import parameterized
from monai.data.meta_tensor import MetaTensor
from monai.transforms import LoadImaged
from monai.transforms.utility.dictionary import SplitDimd
from tests.utils import TEST_NDARRAYS, assert_allclose, make_nifti_image, make_rand_affine
TESTS = []
for p in TEST_NDARRAYS:
for keepdim in (True, False):
for update_meta in (True, False):
for list_output in (True, False):
TESTS.append((keepdim, p, update_meta, list_output))
class TestSplitDimd(unittest.TestCase):
data: MetaTensor
@classmethod
def setUpClass(cls) -> None:
arr = np.random.rand(2, 10, 8, 7)
affine = make_rand_affine()
data = {"i": make_nifti_image(arr, affine)}
loader = LoadImaged("i", image_only=True)
cls.data = loader(data)
@parameterized.expand(TESTS)
def test_correct(self, keepdim, im_type, update_meta, list_output):
data = deepcopy(self.data)
data["i"] = im_type(data["i"])
arr = data["i"]
for dim in range(arr.ndim):
out = SplitDimd("i", dim=dim, keepdim=keepdim, update_meta=update_meta, list_output=list_output)(data)
if list_output:
self.assertIsInstance(out, list)
self.assertEqual(len(out), arr.shape[dim])
else:
self.assertIsInstance(out, dict)
self.assertEqual(len(out.keys()), len(data.keys()) + arr.shape[dim])
# if updating metadata, pick some random points and
# check same world coordinates between input and output
if update_meta:
for _ in range(10):
idx = [np.random.choice(i) for i in arr.shape]
split_im_idx = idx[dim]
split_idx = deepcopy(idx)
split_idx[dim] = 0
if list_output:
split_im = out[split_im_idx]["i"]
else:
split_im = out[f"i_{split_im_idx}"]
if isinstance(data, MetaTensor) and isinstance(split_im, MetaTensor):
# idx[1:] to remove channel and then add 1 for 4th element
real_world = data.affine @ torch.tensor(idx[1:] + [1]).double()
real_world2 = split_im.affine @ torch.tensor(split_idx[1:] + [1]).double()
assert_allclose(real_world, real_world2)
if list_output:
out = out[0]["i"]
else:
out = out["i_0"]
expected_ndim = arr.ndim if keepdim else arr.ndim - 1
self.assertEqual(out.ndim, expected_ndim)
# assert is a shallow copy
arr[0, 0, 0, 0] *= 2
self.assertEqual(arr.flatten()[0], out.flatten()[0])
def test_singleton(self):
shape = (2, 1, 8, 7)
for p in TEST_NDARRAYS:
arr = p(np.random.rand(*shape))
out = SplitDimd("i", dim=1)({"i": arr})
self.assertEqual(out["i"].shape, shape)
if __name__ == "__main__":
unittest.main()