Skip to content

Commit 5807f5d

Browse files
committed
3628 0-d array to_contiguous (#3629)
* 0-d array Signed-off-by: Wenqi Li <[email protected]> * handling string types Signed-off-by: Wenqi Li <[email protected]>
1 parent 5353d39 commit 5807f5d

File tree

3 files changed

+53
-5
lines changed

3 files changed

+53
-5
lines changed

monai/transforms/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import warnings
1515
from contextlib import contextmanager
1616
from inspect import getmembers, isclass
17-
from typing import Any, Callable, Hashable, Iterable, List, Optional, Sequence, Tuple, Union
17+
from typing import Any, Callable, Hashable, Iterable, List, Mapping, Optional, Sequence, Tuple, Union
1818

1919
import numpy as np
2020
import torch
@@ -1508,11 +1508,11 @@ def convert_to_contiguous(data, **kwargs):
15081508
https://pytorch.org/docs/stable/generated/torch.Tensor.contiguous.html#torch.Tensor.contiguous.
15091509
15101510
"""
1511-
if isinstance(data, (np.ndarray, torch.Tensor)):
1511+
if isinstance(data, (np.ndarray, torch.Tensor, str, bytes)):
15121512
return ascontiguousarray(data, **kwargs)
1513-
if isinstance(data, dict):
1513+
if isinstance(data, Mapping):
15141514
return {k: convert_to_contiguous(v, **kwargs) for k, v in data.items()}
1515-
if isinstance(data, (list, tuple)):
1515+
if isinstance(data, Sequence):
15161516
return [convert_to_contiguous(i, **kwargs) for i in data]
15171517
return data
15181518

monai/transforms/utils_pytorch_numpy_unification.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,5 +334,9 @@ def ascontiguousarray(x: NdarrayOrTensor, **kwargs):
334334
335335
"""
336336
if isinstance(x, np.ndarray):
337+
if x.ndim == 0:
338+
return x
337339
return np.ascontiguousarray(x)
338-
return x.contiguous(**kwargs)
340+
if isinstance(x, torch.Tensor):
341+
return x.contiguous(**kwargs)
342+
return x

tests/test_to_contiguous.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import unittest
13+
14+
import numpy as np
15+
import torch
16+
17+
from monai.transforms import convert_to_contiguous
18+
from tests.utils import assert_allclose
19+
20+
21+
class TestToContiguous(unittest.TestCase):
22+
def test_decollation_dict(self):
23+
tochange = np.moveaxis(np.zeros((2, 3, 4)), 0, -1)
24+
test_dict = {"test_key": [[1]], 0: np.array(0), 1: np.array([0]), "nested": {"nested": [tochange]}}
25+
output = convert_to_contiguous(test_dict)
26+
self.assertEqual(output["test_key"], [[1]])
27+
assert_allclose(output[0], np.array(0))
28+
assert_allclose(output[1], np.array([0]))
29+
self.assertTrue(output["nested"]["nested"][0].flags.c_contiguous)
30+
31+
def test_decollation_seq(self):
32+
tochange = torch.zeros(2, 3, 4).transpose(0, 1)
33+
test_dict = [[[1]], np.array(0), np.array([0]), torch.tensor(1.0), [[tochange]], "test_string"]
34+
output = convert_to_contiguous(test_dict)
35+
self.assertEqual(output[0], [[1]])
36+
assert_allclose(output[1], np.array(0))
37+
assert_allclose(output[2], np.array([0]))
38+
assert_allclose(output[3], torch.tensor(1.0))
39+
self.assertTrue(output[4][0][0].is_contiguous())
40+
self.assertEqual(output[5], "test_string")
41+
42+
43+
if __name__ == "__main__":
44+
unittest.main()

0 commit comments

Comments
 (0)