Skip to content

Commit 41a76ba

Browse files
datumboxfacebook-github-bot
authored andcommitted
[fbsync] [proto] Added functional perspective_bounding_box/segmentation_mask ops (#5888)
Summary: * Added functional `perspective_bounding_box`/`perspective_segmentation_mask` ops * Added more comments and added a code to assert denom != 0 * Put larger r/a tolerence when matching bboxes Reviewed By: YosuaMichael Differential Revision: D36281602 fbshipit-source-id: 5976f009e2dad93ff9a1a84129e8bbc066e91cf6
1 parent 33a47cb commit 41a76ba

File tree

3 files changed

+283
-4
lines changed

3 files changed

+283
-4
lines changed

test/test_prototype_transforms_functional.py

Lines changed: 192 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111
from torch.nn.functional import one_hot
1212
from torchvision.prototype import features
1313
from torchvision.prototype.transforms.functional._meta import convert_bounding_box_format
14+
from torchvision.transforms.functional import _get_perspective_coeffs
1415
from torchvision.transforms.functional_tensor import _max_value as get_max_value
1516

17+
1618
make_tensor = functools.partial(torch.testing.make_tensor, device="cpu")
1719

1820

@@ -380,6 +382,37 @@ def pad_segmentation_mask():
380382
yield SampleInput(mask, padding=padding, padding_mode=padding_mode)
381383

382384

385+
@register_kernel_info_from_sample_inputs_fn
386+
def perspective_bounding_box():
387+
for bounding_box, perspective_coeffs in itertools.product(
388+
make_bounding_boxes(),
389+
[
390+
[1.2405, 0.1772, -6.9113, 0.0463, 1.251, -5.235, 0.00013, 0.0018],
391+
[0.7366, -0.11724, 1.45775, -0.15012, 0.73406, 2.6019, -0.0072, -0.0063],
392+
],
393+
):
394+
yield SampleInput(
395+
bounding_box,
396+
format=bounding_box.format,
397+
perspective_coeffs=perspective_coeffs,
398+
)
399+
400+
401+
@register_kernel_info_from_sample_inputs_fn
402+
def perspective_segmentation_mask():
403+
for mask, perspective_coeffs in itertools.product(
404+
make_segmentation_masks(extra_dims=((), (4,))),
405+
[
406+
[1.2405, 0.1772, -6.9113, 0.0463, 1.251, -5.235, 0.00013, 0.0018],
407+
[0.7366, -0.11724, 1.45775, -0.15012, 0.73406, 2.6019, -0.0072, -0.0063],
408+
],
409+
):
410+
yield SampleInput(
411+
mask,
412+
perspective_coeffs=perspective_coeffs,
413+
)
414+
415+
383416
@pytest.mark.parametrize(
384417
"kernel",
385418
[
@@ -985,7 +1018,7 @@ def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device):
9851018
],
9861019
)
9871020
def test_correctness_resized_crop_bounding_box(device, format, top, left, height, width, size):
988-
def _compute_expected(bbox, top_, left_, height_, width_, size_):
1021+
def _compute_expected_bbox(bbox, top_, left_, height_, width_, size_):
9891022
# bbox should be xyxy
9901023
bbox[0] = (bbox[0] - left_) * size_[1] / width_
9911024
bbox[1] = (bbox[1] - top_) * size_[0] / height_
@@ -1001,7 +1034,7 @@ def _compute_expected(bbox, top_, left_, height_, width_, size_):
10011034
]
10021035
expected_bboxes = []
10031036
for in_box in in_boxes:
1004-
expected_bboxes.append(_compute_expected(list(in_box), top, left, height, width, size))
1037+
expected_bboxes.append(_compute_expected_bbox(list(in_box), top, left, height, width, size))
10051038
expected_bboxes = torch.tensor(expected_bboxes, device=device)
10061039

10071040
in_boxes = features.BoundingBox(
@@ -1027,7 +1060,7 @@ def _compute_expected(bbox, top_, left_, height_, width_, size_):
10271060
],
10281061
)
10291062
def test_correctness_resized_crop_segmentation_mask(device, top, left, height, width, size):
1030-
def _compute_expected(mask, top_, left_, height_, width_, size_):
1063+
def _compute_expected_mask(mask, top_, left_, height_, width_, size_):
10311064
output = mask.clone()
10321065
output = output[:, top_ : top_ + height_, left_ : left_ + width_]
10331066
output = torch.nn.functional.interpolate(output[None, :].float(), size=size_, mode="nearest")
@@ -1038,7 +1071,7 @@ def _compute_expected(mask, top_, left_, height_, width_, size_):
10381071
in_mask[0, 10:20, 10:20] = 1
10391072
in_mask[0, 5:15, 12:23] = 2
10401073

