diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 24b240b256f..adf494b1c42 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -907,6 +907,36 @@ def crop(inpt: features.InputTypeJIT, top: int, left: int, height: int, width: i return crop_image_pil(inpt, top, left, height, width) +def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor: + # https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/ + # src/libImaging/Geometry.c#L394 + + # + # x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1) + # y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1) + # + + theta1 = torch.tensor( + [[[coeffs[0], coeffs[1], coeffs[2]], [coeffs[3], coeffs[4], coeffs[5]]]], dtype=dtype, device=device + ) + theta2 = torch.tensor([[[coeffs[6], coeffs[7], 1.0], [coeffs[6], coeffs[7], 1.0]]], dtype=dtype, device=device) + + d = 0.5 + base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device) + x_grid = torch.linspace(d, ow * 1.0 + d - 1.0, steps=ow, device=device) + base_grid[..., 0].copy_(x_grid) + y_grid = torch.linspace(d, oh * 1.0 + d - 1.0, steps=oh, device=device).unsqueeze_(-1) + base_grid[..., 1].copy_(y_grid) + base_grid[..., 2].fill_(1) + + rescaled_theta1 = theta1.transpose(1, 2).div_(torch.tensor([0.5 * ow, 0.5 * oh], dtype=dtype, device=device)) + output_grid1 = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta1) + output_grid2 = base_grid.view(1, oh * ow, 3).bmm(theta2.transpose(1, 2)) + + output_grid = output_grid1.div_(output_grid2).sub_(1.0) + return output_grid.view(1, oh, ow, 2) + + def _perspective_coefficients( startpoints: Optional[List[List[int]]], endpoints: Optional[List[List[int]]], @@ -944,7 +974,19 @@ def perspective_image_tensor( else: needs_unsquash = False - output = _FT.perspective(image, perspective_coeffs, interpolation=interpolation.value, fill=fill) + _FT._assert_grid_transform_inputs( + image, + matrix=None, + interpolation=interpolation.value, + fill=fill, + supported_interpolation_modes=["nearest", "bilinear"], + coeffs=perspective_coeffs, + ) + + ow, oh = image.shape[-1], image.shape[-2] + dtype = image.dtype if torch.is_floating_point(image) else torch.float32 + grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=image.device) + output = _FT._apply_grid_transform(image, grid, interpolation.value, fill=fill) if needs_unsquash: output = output.reshape(shape) @@ -1012,16 +1054,16 @@ def perspective_bounding_box( (-perspective_coeffs[0] * perspective_coeffs[7] + perspective_coeffs[1] * perspective_coeffs[6]) / denom, ] - theta12_T = torch.tensor( - [ - [inv_coeffs[0], inv_coeffs[3], inv_coeffs[6], inv_coeffs[6]], - [inv_coeffs[1], inv_coeffs[4], inv_coeffs[7], inv_coeffs[7]], - [inv_coeffs[2], inv_coeffs[5], 1.0, 1.0], - ], + theta1 = torch.tensor( + [[inv_coeffs[0], inv_coeffs[1], inv_coeffs[2]], [inv_coeffs[3], inv_coeffs[4], inv_coeffs[5]]], dtype=dtype, device=device, ) + theta2 = torch.tensor( + [[inv_coeffs[6], inv_coeffs[7], 1.0], [inv_coeffs[6], inv_coeffs[7], 1.0]], dtype=dtype, device=device + ) + # 1) Let's transform bboxes into a tensor of 4 points (top-left, top-right, bottom-left, bottom-right corners). # Tensor of points has shape (N * 4, 3), where N is the number of bboxes # Single point structure is similar to @@ -1032,9 +1074,8 @@ def perspective_bounding_box( # x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1) # y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1) - numer_denom_points = torch.matmul(points, theta12_T) - numer_points = numer_denom_points[:, :2] - denom_points = numer_denom_points[:, 2:] + numer_points = torch.matmul(points, theta1.T) + denom_points = torch.matmul(points, theta2.T) transformed_points = numer_points.div_(denom_points) # 3) Reshape transformed points to [N boxes, 4 points, x/y coords]