|
10 | 10 | from torch import jit
|
11 | 11 | from torch.nn.functional import one_hot
|
12 | 12 | from torchvision.prototype import features
|
| 13 | +from torchvision.prototype.transforms.functional._geometry import _center_crop_compute_padding |
13 | 14 | from torchvision.prototype.transforms.functional._meta import convert_bounding_box_format
|
14 | 15 | from torchvision.transforms.functional import _get_perspective_coeffs
|
15 | 16 | from torchvision.transforms.functional_tensor import _max_value as get_max_value
|
16 | 17 |
|
17 |
| - |
18 | 18 | make_tensor = functools.partial(torch.testing.make_tensor, device="cpu")
|
19 | 19 |
|
20 | 20 |
|
@@ -421,6 +421,14 @@ def center_crop_bounding_box():
|
421 | 421 | )
|
422 | 422 |
|
423 | 423 |
|
| 424 | +def center_crop_segmentation_mask(): |
| 425 | + for mask, output_size in itertools.product( |
| 426 | + make_segmentation_masks(image_sizes=((16, 16), (7, 33), (31, 9))), |
| 427 | + [[4, 3], [42, 70], [4]], # crop sizes < image sizes, crop_sizes > image sizes, single crop size |
| 428 | + ): |
| 429 | + yield SampleInput(mask, output_size) |
| 430 | + |
| 431 | + |
424 | 432 | @pytest.mark.parametrize(
|
425 | 433 | "kernel",
|
426 | 434 | [
|
@@ -1337,3 +1345,26 @@ def _compute_expected_bbox(bbox, output_size_):
|
1337 | 1345 | else:
|
1338 | 1346 | expected_bboxes = expected_bboxes[0]
|
1339 | 1347 | torch.testing.assert_close(output_boxes, expected_bboxes)
|
| 1348 | + |
| 1349 | + |
| 1350 | +@pytest.mark.parametrize("device", cpu_and_gpu()) |
| 1351 | +@pytest.mark.parametrize("output_size", [[4, 2], [4], [7, 6]]) |
| 1352 | +def test_correctness_center_crop_segmentation_mask(device, output_size): |
| 1353 | + def _compute_expected_segmentation_mask(mask, output_size): |
| 1354 | + crop_height, crop_width = output_size if len(output_size) > 1 else [output_size[0], output_size[0]] |
| 1355 | + |
| 1356 | + _, image_height, image_width = mask.shape |
| 1357 | + if crop_width > image_height or crop_height > image_width: |
| 1358 | + padding = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width) |
| 1359 | + mask = F.pad_image_tensor(mask, padding, fill=0) |
| 1360 | + |
| 1361 | + left = round((image_width - crop_width) * 0.5) |
| 1362 | + top = round((image_height - crop_height) * 0.5) |
| 1363 | + |
| 1364 | + return mask[:, top : top + crop_height, left : left + crop_width] |
| 1365 | + |
| 1366 | + mask = torch.randint(0, 2, size=(1, 6, 6), dtype=torch.long, device=device) |
| 1367 | + actual = F.center_crop_segmentation_mask(mask, output_size) |
| 1368 | + |
| 1369 | + expected = _compute_expected_segmentation_mask(mask, output_size) |
| 1370 | + torch.testing.assert_close(expected, actual) |
0 commit comments