Skip to content

Commit c0911e3

Browse files
authored
Update typehint for fill arg in rotate (#6594)
1 parent 753bf18 commit c0911e3

File tree

3 files changed

+14
-10
lines changed

3 files changed

+14
-10
lines changed

test/test_prototype_transforms_functional.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -102,18 +102,20 @@ def affine_mask():
102102

103103
@register_kernel_info_from_sample_inputs_fn
104104
def rotate_image_tensor():
105-
for image, angle, expand, center, fill in itertools.product(
105+
for image, angle, expand, center in itertools.product(
106106
make_images(),
107107
[-87, 15, 90], # angle
108108
[True, False], # expand
109109
[None, [12, 23]], # center
110-
[None, [128], [12.0]], # fill
111110
):
112111
if center is not None and expand:
113112
# Skip warning: The provided center argument is ignored if expand is True
114113
continue
115114

116-
yield ArgsKwargs(image, angle=angle, expand=expand, center=center, fill=fill)
115+
yield ArgsKwargs(image, angle=angle, expand=expand, center=center, fill=None)
116+
117+
for fill in [None, 128.0, 128, [12.0], [1.0, 2.0, 3.0]]:
118+
yield ArgsKwargs(image, angle=23, expand=False, center=None, fill=fill)
117119

118120

119121
@register_kernel_info_from_sample_inputs_fn

torchvision/prototype/transforms/functional/_geometry.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,7 @@ def rotate_image_tensor(
467467
angle: float,
468468
interpolation: InterpolationMode = InterpolationMode.NEAREST,
469469
expand: bool = False,
470-
fill: Optional[List[float]] = None,
470+
fill: Optional[Union[int, float, List[float]]] = None,
471471
center: Optional[List[float]] = None,
472472
) -> torch.Tensor:
473473
num_channels, height, width = img.shape[-3:]

torchvision/transforms/functional_tensor.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,7 @@ def _assert_grid_transform_inputs(
475475
img: Tensor,
476476
matrix: Optional[List[float]],
477477
interpolation: str,
478-
fill: Optional[List[float]],
478+
fill: Optional[Union[int, float, List[float]]],
479479
supported_interpolation_modes: List[str],
480480
coeffs: Optional[List[float]] = None,
481481
) -> None:
@@ -499,7 +499,7 @@ def _assert_grid_transform_inputs(
499499

500500
# Check fill
501501
num_channels = get_dimensions(img)[0]
502-
if isinstance(fill, (tuple, list)) and (len(fill) > 1 and len(fill) != num_channels):
502+
if fill is not None and isinstance(fill, (tuple, list)) and (len(fill) > 1 and len(fill) != num_channels):
503503
msg = (
504504
"The number of elements in 'fill' cannot broadcast to match the number of "
505505
"channels of the image ({} != {})"
@@ -539,7 +539,9 @@ def _cast_squeeze_out(img: Tensor, need_cast: bool, need_squeeze: bool, out_dtyp
539539
return img
540540

541541

542-
def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str, fill: Optional[List[float]]) -> Tensor:
542+
def _apply_grid_transform(
543+
img: Tensor, grid: Tensor, mode: str, fill: Optional[Union[int, float, List[float]]]
544+
) -> Tensor:
543545

544546
img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [grid.dtype])
545547

@@ -559,8 +561,8 @@ def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str, fill: Optional[L
559561
mask = img[:, -1:, :, :] # N * 1 * H * W
560562
img = img[:, :-1, :, :] # N * C * H * W
561563
mask = mask.expand_as(img)
562-
len_fill = len(fill) if isinstance(fill, (tuple, list)) else 1
563-
fill_img = torch.tensor(fill, dtype=img.dtype, device=img.device).view(1, len_fill, 1, 1).expand_as(img)
564+
fill_list, len_fill = (fill, len(fill)) if isinstance(fill, (tuple, list)) else ([float(fill)], 1)
565+
fill_img = torch.tensor(fill_list, dtype=img.dtype, device=img.device).view(1, len_fill, 1, 1).expand_as(img)
564566
if mode == "nearest":
565567
mask = mask < 0.5
566568
img[mask] = fill_img[mask]
@@ -648,7 +650,7 @@ def rotate(
648650
matrix: List[float],
649651
interpolation: str = "nearest",
650652
expand: bool = False,
651-
fill: Optional[List[float]] = None,
653+
fill: Optional[Union[int, float, List[float]]] = None,
652654
) -> Tensor:
653655
_assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"])
654656
w, h = img.shape[-1], img.shape[-2]

0 commit comments

Comments
 (0)