diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 6656ecfe85b..e7373000967 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -928,16 +928,16 @@ def perspective_bounding_box( (-perspective_coeffs[0] * perspective_coeffs[7] + perspective_coeffs[1] * perspective_coeffs[6]) / denom, ] - theta1 = torch.tensor( - [[inv_coeffs[0], inv_coeffs[1], inv_coeffs[2]], [inv_coeffs[3], inv_coeffs[4], inv_coeffs[5]]], + theta12_T = torch.tensor( + [ + [inv_coeffs[0], inv_coeffs[3], inv_coeffs[6], inv_coeffs[6]], + [inv_coeffs[1], inv_coeffs[4], inv_coeffs[7], inv_coeffs[7]], + [inv_coeffs[2], inv_coeffs[5], 1.0, 1.0], + ], dtype=dtype, device=device, ) - theta2 = torch.tensor( - [[inv_coeffs[6], inv_coeffs[7], 1.0], [inv_coeffs[6], inv_coeffs[7], 1.0]], dtype=dtype, device=device - ) - # 1) Let's transform bboxes into a tensor of 4 points (top-left, top-right, bottom-left, bottom-right corners). # Tensor of points has shape (N * 4, 3), where N is the number of bboxes # Single point structure is similar to @@ -948,15 +948,16 @@ def perspective_bounding_box( # x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1) # y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1) - numer_points = torch.matmul(points, theta1.T) - denom_points = torch.matmul(points, theta2.T) - transformed_points = numer_points / denom_points + numer_denom_points = torch.matmul(points, theta12_T) + numer_points = numer_denom_points[:, :2] + denom_points = numer_denom_points[:, 2:] + transformed_points = numer_points.div_(denom_points) # 3) Reshape transformed points to [N boxes, 4 points, x/y coords] # and compute bounding box from 4 transformed points: transformed_points = transformed_points.reshape(-1, 4, 2) - out_bbox_mins, _ = torch.min(transformed_points, dim=1) - out_bbox_maxs, _ = torch.max(transformed_points, dim=1) + out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1) + out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_box.dtype) # out_bboxes should be of shape [N boxes, 4]