1041-
expected_mask = _compute_expected(in_mask, top, left, height, width, size)
1074+
expected_mask = _compute_expected_mask(in_mask, top, left, height, width, size)
10421075
output_mask = F.resized_crop_segmentation_mask(in_mask, top, left, height, width, size)
10431076
torch.testing.assert_close(output_mask, expected_mask)
10441077

@@ -1085,3 +1118,158 @@ def parse_padding():
10851118

10861119
expected_mask = _compute_expected_mask()
10871120
torch.testing.assert_close(out_mask, expected_mask)
1121+
1122+
1123+
@pytest.mark.parametrize("device", cpu_and_gpu())
1124+
@pytest.mark.parametrize(
1125+
"startpoints, endpoints",
1126+
[
1127+
[[[0, 0], [33, 0], [33, 25], [0, 25]], [[3, 2], [32, 3], [30, 24], [2, 25]]],
1128+
[[[3, 2], [32, 3], [30, 24], [2, 25]], [[0, 0], [33, 0], [33, 25], [0, 25]]],
1129+
[[[3, 2], [32, 3], [30, 24], [2, 25]], [[5, 5], [30, 3], [33, 19], [4, 25]]],
1130+
],
1131+
)
1132+
def test_correctness_perspective_bounding_box(device, startpoints, endpoints):
1133+
def _compute_expected_bbox(bbox, pcoeffs_):
1134+
m1 = np.array(
1135+
[
1136+
[pcoeffs_[0], pcoeffs_[1], pcoeffs_[2]],
1137+
[pcoeffs_[3], pcoeffs_[4], pcoeffs_[5]],
1138+
]
1139+
)
1140+
m2 = np.array(
1141+
[
1142+
[pcoeffs_[6], pcoeffs_[7], 1.0],
1143+
[pcoeffs_[6], pcoeffs_[7], 1.0],
1144+
]
1145+
)
1146+
1147+
bbox_xyxy = convert_bounding_box_format(
1148+
bbox, old_format=bbox.format, new_format=features.BoundingBoxFormat.XYXY
1149+
)
1150+
points = np.array(
1151+
[
1152+
[bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0],
1153+
[bbox_xyxy[2].item(), bbox_xyxy[1].item(), 1.0],
1154+
[bbox_xyxy[0].item(), bbox_xyxy[3].item(), 1.0],
1155+
[bbox_xyxy[2].item(), bbox_xyxy[3].item(), 1.0],
1156+
]
1157+
)
1158+
numer = np.matmul(points, m1.T)
1159+
denom = np.matmul(points, m2.T)
1160+
transformed_points = numer / denom
1161+
out_bbox = [
1162+
np.min(transformed_points[:, 0]),
1163+
np.min(transformed_points[:, 1]),
1164+
np.max(transformed_points[:, 0]),
1165+
np.max(transformed_points[:, 1]),
1166+
]
1167+
out_bbox = features.BoundingBox(
1168+
out_bbox,
1169+
format=features.BoundingBoxFormat.XYXY,
1170+
image_size=bbox.image_size,
1171+
dtype=torch.float32,
1172+
device=bbox.device,
1173+
)
1174+
return convert_bounding_box_format(
1175+
out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox.format, copy=False
1176+
)
1177+
1178+
image_size = (32, 38)
1179+
1180+
pcoeffs = _get_perspective_coeffs(startpoints, endpoints)
1181+
inv_pcoeffs = _get_perspective_coeffs(endpoints, startpoints)
1182+
1183+
for bboxes in make_bounding_boxes(
1184+
image_sizes=[
1185+
image_size,
1186+
],
1187+
extra_dims=((4,),),
1188+
):
1189+
bboxes = bboxes.to(device)
1190+
bboxes_format = bboxes.format
1191+
bboxes_image_size = bboxes.image_size
1192+
1193+
output_bboxes = F.perspective_bounding_box(
1194+
bboxes,
1195+
bboxes_format,
1196+
perspective_coeffs=pcoeffs,
1197+
)
1198+
1199+
if bboxes.ndim < 2:
1200+
bboxes = [bboxes]
1201+
1202+
expected_bboxes = []
1203+
for bbox in bboxes:
1204+
bbox = features.BoundingBox(bbox, format=bboxes_format, image_size=bboxes_image_size)
1205+
expected_bboxes.append(_compute_expected_bbox(bbox, inv_pcoeffs))
1206+
if len(expected_bboxes) > 1:
1207+
expected_bboxes = torch.stack(expected_bboxes)
1208+
else:
1209+
expected_bboxes = expected_bboxes[0]
1210+
torch.testing.assert_close(output_bboxes, expected_bboxes, rtol=1e-5, atol=1e-5)
1211+
1212+
1213+
@pytest.mark.parametrize("device", cpu_and_gpu())
1214+
@pytest.mark.parametrize(
1215+
"startpoints, endpoints",
1216+
[
1217+
[[[0, 0], [33, 0], [33, 25], [0, 25]], [[3, 2], [32, 3], [30, 24], [2, 25]]],
1218+
[[[3, 2], [32, 3], [30, 24], [2, 25]], [[0, 0], [33, 0], [33, 25], [0, 25]]],
1219+
[[[3, 2], [32, 3], [30, 24], [2, 25]], [[5, 5], [30, 3], [33, 19], [4, 25]]],
1220+
],
1221+
)
1222+
def test_correctness_perspective_segmentation_mask(device, startpoints, endpoints):
1223+
def _compute_expected_mask(mask, pcoeffs_):
1224+
assert mask.ndim == 3 and mask.shape[0] == 1
1225+
m1 = np.array(
1226+
[
1227+
[pcoeffs_[0], pcoeffs_[1], pcoeffs_[2]],
1228+
[pcoeffs_[3], pcoeffs_[4], pcoeffs_[5]],
1229+
]
1230+
)
1231+
m2 = np.array(
1232+
[
1233+
[pcoeffs_[6], pcoeffs_[7], 1.0],
1234+
[pcoeffs_[6], pcoeffs_[7], 1.0],
1235+
]
1236+
)
1237+
1238+
expected_mask = torch.zeros_like(mask.cpu())
1239+
for out_y in range(expected_mask.shape[1]):
1240+
for out_x in range(expected_mask.shape[2]):
1241+
output_pt = np.array([out_x + 0.5, out_y + 0.5, 1.0])
1242+
1243+
numer = np.matmul(output_pt, m1.T)
1244+
denom = np.matmul(output_pt, m2.T)
1245+
input_pt = np.floor(numer / denom).astype(np.int32)
1246+
1247+
in_x, in_y = input_pt[:2]
1248+
if 0 <= in_x < mask.shape[2] and 0 <= in_y < mask.shape[1]:
1249+
expected_mask[0, out_y, out_x] = mask[0, in_y, in_x]
1250+
return expected_mask.to(mask.device)
1251+
1252+
pcoeffs = _get_perspective_coeffs(startpoints, endpoints)
1253+
1254+
for mask in make_segmentation_masks(extra_dims=((), (4,))):
1255+
mask = mask.to(device)
1256+
1257+
output_mask = F.perspective_segmentation_mask(
1258+
mask,
1259+
perspective_coeffs=pcoeffs,
1260+
)
1261+
1262+
if mask.ndim < 4:
1263+
masks = [mask]
1264+
else:
1265+
masks = [m for m in mask]
1266+
1267+
expected_masks = []
1268+
for mask in masks:
1269+
expected_mask = _compute_expected_mask(mask, pcoeffs)
1270+
expected_masks.append(expected_mask)
1271+
if len(expected_masks) > 1:
1272+
expected_masks = torch.stack(expected_masks)
1273+
else:
1274+
expected_masks = expected_masks[0]
1275+
torch.testing.assert_close(output_mask, expected_masks)

