|
10 | 10 | from test_prototype_transforms_functional import (
|
11 | 11 | make_bounding_box,
|
12 | 12 | make_bounding_boxes,
|
| 13 | + make_image, |
13 | 14 | make_images,
|
14 | 15 | make_label,
|
15 | 16 | make_one_hot_labels,
|
@@ -1328,3 +1329,161 @@ def test__transform(self, mocker):
|
1328 | 1329 | transform(inpt_sentinel)
|
1329 | 1330 |
|
1330 | 1331 | mock.assert_called_once_with(inpt_sentinel, size=size_sentinel, interpolation=interpolation_sentinel)
|
| 1332 | + |
| 1333 | + |
| 1334 | +class TestFixedSizeCrop: |
| 1335 | + def test__get_params(self, mocker): |
| 1336 | + crop_size = (7, 7) |
| 1337 | + batch_shape = (10,) |
| 1338 | + image_size = (11, 5) |
| 1339 | + |
| 1340 | + transform = transforms.FixedSizeCrop(size=crop_size) |
| 1341 | + |
| 1342 | + sample = dict( |
| 1343 | + image=make_image(size=image_size, color_space=features.ColorSpace.RGB), |
| 1344 | + bounding_boxes=make_bounding_box( |
| 1345 | + format=features.BoundingBoxFormat.XYXY, image_size=image_size, extra_dims=batch_shape |
| 1346 | + ), |
| 1347 | + ) |
| 1348 | + params = transform._get_params(sample) |
| 1349 | + |
| 1350 | + assert params["needs_crop"] |
| 1351 | + assert params["height"] <= crop_size[0] |
| 1352 | + assert params["width"] <= crop_size[1] |
| 1353 | + |
| 1354 | + assert ( |
| 1355 | + isinstance(params["is_valid"], torch.Tensor) |
| 1356 | + and params["is_valid"].dtype is torch.bool |
| 1357 | + and params["is_valid"].shape == batch_shape |
| 1358 | + ) |
| 1359 | + |
| 1360 | + assert params["needs_pad"] |
| 1361 | + assert any(pad > 0 for pad in params["padding"]) |
| 1362 | + |
| 1363 | + @pytest.mark.parametrize("needs", list(itertools.product((False, True), repeat=2))) |
| 1364 | + def test__transform(self, mocker, needs): |
| 1365 | + fill_sentinel = mocker.MagicMock() |
| 1366 | + padding_mode_sentinel = mocker.MagicMock() |
| 1367 | + |
| 1368 | + transform = transforms.FixedSizeCrop((-1, -1), fill=fill_sentinel, padding_mode=padding_mode_sentinel) |
| 1369 | + transform._transformed_types = (mocker.MagicMock,) |
| 1370 | + mocker.patch("torchvision.prototype.transforms._geometry.has_all", return_value=True) |
| 1371 | + mocker.patch("torchvision.prototype.transforms._geometry.has_any", return_value=True) |
| 1372 | + |
| 1373 | + needs_crop, needs_pad = needs |
| 1374 | + top_sentinel = mocker.MagicMock() |
| 1375 | + left_sentinel = mocker.MagicMock() |
| 1376 | + height_sentinel = mocker.MagicMock() |
| 1377 | + width_sentinel = mocker.MagicMock() |
| 1378 | + padding_sentinel = mocker.MagicMock() |
| 1379 | + mocker.patch( |
| 1380 | + "torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params", |
| 1381 | + return_value=dict( |
| 1382 | + needs_crop=needs_crop, |
| 1383 | + top=top_sentinel, |
| 1384 | + left=left_sentinel, |
| 1385 | + height=height_sentinel, |
| 1386 | + width=width_sentinel, |
| 1387 | + padding=padding_sentinel, |
| 1388 | + needs_pad=needs_pad, |
| 1389 | + ), |
| 1390 | + ) |
| 1391 | + |
| 1392 | + inpt_sentinel = mocker.MagicMock() |
| 1393 | + |
| 1394 | + mock_crop = mocker.patch("torchvision.prototype.transforms._geometry.F.crop") |
| 1395 | + mock_pad = mocker.patch("torchvision.prototype.transforms._geometry.F.pad") |
| 1396 | + transform(inpt_sentinel) |
| 1397 | + |
| 1398 | + if needs_crop: |
| 1399 | + mock_crop.assert_called_once_with( |
| 1400 | + inpt_sentinel, |
| 1401 | + top=top_sentinel, |
| 1402 | + left=left_sentinel, |
| 1403 | + height=height_sentinel, |
| 1404 | + width=width_sentinel, |
| 1405 | + ) |
| 1406 | + else: |
| 1407 | + mock_crop.assert_not_called() |
| 1408 | + |
| 1409 | + if needs_pad: |
| 1410 | + # If we cropped before, the input to F.pad is no longer inpt_sentinel. Thus, we can't use |
| 1411 | + # `MagicMock.assert_called_once_with` and have to perform the checks manually |
| 1412 | + mock_pad.assert_called_once() |
| 1413 | + args, kwargs = mock_pad.call_args |
| 1414 | + if not needs_crop: |
| 1415 | + assert args[0] is inpt_sentinel |
| 1416 | + assert args[1] is padding_sentinel |
| 1417 | + assert kwargs == dict(fill=fill_sentinel, padding_mode=padding_mode_sentinel) |
| 1418 | + else: |
| 1419 | + mock_pad.assert_not_called() |
| 1420 | + |
| 1421 | + def test__transform_culling(self, mocker): |
| 1422 | + batch_size = 10 |
| 1423 | + image_size = (10, 10) |
| 1424 | + |
| 1425 | + is_valid = torch.randint(0, 2, (batch_size,), dtype=torch.bool) |
| 1426 | + mocker.patch( |
| 1427 | + "torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params", |
| 1428 | + return_value=dict( |
| 1429 | + needs_crop=True, |
| 1430 | + top=0, |
| 1431 | + left=0, |
| 1432 | + height=image_size[0], |
| 1433 | + width=image_size[1], |
| 1434 | + is_valid=is_valid, |
| 1435 | + needs_pad=False, |
| 1436 | + ), |
| 1437 | + ) |
| 1438 | + |
| 1439 | + bounding_boxes = make_bounding_box( |
| 1440 | + format=features.BoundingBoxFormat.XYXY, image_size=image_size, extra_dims=(batch_size,) |
| 1441 | + ) |
| 1442 | + segmentation_masks = make_segmentation_mask(size=image_size, extra_dims=(batch_size,)) |
| 1443 | + labels = make_label(size=(batch_size,)) |
| 1444 | + |
| 1445 | + transform = transforms.FixedSizeCrop((-1, -1)) |
| 1446 | + mocker.patch("torchvision.prototype.transforms._geometry.has_all", return_value=True) |
| 1447 | + mocker.patch("torchvision.prototype.transforms._geometry.has_any", return_value=True) |
| 1448 | + |
| 1449 | + output = transform( |
| 1450 | + dict( |
| 1451 | + bounding_boxes=bounding_boxes, |
| 1452 | + segmentation_masks=segmentation_masks, |
| 1453 | + labels=labels, |
| 1454 | + ) |
| 1455 | + ) |
| 1456 | + |
| 1457 | + assert_equal(output["bounding_boxes"], bounding_boxes[is_valid]) |
| 1458 | + assert_equal(output["segmentation_masks"], segmentation_masks[is_valid]) |
| 1459 | + assert_equal(output["labels"], labels[is_valid]) |
| 1460 | + |
| 1461 | + def test__transform_bounding_box_clamping(self, mocker): |
| 1462 | + batch_size = 3 |
| 1463 | + image_size = (10, 10) |
| 1464 | + |
| 1465 | + mocker.patch( |
| 1466 | + "torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params", |
| 1467 | + return_value=dict( |
| 1468 | + needs_crop=True, |
| 1469 | + top=0, |
| 1470 | + left=0, |
| 1471 | + height=image_size[0], |
| 1472 | + width=image_size[1], |
| 1473 | + is_valid=torch.full((batch_size,), fill_value=True), |
| 1474 | + needs_pad=False, |
| 1475 | + ), |
| 1476 | + ) |
| 1477 | + |
| 1478 | + bounding_box = make_bounding_box( |
| 1479 | + format=features.BoundingBoxFormat.XYXY, image_size=image_size, extra_dims=(batch_size,) |
| 1480 | + ) |
| 1481 | + mock = mocker.patch("torchvision.prototype.transforms._geometry.F.clamp_bounding_box") |
| 1482 | + |
| 1483 | + transform = transforms.FixedSizeCrop((-1, -1)) |
| 1484 | + mocker.patch("torchvision.prototype.transforms._geometry.has_all", return_value=True) |
| 1485 | + mocker.patch("torchvision.prototype.transforms._geometry.has_any", return_value=True) |
| 1486 | + |
| 1487 | + transform(bounding_box) |
| 1488 | + |
| 1489 | + mock.assert_called_once() |
0 commit comments