Skip to content

Added RandomAffine to prototype Transforms API #6130

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@
import pytest
import torch
from common_utils import assert_equal
from test_prototype_transforms_functional import make_images, make_bounding_boxes, make_one_hot_labels
from test_prototype_transforms_functional import (
make_images,
make_bounding_boxes,
make_one_hot_labels,
make_segmentation_masks,
)
from torchvision.prototype import transforms, features
from torchvision.transforms.functional import to_pil_image, pil_to_tensor

Expand Down Expand Up @@ -49,11 +54,13 @@ def parametrize_from_transforms(*transforms):
make_one_hot_labels,
make_vanilla_tensor_images,
make_pil_images,
make_segmentation_masks,
]:
inputs = list(creation_fn())
# vfdev: this looks scary
try:
output = transform(inputs[0])
except Exception:
except TypeError:
continue
else:
if output is inputs[0]:
Expand All @@ -68,10 +75,11 @@ class TestSmoke:
@parametrize_from_transforms(
transforms.RandomErasing(p=1.0),
transforms.Resize([16, 16]),
transforms.CenterCrop([16, 16]),
# transforms.CenterCrop([16, 16]), # This transform needs to be updated (bbox, segm mask support)
transforms.ConvertImageDtype(),
transforms.RandomHorizontalFlip(),
transforms.Pad(5),
# transforms.Pad(5), # This transform is broken
transforms.RandomAffine(10, (0.2, 0.3), (0.7, 1.2), 0.1, fill=1.0),
)
def test_common(self, transform, input):
transform(input)
Expand Down
2 changes: 1 addition & 1 deletion test/test_prototype_transforms_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def resize_bounding_box():
@register_kernel_info_from_sample_inputs_fn
def affine_image_tensor():
for image, angle, translate, scale, shear in itertools.product(
make_images(extra_dims=((), (4,))),
make_images(),
[-87, 15, 90], # angle
[5, -5], # translate
[0.77, 1.27], # scale
Expand Down
1 change: 1 addition & 0 deletions torchvision/prototype/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
RandomVerticalFlip,
Pad,
RandomZoomOut,
RandomAffine,
)
from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace
from ._misc import Identity, Normalize, ToDtype, Lambda
Expand Down
116 changes: 114 additions & 2 deletions torchvision/prototype/transforms/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,19 @@
import math
import numbers
import warnings
from typing import Any, Dict, List, Union, Sequence, Tuple, cast
from typing import Any, Dict, List, Optional, Union, Sequence, Tuple, cast

import PIL.Image
import torch
from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, functional as F
from torchvision.transforms.functional import pil_to_tensor, InterpolationMode
from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int
from torchvision.transforms.transforms import (
_setup_size,
_interpolation_modes_from_int,
_setup_angle,
_check_sequence_input,
)
from typing_extensions import Literal

from ._transform import _RandomApplyTransform
Expand Down Expand Up @@ -125,6 +130,9 @@ def __init__(
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
warnings.warn("Scale and ratio should be of kind (min, max)")

# TODO: Let's remove this compatibility for the prototype
# Otherwise, apply the same logic for Resize and other ops with interpolate arg.
#
# Backward compatibility with integer value
if isinstance(interpolation, int):
warnings.warn(
Expand Down Expand Up @@ -388,3 +396,107 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
transform = Pad(**params, padding_mode="constant")
return transform(input)


class RandomAffine(Transform):
def __init__(
self,
degrees: Union[float, Sequence[float]],
translate: Optional[Tuple[float, float]] = None,
scale: Optional[Tuple[float, float]] = None,
shear: Optional[Union[float, Sequence[float]]] = None,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Union[float, List[float]] = 0.0,
center: Optional[List[int]] = None,
) -> None:
super().__init__()

self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))

if translate is not None:
_check_sequence_input(translate, "translate", req_sizes=(2,))
for t in translate:
if not (0.0 <= t <= 1.0):
raise ValueError("translation values should be between 0 and 1")
self.translate = translate

if scale is not None:
_check_sequence_input(scale, "scale", req_sizes=(2,))
for s in scale:
if s <= 0:
raise ValueError("scale values should be positive")
self.scale = scale

if shear is not None:
self.shear = _setup_angle(shear, name="shear", req_sizes=(2, 4))
else:
self.shear = shear

self.interpolation = interpolation

if not isinstance(fill, (Sequence, numbers.Number)):
raise TypeError("Fill should be either a sequence or a number.")
self.fill = fill

if center is not None:
_check_sequence_input(center, "center", req_sizes=(2,))

self.center = center

def _get_params(self, sample: Any) -> Dict[str, Any]:
image = query_image(sample)
_, orig_h, orig_w = get_image_dimensions(image)

angle = float(torch.empty(1).uniform_(float(self.degrees[0]), float(self.degrees[1])).item())
if self.translate is not None:
max_dx = float(self.translate[0] * orig_w)
max_dy = float(self.translate[1] * orig_h)
tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx).item()))
ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy).item()))
translations = (tx, ty)
else:
translations = (0, 0)

if self.scale is not None:
scale = float(torch.empty(1).uniform_(self.scale[0], self.scale[1]).item())
else:
scale = 1.0

shear_x = shear_y = 0.0
if self.shear is not None:
shear_x = float(torch.empty(1).uniform_(self.shear[0], self.shear[1]).item())
if len(self.shear) == 4:
shear_y = float(torch.empty(1).uniform_(self.shear[2], self.shear[3]).item())

shear = (shear_x, shear_y)

return dict(angle=angle, translate=translations, scale=scale, shear=shear)

def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, features.Image) or is_simple_tensor(input):
fill = self.fill
if isinstance(fill, (int, float)):
num_channels, _, _ = get_image_dimensions(input)
fill = [float(fill)] * num_channels
else:
fill = [float(f) for f in fill]

output = F.affine_image_tensor(
input, **params, interpolation=self.interpolation, fill=fill, center=self.center
)

if isinstance(input, features.Image):
return features.Image.new_like(input, output)
return output
elif isinstance(input, PIL.Image.Image):
return F.affine_image_pil(
input, **params, interpolation=self.interpolation, fill=self.fill, center=self.center
)
elif isinstance(input, features.BoundingBox):
output = F.affine_bounding_box(input, input.image_size, **params, center=self.center)
return features.BoundingBox.new_like(input, output)
elif isinstance(input, features.SegmentationMask):
output = F.affine_segmentation_mask(input, **params, center=self.center)
return features.SegmentationMask.new_like(input, output)
else:
return input
5 changes: 4 additions & 1 deletion torchvision/prototype/transforms/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,10 @@ def affine_image_tensor(
translate_f = [1.0 * t for t in translate]
matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear)

return _FT.affine(img, matrix, interpolation=interpolation.value, fill=fill)
num_channels, height, width = get_dimensions_image_tensor(img)
batch_shape = img.shape[:-3]
output = _FT.affine(img.view(-1, num_channels, height, width), matrix, interpolation=interpolation.value, fill=fill)
return output.view(batch_shape + (num_channels, height, width))


def affine_image_pil(
Expand Down