-
Notifications
You must be signed in to change notification settings - Fork 7.2k
[prototype] Optimize and clean up all affine methods #6945
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 13 commits
3ac06ad
62deb43
2f0d763
dca1923
b5548ec
709b34a
3c38b97
b9a6e74
62b9d47
b3a0bb1
8e110f6
555df2d
5c1f433
a32be72
311ff85
6644006
d3639e0
548ef68
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,16 +1,16 @@ | ||
| import math | ||
| import numbers | ||
| import warnings | ||
| from typing import List, Optional, Sequence, Tuple, Union | ||
|
|
||
| import PIL.Image | ||
| import torch | ||
| from torch.nn.functional import interpolate, pad as torch_pad | ||
| from torch.nn.functional import grid_sample, interpolate, pad as torch_pad | ||
|
|
||
| from torchvision.prototype import features | ||
| from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT | ||
| from torchvision.transforms.functional import ( | ||
| _compute_resized_output_size as __compute_resized_output_size, | ||
| _get_inverse_affine_matrix, | ||
| _get_perspective_coeffs, | ||
| InterpolationMode, | ||
| pil_modes_mapping, | ||
|
|
@@ -272,6 +272,168 @@ def _affine_parse_args( | |
| return angle, translate, shear, center | ||
|
|
||
|
|
||
| def _get_inverse_affine_matrix( | ||
| center: List[float], angle: float, translate: List[float], scale: float, shear: List[float], inverted: bool = True | ||
| ) -> List[float]: | ||
|
Comment on lines
+275
to
+277
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here we do some caching of intermediate results and minor refactoring (especially with the negative values) to make the code a bit more readable IMO. Can be reverted. |
||
| # Helper method to compute inverse matrix for affine transformation | ||
|
|
||
| # Pillow requires inverse affine transformation matrix: | ||
| # Affine matrix is : M = T * C * RotateScaleShear * C^-1 | ||
| # | ||
| # where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1] | ||
| # C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1] | ||
| # RotateScaleShear is rotation with scale and shear matrix | ||
| # | ||
| # RotateScaleShear(a, s, (sx, sy)) = | ||
| # = R(a) * S(s) * SHy(sy) * SHx(sx) | ||
| # = [ s*cos(a - sy)/cos(sy), s*(-cos(a - sy)*tan(sx)/cos(sy) - sin(a)), 0 ] | ||
| # [ s*sin(a - sy)/cos(sy), s*(-sin(a - sy)*tan(sx)/cos(sy) + cos(a)), 0 ] | ||
| # [ 0 , 0 , 1 ] | ||
| # where R is a rotation matrix, S is a scaling matrix, and SHx and SHy are the shears: | ||
| # SHx(s) = [1, -tan(s)] and SHy(s) = [1 , 0] | ||
| # [0, 1 ] [-tan(s), 1] | ||
| # | ||
| # Thus, the inverse is M^-1 = C * RotateScaleShear^-1 * C^-1 * T^-1 | ||
|
|
||
| rot = math.radians(angle) | ||
| sx = math.radians(shear[0]) | ||
| sy = math.radians(shear[1]) | ||
|
|
||
| cx, cy = center | ||
| tx, ty = translate | ||
|
|
||
| # Cached results | ||
| cossy = math.cos(sy) | ||
| tansx = math.tan(sx) | ||
datumbox marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| rot_sy = rot - sy | ||
| cx_plus_tx = cx + tx | ||
datumbox marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| cy_plus_ty = cy + ty | ||
|
|
||
| # RSS without scaling | ||
datumbox marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| a = math.cos(rot_sy) / cossy | ||
| b = -(a * tansx + math.sin(rot)) | ||
| c = math.sin(rot_sy) / cossy | ||
| d = math.cos(rot) - c * tansx | ||
|
|
||
| if inverted: | ||
| # Inverted rotation matrix with scale and shear | ||
| # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1 | ||
| matrix = [d / scale, -b / scale, 0.0, -c / scale, a / scale, 0.0] | ||
| # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1 | ||
| # and then apply center translation: C * RSS^-1 * C^-1 * T^-1 | ||
| matrix[2] += cx - matrix[0] * cx_plus_tx - matrix[1] * cy_plus_ty | ||
| matrix[5] += cy - matrix[3] * cx_plus_tx - matrix[4] * cy_plus_ty | ||
| else: | ||
| matrix = [a * scale, b * scale, 0.0, c * scale, d * scale, 0.0] | ||
| # Apply inverse of center translation: RSS * C^-1 | ||
| # and then apply translation and center : T * C * RSS * C^-1 | ||
| matrix[2] += cx_plus_tx - matrix[0] * cx - matrix[1] * cy | ||
| matrix[5] += cy_plus_ty - matrix[3] * cx - matrix[4] * cy | ||
|
|
||
| return matrix | ||
|
|
||
|
|
||
| def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]: | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| # Inspired of PIL implementation: | ||
| # https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054 | ||
|
|
||
| # pts are Top-Left, Top-Right, Bottom-Left, Bottom-Right points. | ||
| # Points are shifted due to affine matrix torch convention about | ||
| # the center point. Center is (0, 0) for image center pivot point (w * 0.5, h * 0.5) | ||
| half_w = 0.5 * w | ||
| half_h = 0.5 * h | ||
| pts = torch.tensor( | ||
| [ | ||
| [-half_w, -half_h, 1.0], | ||
| [-half_w, half_h, 1.0], | ||
| [half_w, half_h, 1.0], | ||
| [half_w, -half_h, 1.0], | ||
| ] | ||
| ) | ||
| theta = torch.tensor(matrix, dtype=torch.float).view(2, 3) | ||
| new_pts = torch.matmul(pts, theta.T) | ||
| min_vals, max_vals = new_pts.aminmax(dim=0) | ||
|
|
||
| # shift points to [0, w] and [0, h] interval to match PIL results | ||
| halfs = torch.tensor((half_w, half_h)) | ||
| min_vals.add_(halfs) | ||
| max_vals.add_(halfs) | ||
|
|
||
| # Truncate precision to 1e-4 to avoid ceil of Xe-15 to 1.0 | ||
| tol = 1e-4 | ||
| inv_tol = 1.0 / tol | ||
| cmax = max_vals.mul_(inv_tol).trunc_().mul_(tol).ceil_() | ||
| cmin = min_vals.mul_(inv_tol).trunc_().mul_(tol).floor_() | ||
| size = cmax.sub_(cmin) | ||
| return int(size[0]), int(size[1]) # w, h | ||
|
|
||
|
|
||
| def _apply_grid_transform( | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We do in-place ops where possible, plus a bit of refactoring. Important note is that the input image must be float. This is because the handling/casting can be done more efficiently outside of the method. |
||
| float_img: torch.Tensor, grid: torch.Tensor, mode: str, fill: Optional[Union[int, float, List[float]]] | ||
| ) -> torch.Tensor: | ||
|
|
||
| shape = float_img.shape | ||
| if shape[0] > 1: | ||
| # Apply same grid to a batch of images | ||
| grid = grid.expand(shape[0], -1, -1, -1) | ||
|
|
||
| # Append a dummy mask for customized fill colors, should be faster than grid_sample() twice | ||
| if fill is not None: | ||
| mask = torch.ones((shape[0], 1, shape[2], shape[3]), dtype=float_img.dtype, device=float_img.device) | ||
| float_img = torch.cat((float_img, mask), dim=1) | ||
|
|
||
| float_img = grid_sample(float_img, grid, mode=mode, padding_mode="zeros", align_corners=False) | ||
|
|
||
| # Fill with required color | ||
| if fill is not None: | ||
| float_img, mask = torch.tensor_split(float_img, indices=(-1,), dim=-3) | ||
| mask = mask.expand_as(float_img) | ||
| fill_list = fill if isinstance(fill, (tuple, list)) else [float(fill)] | ||
| fill_img = torch.tensor(fill_list, dtype=float_img.dtype, device=float_img.device).view(1, -1, 1, 1) | ||
| if mode == "nearest": | ||
| bool_mask = mask < 0.5 | ||
| float_img[bool_mask] = fill_img.expand_as(float_img)[bool_mask] | ||
| else: # 'bilinear' | ||
| # The following is mathematically equivalent to: | ||
| # img * mask + (1.0 - mask) * fill = img * mask - fill * mask + fill = mask * (img - fill) + fill | ||
| float_img = float_img.sub_(fill_img).mul_(mask).add_(fill_img) | ||
|
|
||
| return float_img | ||
|
|
||
|
|
||
| def _assert_grid_transform_inputs( | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Minor clean ups to make the if statements clearer. |
||
| img: torch.Tensor, | ||
| matrix: Optional[List[float]], | ||
| interpolation: str, | ||
| fill: Optional[Union[int, float, List[float]]], | ||
| supported_interpolation_modes: List[str], | ||
| coeffs: Optional[List[float]] = None, | ||
| ) -> None: | ||
| if matrix is not None: | ||
| if not isinstance(matrix, list): | ||
| raise TypeError("Argument matrix should be a list") | ||
| elif len(matrix) != 6: | ||
| raise ValueError("Argument matrix should have 6 float values") | ||
|
|
||
| if coeffs is not None and len(coeffs) != 8: | ||
| raise ValueError("Argument coeffs should have 8 float values") | ||
|
Comment on lines
+418
to
+419
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See #6945 (comment).
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's the perspective coeff. I'm not writing new code, I'm porting the existing methods here. I don't think changes of changing the validation should be in scope on this PR because it will get really fast really quickly.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let me send a follow-up PR then. |
||
|
|
||
| if fill is not None: | ||
| if isinstance(fill, (tuple, list)): | ||
| length = len(fill) | ||
| num_channels = img.shape[-3] | ||
| if length > 1 and length != num_channels: | ||
| raise ValueError( | ||
| "The number of elements in 'fill' cannot broadcast to match the number of " | ||
| f"channels of the image ({length} != {num_channels})" | ||
| ) | ||
| elif not isinstance(fill, (int, float)): | ||
| raise ValueError("Argument fill should be either int, float, tuple or list") | ||
|
|
||
| if interpolation not in supported_interpolation_modes: | ||
| raise ValueError(f"Interpolation mode '{interpolation}' is unsupported with Tensor input") | ||
|
|
||
|
|
||
| def affine_image_tensor( | ||
| image: torch.Tensor, | ||
| angle: Union[int, float], | ||
|
|
@@ -395,7 +557,7 @@ def _affine_bounding_box_xyxy( | |
| out_bboxes.sub_(tr.repeat((1, 2))) | ||
| # Estimate meta-data for image with inverted=True and with center=[0,0] | ||
| affine_vector = _get_inverse_affine_matrix([0.0, 0.0], angle, translate, scale, shear) | ||
| new_width, new_height = _FT._compute_affine_output_size(affine_vector, width, height) | ||
| new_width, new_height = _compute_affine_output_size(affine_vector, width, height) | ||
| spatial_size = (new_height, new_width) | ||
|
|
||
| return out_bboxes.to(bounding_box.dtype), spatial_size | ||
|
|
@@ -552,7 +714,7 @@ def rotate_image_tensor( | |
| ) | ||
| new_height, new_width = image.shape[-2:] | ||
| else: | ||
| new_width, new_height = _FT._compute_affine_output_size(matrix, width, height) if expand else (width, height) | ||
| new_width, new_height = _compute_affine_output_size(matrix, width, height) if expand else (width, height) | ||
|
Comment on lines
+773
to
+774
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here I don't adopt the squash idiom from other places because it would lead to more complex logic. I have concerns about whether the implementation followed here (maintained from main) actually handles properly all corner-cases (ill formed images with 0 elements), similar to the issue observed with pad. @vfdev-5 Might be worth talking a look on your side to see if a mitigation is necessary similar to #6949 (aka sending the image through the kernel normally). |
||
|
|
||
| return image.reshape(shape[:-3] + (num_channels, new_height, new_width)) | ||
|
|
||
|
|
@@ -944,7 +1106,6 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, | |
| # x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1) | ||
| # y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1) | ||
| # | ||
| # TODO: should we define them transposed? | ||
| theta1 = torch.tensor( | ||
| [[[coeffs[0], coeffs[1], coeffs[2]], [coeffs[3], coeffs[4], coeffs[5]]]], dtype=dtype, device=device | ||
| ) | ||
|
|
@@ -959,8 +1120,9 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, | |
| base_grid[..., 2].fill_(1) | ||
|
|
||
| rescaled_theta1 = theta1.transpose(1, 2).div_(torch.tensor([0.5 * ow, 0.5 * oh], dtype=dtype, device=device)) | ||
| output_grid1 = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta1) | ||
| output_grid2 = base_grid.view(1, oh * ow, 3).bmm(theta2.transpose(1, 2)) | ||
| shape = (1, oh * ow, 3) | ||
| output_grid1 = base_grid.view(shape).bmm(rescaled_theta1) | ||
| output_grid2 = base_grid.view(shape).bmm(theta2.transpose(1, 2)) | ||
|
|
||
| output_grid = output_grid1.div_(output_grid2).sub_(1.0) | ||
| return output_grid.view(1, oh, ow, 2) | ||
|
|
@@ -996,14 +1158,19 @@ def perspective_image_tensor( | |
| return image | ||
|
|
||
| shape = image.shape | ||
| ndim = image.ndim | ||
| fp = torch.is_floating_point(image) | ||
|
|
||
| if image.ndim > 4: | ||
| if ndim > 4: | ||
| image = image.reshape((-1,) + shape[-3:]) | ||
| needs_unsquash = True | ||
| elif ndim == 3: | ||
| image = image.unsqueeze(0) | ||
| needs_unsquash = True | ||
| else: | ||
| needs_unsquash = False | ||
|
|
||
| _FT._assert_grid_transform_inputs( | ||
| _assert_grid_transform_inputs( | ||
| image, | ||
| matrix=None, | ||
| interpolation=interpolation.value, | ||
|
|
@@ -1012,10 +1179,13 @@ def perspective_image_tensor( | |
| coeffs=perspective_coeffs, | ||
pmeier marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ) | ||
|
|
||
| ow, oh = image.shape[-1], image.shape[-2] | ||
| dtype = image.dtype if torch.is_floating_point(image) else torch.float32 | ||
| oh, ow = shape[-2:] | ||
| dtype = image.dtype if fp else torch.float32 | ||
| grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=image.device) | ||
| output = _FT._apply_grid_transform(image, grid, interpolation.value, fill=fill) | ||
| output = _apply_grid_transform(image if fp else image.to(dtype), grid, interpolation.value, fill=fill) | ||
|
|
||
| if not fp: | ||
| output = output.round_().to(image.dtype) | ||
|
|
||
| if needs_unsquash: | ||
| output = output.reshape(shape) | ||
|
|
@@ -1086,7 +1256,6 @@ def perspective_bounding_box( | |
| (-perspective_coeffs[0] * perspective_coeffs[7] + perspective_coeffs[1] * perspective_coeffs[6]) / denom, | ||
| ] | ||
|
|
||
| # TODO: should we define them transposed? | ||
| theta1 = torch.tensor( | ||
| [[inv_coeffs[0], inv_coeffs[1], inv_coeffs[2]], [inv_coeffs[3], inv_coeffs[4], inv_coeffs[5]]], | ||
| dtype=dtype, | ||
|
|
@@ -1193,17 +1362,25 @@ def elastic_image_tensor( | |
| return image | ||
|
|
||
| shape = image.shape | ||
| ndim = image.ndim | ||
| device = image.device | ||
| fp = torch.is_floating_point(image) | ||
|
|
||
| if image.ndim > 4: | ||
| if ndim > 4: | ||
| image = image.reshape((-1,) + shape[-3:]) | ||
| needs_unsquash = True | ||
| elif ndim == 3: | ||
| image = image.unsqueeze(0) | ||
| needs_unsquash = True | ||
| else: | ||
| needs_unsquash = False | ||
|
|
||
| image_height, image_width = shape[-2:] | ||
| grid = _create_identity_grid((image_height, image_width), device=device).add_(displacement.to(device)) | ||
| output = _FT._apply_grid_transform(image, grid, interpolation.value, fill) | ||
| output = _apply_grid_transform(image if fp else image.to(torch.float32), grid, interpolation.value, fill=fill) | ||
|
|
||
| if not fp: | ||
| output = output.round_().to(image.dtype) | ||
|
|
||
| if needs_unsquash: | ||
| output = output.reshape(shape) | ||
|
|
@@ -1361,7 +1538,7 @@ def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> tor | |
|
|
||
| if crop_height > image_height or crop_width > image_width: | ||
| padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width) | ||
| image = _FT.torch_pad(image, _FT._parse_pad_padding(padding_ltrb), value=0.0) | ||
| image = torch_pad(image, _parse_pad_padding(padding_ltrb), value=0.0) | ||
|
|
||
| image_height, image_width = image.shape[-2:] | ||
| if crop_width == image_width and crop_height == image_height: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Flaky test unrelated to this PR that popped up previously on another PR: