Skip to content

Commit ab800d8

Browse files
authored
6066 pad mode (#6076)
Fixes #6066 ### Description prefer the pytorch backend as much as possible ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Wenqi Li <wenqil@nvidia.com>
1 parent c2fc083 commit ab800d8

2 files changed

Lines changed: 21 additions & 8 deletions

File tree

monai/transforms/croppad/functional.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
from __future__ import annotations
1717

18+
import warnings
19+
1820
import numpy as np
1921
import torch
2022
from torch.nn.functional import pad as pad_pt
@@ -29,7 +31,12 @@
2931

3032

3133
def _np_pad(img: torch.Tensor, pad_width: list[tuple[int, int]], mode: str, **kwargs) -> torch.Tensor:
32-
img_np = img.detach().cpu().numpy() if isinstance(img, torch.Tensor) else img
34+
if isinstance(img, torch.Tensor):
35+
if img.is_cuda:
36+
warnings.warn(f"Padding: moving img {img.shape} from cuda to cpu for dtype={img.dtype} mode={mode}.")
37+
img_np = img.detach().cpu().numpy()
38+
else:
39+
img_np = img
3340
mode = convert_pad_mode(dst=img_np, mode=mode).value
3441
if mode == "constant" and "value" in kwargs:
3542
kwargs["constant_values"] = kwargs.pop("value")
@@ -40,9 +47,15 @@ def _np_pad(img: torch.Tensor, pad_width: list[tuple[int, int]], mode: str, **kw
4047

4148

4249
def _pt_pad(img: torch.Tensor, pad_width: list[tuple[int, int]], mode: str, **kwargs) -> torch.Tensor:
50+
mode = convert_pad_mode(dst=img, mode=mode).value
51+
if mode == "constant" and "constant_values" in kwargs:
52+
_kwargs = kwargs.copy()
53+
_kwargs["value"] = _kwargs.pop("constant_values")
54+
else:
55+
_kwargs = kwargs
4356
pt_pad_width = [val for sublist in pad_width[1:] for val in sublist[::-1]][::-1]
4457
# torch.pad expects `[B, C, H, W, [D]]` shape
45-
return pad_pt(img.unsqueeze(0), pt_pad_width, mode=mode, **kwargs).squeeze(0)
58+
return pad_pt(img.unsqueeze(0), pt_pad_width, mode=mode, **_kwargs).squeeze(0)
4659

4760

4861
def pad_nd(img: torch.Tensor, to_pad: list[tuple[int, int]], mode: str, **kwargs):
@@ -68,14 +81,14 @@ def pad_nd(img: torch.Tensor, to_pad: list[tuple[int, int]], mode: str, **kwargs
6881
mode = convert_pad_mode(dst=img, mode=mode).value
6982
try:
7083
_pad = (
71-
_pt_pad
72-
if mode in {"reflect", "replicate"} and img.dtype not in {torch.int16, torch.int64, torch.bool, torch.uint8}
73-
else _np_pad
84+
_np_pad
85+
if mode in {"reflect", "replicate"} and img.dtype in {torch.int16, torch.int64, torch.bool, torch.uint8}
86+
else _pt_pad
7487
)
7588
return _pad(img, pad_width=to_pad, mode=mode, **kwargs)
7689
except (ValueError, TypeError, RuntimeError) as err:
7790
if isinstance(err, NotImplementedError) or any(
78-
k in str(err) for k in ("supported", "unexpected keyword", "implemented")
91+
k in str(err) for k in ("supported", "unexpected keyword", "implemented", "value")
7992
):
8093
return _np_pad(img, pad_width=to_pad, mode=mode, **kwargs)
8194
raise ValueError(f"{img.shape} {to_pad} {mode} {kwargs} {img.dtype} {img.device}") from err

monai/transforms/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1628,13 +1628,13 @@ def convert_pad_mode(dst: NdarrayOrTensor, mode: str | None):
16281628
if isinstance(dst, torch.Tensor):
16291629
if mode == "wrap":
16301630
mode = "circular"
1631-
if mode == "edge":
1631+
elif mode == "edge":
16321632
mode = "replicate"
16331633
return look_up_option(mode, PytorchPadMode)
16341634
if isinstance(dst, np.ndarray):
16351635
if mode == "circular":
16361636
mode = "wrap"
1637-
if mode == "replicate":
1637+
elif mode == "replicate":
16381638
mode = "edge"
16391639
return look_up_option(mode, NumpyPadMode)
16401640
raise ValueError(f"unsupported data type: {type(dst)}.")

0 commit comments

Comments
 (0)