Skip to content

Commit 9c3e2bf

Browse files
authored
port FixedSizeCrop from detection references to prototype transforms (#6417)
* port `FixedSizeCrop` from detection references to prototype transforms * mypy * [skip ci] call invalid boxes and corresponding masks and labels * cherry-pick missing functions from #6401 * fix feature wrapping * add test * mypy * add input type restrictions * add test for _get_params * fix input checks
1 parent 9662001 commit 9c3e2bf

File tree

3 files changed

+257
-0
lines changed

3 files changed

+257
-0
lines changed

test/test_prototype_transforms.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from test_prototype_transforms_functional import (
1111
make_bounding_box,
1212
make_bounding_boxes,
13+
make_image,
1314
make_images,
1415
make_label,
1516
make_one_hot_labels,
@@ -1328,3 +1329,161 @@ def test__transform(self, mocker):
13281329
transform(inpt_sentinel)
13291330

13301331
mock.assert_called_once_with(inpt_sentinel, size=size_sentinel, interpolation=interpolation_sentinel)
1332+
1333+
1334+
class TestFixedSizeCrop:
1335+
def test__get_params(self, mocker):
1336+
crop_size = (7, 7)
1337+
batch_shape = (10,)
1338+
image_size = (11, 5)
1339+
1340+
transform = transforms.FixedSizeCrop(size=crop_size)
1341+
1342+
sample = dict(
1343+
image=make_image(size=image_size, color_space=features.ColorSpace.RGB),
1344+
bounding_boxes=make_bounding_box(
1345+
format=features.BoundingBoxFormat.XYXY, image_size=image_size, extra_dims=batch_shape
1346+
),
1347+
)
1348+
params = transform._get_params(sample)
1349+
1350+
assert params["needs_crop"]
1351+
assert params["height"] <= crop_size[0]
1352+
assert params["width"] <= crop_size[1]
1353+
1354+
assert (
1355+
isinstance(params["is_valid"], torch.Tensor)
1356+
and params["is_valid"].dtype is torch.bool
1357+
and params["is_valid"].shape == batch_shape
1358+
)
1359+
1360+
assert params["needs_pad"]
1361+
assert any(pad > 0 for pad in params["padding"])
1362+
1363+
@pytest.mark.parametrize("needs", list(itertools.product((False, True), repeat=2)))
1364+
def test__transform(self, mocker, needs):
1365+
fill_sentinel = mocker.MagicMock()
1366+
padding_mode_sentinel = mocker.MagicMock()
1367+
1368+
transform = transforms.FixedSizeCrop((-1, -1), fill=fill_sentinel, padding_mode=padding_mode_sentinel)
1369+
transform._transformed_types = (mocker.MagicMock,)
1370+
mocker.patch("torchvision.prototype.transforms._geometry.has_all", return_value=True)
1371+
mocker.patch("torchvision.prototype.transforms._geometry.has_any", return_value=True)
1372+
1373+
needs_crop, needs_pad = needs
1374+
top_sentinel = mocker.MagicMock()
1375+
left_sentinel = mocker.MagicMock()
1376+
height_sentinel = mocker.MagicMock()
1377+
width_sentinel = mocker.MagicMock()
1378+
padding_sentinel = mocker.MagicMock()
1379+
mocker.patch(
1380+
"torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params",
1381+
return_value=dict(
1382+
needs_crop=needs_crop,
1383+
top=top_sentinel,
1384+
left=left_sentinel,
1385+
height=height_sentinel,
1386+
width=width_sentinel,
1387+
padding=padding_sentinel,
1388+
needs_pad=needs_pad,
1389+
),
1390+
)
1391+
1392+
inpt_sentinel = mocker.MagicMock()
1393+
1394+
mock_crop = mocker.patch("torchvision.prototype.transforms._geometry.F.crop")
1395+
mock_pad = mocker.patch("torchvision.prototype.transforms._geometry.F.pad")
1396+
transform(inpt_sentinel)
1397+
1398+
if needs_crop:
1399+
mock_crop.assert_called_once_with(
1400+
inpt_sentinel,
1401+
top=top_sentinel,
1402+
left=left_sentinel,
1403+
height=height_sentinel,
1404+
width=width_sentinel,
1405+
)
1406+
else:
1407+
mock_crop.assert_not_called()
1408+
1409+
if needs_pad:
1410+
# If we cropped before, the input to F.pad is no longer inpt_sentinel. Thus, we can't use
1411+
# `MagicMock.assert_called_once_with` and have to perform the checks manually
1412+
mock_pad.assert_called_once()
1413+
args, kwargs = mock_pad.call_args
1414+
if not needs_crop:
1415+
assert args[0] is inpt_sentinel
1416+
assert args[1] is padding_sentinel
1417+
assert kwargs == dict(fill=fill_sentinel, padding_mode=padding_mode_sentinel)
1418+
else:
1419+
mock_pad.assert_not_called()
1420+
1421+
def test__transform_culling(self, mocker):
1422+
batch_size = 10
1423+
image_size = (10, 10)
1424+
1425+
is_valid = torch.randint(0, 2, (batch_size,), dtype=torch.bool)
1426+
mocker.patch(
1427+
"torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params",
1428+
return_value=dict(
1429+
needs_crop=True,
1430+
top=0,
1431+
left=0,
1432+
height=image_size[0],
1433+
width=image_size[1],
1434+
is_valid=is_valid,
1435+
needs_pad=False,
1436+
),
1437+
)
1438+
1439+
bounding_boxes = make_bounding_box(
1440+
format=features.BoundingBoxFormat.XYXY, image_size=image_size, extra_dims=(batch_size,)
1441+
)
1442+
segmentation_masks = make_segmentation_mask(size=image_size, extra_dims=(batch_size,))
1443+
labels = make_label(size=(batch_size,))
1444+
1445+
transform = transforms.FixedSizeCrop((-1, -1))
1446+
mocker.patch("torchvision.prototype.transforms._geometry.has_all", return_value=True)
1447+
mocker.patch("torchvision.prototype.transforms._geometry.has_any", return_value=True)
1448+
1449+
output = transform(
1450+
dict(
1451+
bounding_boxes=bounding_boxes,
1452+
segmentation_masks=segmentation_masks,
1453+
labels=labels,
1454+
)
1455+
)
1456+
1457+
assert_equal(output["bounding_boxes"], bounding_boxes[is_valid])
1458+
assert_equal(output["segmentation_masks"], segmentation_masks[is_valid])
1459+
assert_equal(output["labels"], labels[is_valid])
1460+
1461+
def test__transform_bounding_box_clamping(self, mocker):
1462+
batch_size = 3
1463+
image_size = (10, 10)
1464+
1465+
mocker.patch(
1466+
"torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params",
1467+
return_value=dict(
1468+
needs_crop=True,
1469+
top=0,
1470+
left=0,
1471+
height=image_size[0],
1472+
width=image_size[1],
1473+
is_valid=torch.full((batch_size,), fill_value=True),
1474+
needs_pad=False,
1475+
),
1476+
)
1477+
1478+
bounding_box = make_bounding_box(
1479+
format=features.BoundingBoxFormat.XYXY, image_size=image_size, extra_dims=(batch_size,)
1480+
)
1481+
mock = mocker.patch("torchvision.prototype.transforms._geometry.F.clamp_bounding_box")
1482+
1483+
transform = transforms.FixedSizeCrop((-1, -1))
1484+
mocker.patch("torchvision.prototype.transforms._geometry.has_all", return_value=True)
1485+
mocker.patch("torchvision.prototype.transforms._geometry.has_any", return_value=True)
1486+
1487+
transform(bounding_box)
1488+
1489+
mock.assert_called_once()

torchvision/prototype/transforms/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
CenterCrop,
2121
ElasticTransform,
2222
FiveCrop,
23+
FixedSizeCrop,
2324
Pad,
2425
RandomAffine,
2526
RandomCrop,

torchvision/prototype/transforms/_geometry.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -783,3 +783,100 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
783783

784784
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
785785
return F.resize(inpt, size=params["size"], interpolation=self.interpolation)
786+
787+
788+
class FixedSizeCrop(Transform):
789+
def __init__(
790+
self,
791+
size: Union[int, Sequence[int]],
792+
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
793+
padding_mode: str = "constant",
794+
) -> None:
795+
super().__init__()
796+
size = tuple(_setup_size(size, error_msg="Please provide only two dimensions (h, w) for size."))
797+
self.crop_height = size[0]
798+
self.crop_width = size[1]
799+
self.fill = fill # TODO: Fill is currently respected only on PIL. Apply tensor patch.
800+
self.padding_mode = padding_mode
801+
802+
def _get_params(self, sample: Any) -> Dict[str, Any]:
803+
image = query_image(sample)
804+
_, height, width = get_image_dimensions(image)
805+
new_height = min(height, self.crop_height)
806+
new_width = min(width, self.crop_width)
807+
808+
needs_crop = new_height != height or new_width != width
809+
810+
offset_height = max(height - self.crop_height, 0)
811+
offset_width = max(width - self.crop_width, 0)
812+
813+
r = torch.rand(1)
814+
top = int(offset_height * r)
815+
left = int(offset_width * r)
816+
817+
if needs_crop:
818+
bounding_boxes = query_bounding_box(sample)
819+
bounding_boxes = cast(
820+
features.BoundingBox, F.crop(bounding_boxes, top=top, left=left, height=height, width=width)
821+
)
822+
bounding_boxes = features.BoundingBox.new_like(
823+
bounding_boxes,
824+
F.clamp_bounding_box(
825+
bounding_boxes, format=bounding_boxes.format, image_size=bounding_boxes.image_size
826+
),
827+
)
828+
height_and_width = bounding_boxes.to_format(features.BoundingBoxFormat.XYWH)[..., 2:]
829+
is_valid = torch.all(height_and_width > 0, dim=-1)
830+
else:
831+
is_valid = None
832+
833+
pad_bottom = max(self.crop_height - new_height, 0)
834+
pad_right = max(self.crop_width - new_width, 0)
835+
836+
needs_pad = pad_bottom != 0 or pad_right != 0
837+
838+
return dict(
839+
needs_crop=needs_crop,
840+
top=top,
841+
left=left,
842+
height=new_height,
843+
width=new_width,
844+
is_valid=is_valid,
845+
padding=[0, 0, pad_right, pad_bottom],
846+
needs_pad=needs_pad,
847+
)
848+
849+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
850+
if params["needs_crop"]:
851+
inpt = F.crop(
852+
inpt,
853+
top=params["top"],
854+
left=params["left"],
855+
height=params["height"],
856+
width=params["width"],
857+
)
858+
if isinstance(inpt, (features.Label, features.OneHotLabel, features.SegmentationMask)):
859+
inpt = inpt.new_like(inpt, inpt[params["is_valid"]]) # type: ignore[arg-type]
860+
elif isinstance(inpt, features.BoundingBox):
861+
inpt = features.BoundingBox.new_like(
862+
inpt,
863+
F.clamp_bounding_box(inpt[params["is_valid"]], format=inpt.format, image_size=inpt.image_size),
864+
)
865+
866+
if params["needs_pad"]:
867+
inpt = F.pad(inpt, params["padding"], fill=self.fill, padding_mode=self.padding_mode)
868+
869+
return inpt
870+
871+
def forward(self, *inputs: Any) -> Any:
872+
sample = inputs if len(inputs) > 1 else inputs[0]
873+
if not (
874+
has_all(sample, features.BoundingBox)
875+
and has_any(sample, PIL.Image.Image, features.Image, is_simple_tensor)
876+
and has_any(sample, features.Label, features.OneHotLabel)
877+
):
878+
raise TypeError(
879+
f"{type(self).__name__}() requires input sample to contain Images or PIL Images, "
880+
"BoundingBoxes and Labels or OneHotLabels. Sample can also contain Segmentation Masks."
881+
)
882+
return super().forward(sample)

0 commit comments

Comments
 (0)