Skip to content

Commit 1502ed9

Browse files
authored
[proto] Small optims for perspective bboxes op (#6891)
* [proto] Speed-up crop on bboxes and tests * Fix linter * Update _geometry.py * Fixed device issue * Revert changes in test/prototype_transforms_kernel_infos.py * Fixed failing correctness tests * [proto] Optimized functional pad op for bboxes + tests * Renamed copy-pasted variable name * [proto] Small optims for perspective bboxes op
1 parent a2151b9 commit 1502ed9

File tree

1 file changed

+12
-11
lines changed

1 file changed

+12
-11
lines changed

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -928,16 +928,16 @@ def perspective_bounding_box(
928928
(-perspective_coeffs[0] * perspective_coeffs[7] + perspective_coeffs[1] * perspective_coeffs[6]) / denom,
929929
]
930930

931-
theta1 = torch.tensor(
932-
[[inv_coeffs[0], inv_coeffs[1], inv_coeffs[2]], [inv_coeffs[3], inv_coeffs[4], inv_coeffs[5]]],
931+
theta12_T = torch.tensor(
932+
[
933+
[inv_coeffs[0], inv_coeffs[3], inv_coeffs[6], inv_coeffs[6]],
934+
[inv_coeffs[1], inv_coeffs[4], inv_coeffs[7], inv_coeffs[7]],
935+
[inv_coeffs[2], inv_coeffs[5], 1.0, 1.0],
936+
],
933937
dtype=dtype,
934938
device=device,
935939
)
936940

937-
theta2 = torch.tensor(
938-
[[inv_coeffs[6], inv_coeffs[7], 1.0], [inv_coeffs[6], inv_coeffs[7], 1.0]], dtype=dtype, device=device
939-
)
940-
941941
# 1) Let's transform bboxes into a tensor of 4 points (top-left, top-right, bottom-left, bottom-right corners).
942942
# Tensor of points has shape (N * 4, 3), where N is the number of bboxes
943943
# Single point structure is similar to
@@ -948,15 +948,16 @@ def perspective_bounding_box(
948948
# x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
949949
# y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
950950

951-
numer_points = torch.matmul(points, theta1.T)
952-
denom_points = torch.matmul(points, theta2.T)
953-
transformed_points = numer_points / denom_points
951+
numer_denom_points = torch.matmul(points, theta12_T)
952+
numer_points = numer_denom_points[:, :2]
953+
denom_points = numer_denom_points[:, 2:]
954+
transformed_points = numer_points.div_(denom_points)
954955

955956
# 3) Reshape transformed points to [N boxes, 4 points, x/y coords]
956957
# and compute bounding box from 4 transformed points:
957958
transformed_points = transformed_points.reshape(-1, 4, 2)
958-
out_bbox_mins, _ = torch.min(transformed_points, dim=1)
959-
out_bbox_maxs, _ = torch.max(transformed_points, dim=1)
959+
out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1)
960+
960961
out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_box.dtype)
961962

962963
# out_bboxes should be of shape [N boxes, 4]

0 commit comments

Comments
 (0)