Skip to content

Commit e55c02e

Browse files
Yosua Michael Maranathafacebook-github-bot
Yosua Michael Maranatha
authored andcommitted
[fbsync] [proto] Small optim for perspective op on images (#6907)
Summary: * [proto] small optim for perspective op on images, reverted concat trick on bboxes * revert unrelated changes * PR review updates * PR review change Reviewed By: NicolasHug Differential Revision: D41265184 fbshipit-source-id: 12073a164180b2ed392dd455106f6411bab9a317
1 parent 41b0a19 commit e55c02e

File tree

1 file changed

+51
-10
lines changed

1 file changed

+51
-10
lines changed

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 51 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -907,6 +907,36 @@ def crop(inpt: features.InputTypeJIT, top: int, left: int, height: int, width: i
907907
return crop_image_pil(inpt, top, left, height, width)
908908

909909

910+
def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
911+
# https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/
912+
# src/libImaging/Geometry.c#L394
913+
914+
#
915+
# x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
916+
# y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
917+
#
918+
919+
theta1 = torch.tensor(
920+
[[[coeffs[0], coeffs[1], coeffs[2]], [coeffs[3], coeffs[4], coeffs[5]]]], dtype=dtype, device=device
921+
)
922+
theta2 = torch.tensor([[[coeffs[6], coeffs[7], 1.0], [coeffs[6], coeffs[7], 1.0]]], dtype=dtype, device=device)
923+
924+
d = 0.5
925+
base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device)
926+
x_grid = torch.linspace(d, ow * 1.0 + d - 1.0, steps=ow, device=device)
927+
base_grid[..., 0].copy_(x_grid)
928+
y_grid = torch.linspace(d, oh * 1.0 + d - 1.0, steps=oh, device=device).unsqueeze_(-1)
929+
base_grid[..., 1].copy_(y_grid)
930+
base_grid[..., 2].fill_(1)
931+
932+
rescaled_theta1 = theta1.transpose(1, 2).div_(torch.tensor([0.5 * ow, 0.5 * oh], dtype=dtype, device=device))
933+
output_grid1 = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta1)
934+
output_grid2 = base_grid.view(1, oh * ow, 3).bmm(theta2.transpose(1, 2))
935+
936+
output_grid = output_grid1.div_(output_grid2).sub_(1.0)
937+
return output_grid.view(1, oh, ow, 2)
938+
939+
910940
def _perspective_coefficients(
911941
startpoints: Optional[List[List[int]]],
912942
endpoints: Optional[List[List[int]]],
@@ -944,7 +974,19 @@ def perspective_image_tensor(
944974
else:
945975
needs_unsquash = False
946976

947-
output = _FT.perspective(image, perspective_coeffs, interpolation=interpolation.value, fill=fill)
977+
_FT._assert_grid_transform_inputs(
978+
image,
979+
matrix=None,
980+
interpolation=interpolation.value,
981+
fill=fill,
982+
supported_interpolation_modes=["nearest", "bilinear"],
983+
coeffs=perspective_coeffs,
984+
)
985+
986+
ow, oh = image.shape[-1], image.shape[-2]
987+
dtype = image.dtype if torch.is_floating_point(image) else torch.float32
988+
grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=image.device)
989+
output = _FT._apply_grid_transform(image, grid, interpolation.value, fill=fill)
948990

949991
if needs_unsquash:
950992
output = output.reshape(shape)
@@ -1012,16 +1054,16 @@ def perspective_bounding_box(
10121054
(-perspective_coeffs[0] * perspective_coeffs[7] + perspective_coeffs[1] * perspective_coeffs[6]) / denom,
10131055
]
10141056

1015-
theta12_T = torch.tensor(
1016-
[
1017-
[inv_coeffs[0], inv_coeffs[3], inv_coeffs[6], inv_coeffs[6]],
1018-
[inv_coeffs[1], inv_coeffs[4], inv_coeffs[7], inv_coeffs[7]],
1019-
[inv_coeffs[2], inv_coeffs[5], 1.0, 1.0],
1020-
],
1057+
theta1 = torch.tensor(
1058+
[[inv_coeffs[0], inv_coeffs[1], inv_coeffs[2]], [inv_coeffs[3], inv_coeffs[4], inv_coeffs[5]]],
10211059
dtype=dtype,
10221060
device=device,
10231061
)
10241062

1063+
theta2 = torch.tensor(
1064+
[[inv_coeffs[6], inv_coeffs[7], 1.0], [inv_coeffs[6], inv_coeffs[7], 1.0]], dtype=dtype, device=device
1065+
)
1066+
10251067
# 1) Let's transform bboxes into a tensor of 4 points (top-left, top-right, bottom-left, bottom-right corners).
10261068
# Tensor of points has shape (N * 4, 3), where N is the number of bboxes
10271069
# Single point structure is similar to
@@ -1032,9 +1074,8 @@ def perspective_bounding_box(
10321074
# x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
10331075
# y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
10341076

1035-
numer_denom_points = torch.matmul(points, theta12_T)
1036-
numer_points = numer_denom_points[:, :2]
1037-
denom_points = numer_denom_points[:, 2:]
1077+
numer_points = torch.matmul(points, theta1.T)
1078+
denom_points = torch.matmul(points, theta2.T)
10381079
transformed_points = numer_points.div_(denom_points)
10391080

10401081
# 3) Reshape transformed points to [N boxes, 4 points, x/y coords]

0 commit comments

Comments
 (0)