torchvision/prototype/transforms/functional/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,10 @@
6767
crop_image_tensor,
6868
crop_image_pil,
6969
crop_segmentation_mask,
70+
perspective_bounding_box,
7071
perspective_image_tensor,
7172
perspective_image_pil,
73+
perspective_segmentation_mask,
7274
vertical_flip_image_tensor,
7375
vertical_flip_image_pil,
7476
vertical_flip_bounding_box,

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,95 @@ def perspective_image_pil(
472472
return _FP.perspective(img, perspective_coeffs, interpolation=pil_modes_mapping[interpolation], fill=fill)
473473

474474

475+
def perspective_bounding_box(
476+
bounding_box: torch.Tensor,
477+
format: features.BoundingBoxFormat,
478+
perspective_coeffs: List[float],
479+
) -> torch.Tensor:
480+
481+
if len(perspective_coeffs) != 8:
482+
raise ValueError("Argument perspective_coeffs should have 8 float values")
483+
484+
original_shape = bounding_box.shape
485+
bounding_box = convert_bounding_box_format(
486+
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
487+
).view(-1, 4)
488+
489+
dtype = bounding_box.dtype if torch.is_floating_point(bounding_box) else torch.float32
490+
device = bounding_box.device
491+
492+
# perspective_coeffs are computed as endpoint -> start point
493+
# We have to invert perspective_coeffs for bboxes:
494+
# (x, y) - end point and (x_out, y_out) - start point
495+
# x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
496+
# y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
497+
# and we would like to get:
498+
# x = (inv_coeffs[0] * x_out + inv_coeffs[1] * y_out + inv_coeffs[2])
499+
# / (inv_coeffs[6] * x_out + inv_coeffs[7] * y_out + 1)
500+
# y = (inv_coeffs[3] * x_out + inv_coeffs[4] * y_out + inv_coeffs[5])
501+
# / (inv_coeffs[6] * x_out + inv_coeffs[7] * y_out + 1)
502+
# and compute inv_coeffs in terms of coeffs
503+
504+
denom = perspective_coeffs[0] * perspective_coeffs[4] - perspective_coeffs[1] * perspective_coeffs[3]
505+
if denom == 0:
506+
raise RuntimeError(
507+
f"Provided perspective_coeffs {perspective_coeffs} can not be inverted to transform bounding boxes. "
508+
f"Denominator is zero, denom={denom}"
509+
)
510+
511+
inv_coeffs = [
512+
(perspective_coeffs[4] - perspective_coeffs[5] * perspective_coeffs[7]) / denom,
513+
(-perspective_coeffs[1] + perspective_coeffs[2] * perspective_coeffs[7]) / denom,
514+
(perspective_coeffs[1] * perspective_coeffs[5] - perspective_coeffs[2] * perspective_coeffs[4]) / denom,
515+
(-perspective_coeffs[3] + perspective_coeffs[5] * perspective_coeffs[6]) / denom,
516+
(perspective_coeffs[0] - perspective_coeffs[2] * perspective_coeffs[6]) / denom,
517+
(-perspective_coeffs[0] * perspective_coeffs[5] + perspective_coeffs[2] * perspective_coeffs[3]) / denom,
518+
(-perspective_coeffs[4] * perspective_coeffs[6] + perspective_coeffs[3] * perspective_coeffs[7]) / denom,
519+
(-perspective_coeffs[0] * perspective_coeffs[7] + perspective_coeffs[1] * perspective_coeffs[6]) / denom,
520+
]
521+
522+
theta1 = torch.tensor(
523+
[[inv_coeffs[0], inv_coeffs[1], inv_coeffs[2]], [inv_coeffs[3], inv_coeffs[4], inv_coeffs[5]]],
524+
dtype=dtype,
525+
device=device,
526+
)
527+
528+
theta2 = torch.tensor(
529+
[[inv_coeffs[6], inv_coeffs[7], 1.0], [inv_coeffs[6], inv_coeffs[7], 1.0]], dtype=dtype, device=device
530+
)
531+
532+
# 1) Let's transform bboxes into a tensor of 4 points (top-left, top-right, bottom-left, bottom-right corners).
533+
# Tensor of points has shape (N * 4, 3), where N is the number of bboxes
534+
# Single point structure is similar to
535+
# [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)]
536+
points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].view(-1, 2)
537+
points = torch.cat([points, torch.ones(points.shape[0], 1, device=points.device)], dim=-1)
538+
# 2) Now let's transform the points using perspective matrices
539+
# x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
540+
# y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
541+
542+
numer_points = torch.matmul(points, theta1.T)
543+
denom_points = torch.matmul(points, theta2.T)
544+
transformed_points = numer_points / denom_points
545+
546+
# 3) Reshape transformed points to [N boxes, 4 points, x/y coords]
547+
# and compute bounding box from 4 transformed points:
548+
transformed_points = transformed_points.view(-1, 4, 2)
549+
out_bbox_mins, _ = torch.min(transformed_points, dim=1)
550+
out_bbox_maxs, _ = torch.max(transformed_points, dim=1)
551+
out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1)
552+
553+
# out_bboxes should be of shape [N boxes, 4]
554+
555+
return convert_bounding_box_format(
556+
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
557+
).view(original_shape)
558+
559+
560+
def perspective_segmentation_mask(img: torch.Tensor, perspective_coeffs: List[float]) -> torch.Tensor:
561+
return perspective_image_tensor(img, perspective_coeffs=perspective_coeffs, interpolation=InterpolationMode.NEAREST)
562+
563+
475564
def _center_crop_parse_output_size(output_size: List[int]) -> List[int]:
476565
if isinstance(output_size, numbers.Number):
477566
return [int(output_size), int(output_size)]

0 commit comments

Comments
 (0)