Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/prototype_transforms_kernel_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,7 +915,7 @@ def sample_inputs_rotate_video():
reference_inputs_fn=reference_inputs_rotate_image_tensor,
float32_vs_uint8=True,
# TODO: investigate
closeness_kwargs=pil_reference_pixel_difference(100, agg_method="mean"),
closeness_kwargs=pil_reference_pixel_difference(110, agg_method="mean"),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Flaky test unrelated to this PR that popped up previously on another PR:

Unrelated flakiness:
FAILED test/test_prototype_transforms_functional.py::TestKernels::test_against_reference[rotate_image_tensor-38] - AssertionError: The 'mean' of the absolute difference is 104.21571906354515, but only 100.0 is allowed.

test_marks=[
xfail_jit_tuple_instead_of_list("fill"),
# TODO: check if this is a regression since it seems that should be supported if `int` is ok
Expand Down
1 change: 1 addition & 0 deletions test/test_prototype_transforms_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ def __init__(
ArgsKwargs(p=1, distortion_scale=0.1, fill=1),
ArgsKwargs(p=1, distortion_scale=0.4, fill=(1, 2, 3)),
],
closeness_kwargs={"atol": 1e-6, "rtol": 1e-6},
),
ConsistencyConfig(
prototype_transforms.RandomRotation,
Expand Down
209 changes: 193 additions & 16 deletions torchvision/prototype/transforms/functional/_geometry.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import math
import numbers
import warnings
from typing import List, Optional, Sequence, Tuple, Union

import PIL.Image
import torch
from torch.nn.functional import interpolate, pad as torch_pad
from torch.nn.functional import grid_sample, interpolate, pad as torch_pad

from torchvision.prototype import features
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT
from torchvision.transforms.functional import (
_compute_resized_output_size as __compute_resized_output_size,
_get_inverse_affine_matrix,
_get_perspective_coeffs,
InterpolationMode,
pil_modes_mapping,
Expand Down Expand Up @@ -272,6 +272,168 @@ def _affine_parse_args(
return angle, translate, shear, center


def _get_inverse_affine_matrix(
center: List[float], angle: float, translate: List[float], scale: float, shear: List[float], inverted: bool = True
) -> List[float]:
Comment on lines +275 to +277
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we do some caching of intermediate results and minor refactoring (especially with the negative values) to make the code a bit more readable IMO. Can be reverted.

# Helper method to compute inverse matrix for affine transformation

# Pillow requires inverse affine transformation matrix:
# Affine matrix is : M = T * C * RotateScaleShear * C^-1
#
# where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1]
# C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1]
# RotateScaleShear is rotation with scale and shear matrix
#
# RotateScaleShear(a, s, (sx, sy)) =
# = R(a) * S(s) * SHy(sy) * SHx(sx)
# = [ s*cos(a - sy)/cos(sy), s*(-cos(a - sy)*tan(sx)/cos(sy) - sin(a)), 0 ]
# [ s*sin(a - sy)/cos(sy), s*(-sin(a - sy)*tan(sx)/cos(sy) + cos(a)), 0 ]
# [ 0 , 0 , 1 ]
# where R is a rotation matrix, S is a scaling matrix, and SHx and SHy are the shears:
# SHx(s) = [1, -tan(s)] and SHy(s) = [1 , 0]
# [0, 1 ] [-tan(s), 1]
#
# Thus, the inverse is M^-1 = C * RotateScaleShear^-1 * C^-1 * T^-1

rot = math.radians(angle)
sx = math.radians(shear[0])
sy = math.radians(shear[1])

cx, cy = center
tx, ty = translate

# Cached results
cossy = math.cos(sy)
tansx = math.tan(sx)
rot_sy = rot - sy
cx_plus_tx = cx + tx
cy_plus_ty = cy + ty

# RSS without scaling
a = math.cos(rot_sy) / cossy
b = -(a * tansx + math.sin(rot))
c = math.sin(rot_sy) / cossy
d = math.cos(rot) - c * tansx

if inverted:
# Inverted rotation matrix with scale and shear
# det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1
matrix = [d / scale, -b / scale, 0.0, -c / scale, a / scale, 0.0]
# Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
# and then apply center translation: C * RSS^-1 * C^-1 * T^-1
matrix[2] += cx - matrix[0] * cx_plus_tx - matrix[1] * cy_plus_ty
matrix[5] += cy - matrix[3] * cx_plus_tx - matrix[4] * cy_plus_ty
else:
matrix = [a * scale, b * scale, 0.0, c * scale, d * scale, 0.0]
# Apply inverse of center translation: RSS * C^-1
# and then apply translation and center : T * C * RSS * C^-1
matrix[2] += cx_plus_tx - matrix[0] * cx - matrix[1] * cy
matrix[5] += cy_plus_ty - matrix[3] * cx - matrix[4] * cy

return matrix


def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aminmax call + a bunch of in-place ops to speed things up.

# Inspired of PIL implementation:
# https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054

# pts are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
# Points are shifted due to affine matrix torch convention about
# the center point. Center is (0, 0) for image center pivot point (w * 0.5, h * 0.5)
half_w = 0.5 * w
half_h = 0.5 * h
pts = torch.tensor(
[
[-half_w, -half_h, 1.0],
[-half_w, half_h, 1.0],
[half_w, half_h, 1.0],
[half_w, -half_h, 1.0],
]
)
theta = torch.tensor(matrix, dtype=torch.float).view(2, 3)
new_pts = torch.matmul(pts, theta.T)
min_vals, max_vals = new_pts.aminmax(dim=0)

# shift points to [0, w] and [0, h] interval to match PIL results
halfs = torch.tensor((half_w, half_h))
min_vals.add_(halfs)
max_vals.add_(halfs)

# Truncate precision to 1e-4 to avoid ceil of Xe-15 to 1.0
tol = 1e-4
inv_tol = 1.0 / tol
cmax = max_vals.mul_(inv_tol).trunc_().mul_(tol).ceil_()
cmin = min_vals.mul_(inv_tol).trunc_().mul_(tol).floor_()
size = cmax.sub_(cmin)
return int(size[0]), int(size[1]) # w, h


def _apply_grid_transform(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do in-place ops where possible, plus a bit of refactoring. Important note is that the input image must be float. This is because the handling/casting can be done more efficiently outside of the method.

float_img: torch.Tensor, grid: torch.Tensor, mode: str, fill: Optional[Union[int, float, List[float]]]
) -> torch.Tensor:

shape = float_img.shape
if shape[0] > 1:
# Apply same grid to a batch of images
grid = grid.expand(shape[0], -1, -1, -1)

# Append a dummy mask for customized fill colors, should be faster than grid_sample() twice
if fill is not None:
mask = torch.ones((shape[0], 1, shape[2], shape[3]), dtype=float_img.dtype, device=float_img.device)
float_img = torch.cat((float_img, mask), dim=1)

float_img = grid_sample(float_img, grid, mode=mode, padding_mode="zeros", align_corners=False)

# Fill with required color
if fill is not None:
float_img, mask = torch.tensor_split(float_img, indices=(-1,), dim=-3)
mask = mask.expand_as(float_img)
fill_list = fill if isinstance(fill, (tuple, list)) else [float(fill)]
fill_img = torch.tensor(fill_list, dtype=float_img.dtype, device=float_img.device).view(1, -1, 1, 1)
if mode == "nearest":
bool_mask = mask < 0.5
float_img[bool_mask] = fill_img.expand_as(float_img)[bool_mask]
else: # 'bilinear'
# The following is mathematically equivalent to:
# img * mask + (1.0 - mask) * fill = img * mask - fill * mask + fill = mask * (img - fill) + fill
float_img = float_img.sub_(fill_img).mul_(mask).add_(fill_img)

return float_img


def _assert_grid_transform_inputs(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor clean ups to make the if statements clearer.

img: torch.Tensor,
matrix: Optional[List[float]],
interpolation: str,
fill: Optional[Union[int, float, List[float]]],
supported_interpolation_modes: List[str],
coeffs: Optional[List[float]] = None,
) -> None:
if matrix is not None:
if not isinstance(matrix, list):
raise TypeError("Argument matrix should be a list")
elif len(matrix) != 6:
raise ValueError("Argument matrix should have 6 float values")

if coeffs is not None and len(coeffs) != 8:
raise ValueError("Argument coeffs should have 8 float values")
Comment on lines +418 to +419
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is coeffs in an affine transformation? That seems out of place here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the perspective coeff. I'm not writing new code, I'm porting the existing methods here. I don't think changes of changing the validation should be in scope on this PR because it will get really fast really quickly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me send a follow-up PR then.


if fill is not None:
if isinstance(fill, (tuple, list)):
length = len(fill)
num_channels = img.shape[-3]
if length > 1 and length != num_channels:
raise ValueError(
"The number of elements in 'fill' cannot broadcast to match the number of "
f"channels of the image ({length} != {num_channels})"
)
elif not isinstance(fill, (int, float)):
raise ValueError("Argument fill should be either int, float, tuple or list")

if interpolation not in supported_interpolation_modes:
raise ValueError(f"Interpolation mode '{interpolation}' is unsupported with Tensor input")


def affine_image_tensor(
image: torch.Tensor,
angle: Union[int, float],
Expand Down Expand Up @@ -395,7 +557,7 @@ def _affine_bounding_box_xyxy(
out_bboxes.sub_(tr.repeat((1, 2)))
# Estimate meta-data for image with inverted=True and with center=[0,0]
affine_vector = _get_inverse_affine_matrix([0.0, 0.0], angle, translate, scale, shear)
new_width, new_height = _FT._compute_affine_output_size(affine_vector, width, height)
new_width, new_height = _compute_affine_output_size(affine_vector, width, height)
spatial_size = (new_height, new_width)

return out_bboxes.to(bounding_box.dtype), spatial_size
Expand Down Expand Up @@ -552,7 +714,7 @@ def rotate_image_tensor(
)
new_height, new_width = image.shape[-2:]
else:
new_width, new_height = _FT._compute_affine_output_size(matrix, width, height) if expand else (width, height)
new_width, new_height = _compute_affine_output_size(matrix, width, height) if expand else (width, height)
Comment on lines +773 to +774
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I don't adopt the squash idiom from other places because it would lead to more complex logic. I have concerns about whether the implementation followed here (maintained from main) actually handles properly all corner-cases (ill formed images with 0 elements), similar to the issue observed with pad.

@vfdev-5 Might be worth talking a look on your side to see if a mitigation is necessary similar to #6949 (aka sending the image through the kernel normally).


return image.reshape(shape[:-3] + (num_channels, new_height, new_width))

Expand Down Expand Up @@ -944,7 +1106,6 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype,
# x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
# y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
#
# TODO: should we define them transposed?
theta1 = torch.tensor(
[[[coeffs[0], coeffs[1], coeffs[2]], [coeffs[3], coeffs[4], coeffs[5]]]], dtype=dtype, device=device
)
Expand All @@ -959,8 +1120,9 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype,
base_grid[..., 2].fill_(1)

rescaled_theta1 = theta1.transpose(1, 2).div_(torch.tensor([0.5 * ow, 0.5 * oh], dtype=dtype, device=device))
output_grid1 = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta1)
output_grid2 = base_grid.view(1, oh * ow, 3).bmm(theta2.transpose(1, 2))
shape = (1, oh * ow, 3)
output_grid1 = base_grid.view(shape).bmm(rescaled_theta1)
output_grid2 = base_grid.view(shape).bmm(theta2.transpose(1, 2))

output_grid = output_grid1.div_(output_grid2).sub_(1.0)
return output_grid.view(1, oh, ow, 2)
Expand Down Expand Up @@ -996,14 +1158,19 @@ def perspective_image_tensor(
return image

shape = image.shape
ndim = image.ndim
fp = torch.is_floating_point(image)

if image.ndim > 4:
if ndim > 4:
image = image.reshape((-1,) + shape[-3:])
needs_unsquash = True
elif ndim == 3:
image = image.unsqueeze(0)
needs_unsquash = True
else:
needs_unsquash = False

_FT._assert_grid_transform_inputs(
_assert_grid_transform_inputs(
image,
matrix=None,
interpolation=interpolation.value,
Expand All @@ -1012,10 +1179,13 @@ def perspective_image_tensor(
coeffs=perspective_coeffs,
)

ow, oh = image.shape[-1], image.shape[-2]
dtype = image.dtype if torch.is_floating_point(image) else torch.float32
oh, ow = shape[-2:]
dtype = image.dtype if fp else torch.float32
grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=image.device)
output = _FT._apply_grid_transform(image, grid, interpolation.value, fill=fill)
output = _apply_grid_transform(image if fp else image.to(dtype), grid, interpolation.value, fill=fill)

if not fp:
output = output.round_().to(image.dtype)

if needs_unsquash:
output = output.reshape(shape)
Expand Down Expand Up @@ -1086,7 +1256,6 @@ def perspective_bounding_box(
(-perspective_coeffs[0] * perspective_coeffs[7] + perspective_coeffs[1] * perspective_coeffs[6]) / denom,
]

# TODO: should we define them transposed?
theta1 = torch.tensor(
[[inv_coeffs[0], inv_coeffs[1], inv_coeffs[2]], [inv_coeffs[3], inv_coeffs[4], inv_coeffs[5]]],
dtype=dtype,
Expand Down Expand Up @@ -1193,17 +1362,25 @@ def elastic_image_tensor(
return image

shape = image.shape
ndim = image.ndim
device = image.device
fp = torch.is_floating_point(image)

if image.ndim > 4:
if ndim > 4:
image = image.reshape((-1,) + shape[-3:])
needs_unsquash = True
elif ndim == 3:
image = image.unsqueeze(0)
needs_unsquash = True
else:
needs_unsquash = False

image_height, image_width = shape[-2:]
grid = _create_identity_grid((image_height, image_width), device=device).add_(displacement.to(device))
output = _FT._apply_grid_transform(image, grid, interpolation.value, fill)
output = _apply_grid_transform(image if fp else image.to(torch.float32), grid, interpolation.value, fill=fill)

if not fp:
output = output.round_().to(image.dtype)

if needs_unsquash:
output = output.reshape(shape)
Expand Down Expand Up @@ -1361,7 +1538,7 @@ def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> tor

if crop_height > image_height or crop_width > image_width:
padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
image = _FT.torch_pad(image, _FT._parse_pad_padding(padding_ltrb), value=0.0)
image = torch_pad(image, _parse_pad_padding(padding_ltrb), value=0.0)

image_height, image_width = image.shape[-2:]
if crop_width == image_width and crop_height == image_height:
Expand Down