diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 8839d842b85..6afc7b95b7e 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -10,6 +10,7 @@ from test_prototype_transforms_functional import ( make_bounding_box, make_bounding_boxes, + make_image, make_images, make_label, make_one_hot_labels, @@ -1328,3 +1329,161 @@ def test__transform(self, mocker): transform(inpt_sentinel) mock.assert_called_once_with(inpt_sentinel, size=size_sentinel, interpolation=interpolation_sentinel) + + +class TestFixedSizeCrop: + def test__get_params(self, mocker): + crop_size = (7, 7) + batch_shape = (10,) + image_size = (11, 5) + + transform = transforms.FixedSizeCrop(size=crop_size) + + sample = dict( + image=make_image(size=image_size, color_space=features.ColorSpace.RGB), + bounding_boxes=make_bounding_box( + format=features.BoundingBoxFormat.XYXY, image_size=image_size, extra_dims=batch_shape + ), + ) + params = transform._get_params(sample) + + assert params["needs_crop"] + assert params["height"] <= crop_size[0] + assert params["width"] <= crop_size[1] + + assert ( + isinstance(params["is_valid"], torch.Tensor) + and params["is_valid"].dtype is torch.bool + and params["is_valid"].shape == batch_shape + ) + + assert params["needs_pad"] + assert any(pad > 0 for pad in params["padding"]) + + @pytest.mark.parametrize("needs", list(itertools.product((False, True), repeat=2))) + def test__transform(self, mocker, needs): + fill_sentinel = mocker.MagicMock() + padding_mode_sentinel = mocker.MagicMock() + + transform = transforms.FixedSizeCrop((-1, -1), fill=fill_sentinel, padding_mode=padding_mode_sentinel) + transform._transformed_types = (mocker.MagicMock,) + mocker.patch("torchvision.prototype.transforms._geometry.has_all", return_value=True) + mocker.patch("torchvision.prototype.transforms._geometry.has_any", return_value=True) + + needs_crop, needs_pad = needs + top_sentinel = mocker.MagicMock() + left_sentinel = mocker.MagicMock() + height_sentinel = mocker.MagicMock() + width_sentinel = mocker.MagicMock() + padding_sentinel = mocker.MagicMock() + mocker.patch( + "torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params", + return_value=dict( + needs_crop=needs_crop, + top=top_sentinel, + left=left_sentinel, + height=height_sentinel, + width=width_sentinel, + padding=padding_sentinel, + needs_pad=needs_pad, + ), + ) + + inpt_sentinel = mocker.MagicMock() + + mock_crop = mocker.patch("torchvision.prototype.transforms._geometry.F.crop") + mock_pad = mocker.patch("torchvision.prototype.transforms._geometry.F.pad") + transform(inpt_sentinel) + + if needs_crop: + mock_crop.assert_called_once_with( + inpt_sentinel, + top=top_sentinel, + left=left_sentinel, + height=height_sentinel, + width=width_sentinel, + ) + else: + mock_crop.assert_not_called() + + if needs_pad: + # If we cropped before, the input to F.pad is no longer inpt_sentinel. Thus, we can't use + # `MagicMock.assert_called_once_with` and have to perform the checks manually + mock_pad.assert_called_once() + args, kwargs = mock_pad.call_args + if not needs_crop: + assert args[0] is inpt_sentinel + assert args[1] is padding_sentinel + assert kwargs == dict(fill=fill_sentinel, padding_mode=padding_mode_sentinel) + else: + mock_pad.assert_not_called() + + def test__transform_culling(self, mocker): + batch_size = 10 + image_size = (10, 10) + + is_valid = torch.randint(0, 2, (batch_size,), dtype=torch.bool) + mocker.patch( + "torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params", + return_value=dict( + needs_crop=True, + top=0, + left=0, + height=image_size[0], + width=image_size[1], + is_valid=is_valid, + needs_pad=False, + ), + ) + + bounding_boxes = make_bounding_box( + format=features.BoundingBoxFormat.XYXY, image_size=image_size, extra_dims=(batch_size,) + ) + segmentation_masks = make_segmentation_mask(size=image_size, extra_dims=(batch_size,)) + labels = make_label(size=(batch_size,)) + + transform = transforms.FixedSizeCrop((-1, -1)) + mocker.patch("torchvision.prototype.transforms._geometry.has_all", return_value=True) + mocker.patch("torchvision.prototype.transforms._geometry.has_any", return_value=True) + + output = transform( + dict( + bounding_boxes=bounding_boxes, + segmentation_masks=segmentation_masks, + labels=labels, + ) + ) + + assert_equal(output["bounding_boxes"], bounding_boxes[is_valid]) + assert_equal(output["segmentation_masks"], segmentation_masks[is_valid]) + assert_equal(output["labels"], labels[is_valid]) + + def test__transform_bounding_box_clamping(self, mocker): + batch_size = 3 + image_size = (10, 10) + + mocker.patch( + "torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params", + return_value=dict( + needs_crop=True, + top=0, + left=0, + height=image_size[0], + width=image_size[1], + is_valid=torch.full((batch_size,), fill_value=True), + needs_pad=False, + ), + ) + + bounding_box = make_bounding_box( + format=features.BoundingBoxFormat.XYXY, image_size=image_size, extra_dims=(batch_size,) + ) + mock = mocker.patch("torchvision.prototype.transforms._geometry.F.clamp_bounding_box") + + transform = transforms.FixedSizeCrop((-1, -1)) + mocker.patch("torchvision.prototype.transforms._geometry.has_all", return_value=True) + mocker.patch("torchvision.prototype.transforms._geometry.has_any", return_value=True) + + transform(bounding_box) + + mock.assert_called_once() diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index e1ba20904fe..aa920aa2ef2 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -20,6 +20,7 @@ CenterCrop, ElasticTransform, FiveCrop, + FixedSizeCrop, Pad, RandomAffine, RandomCrop, diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 32f220f2f9f..ed8e4cc87ea 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -783,3 +783,100 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return F.resize(inpt, size=params["size"], interpolation=self.interpolation) + + +class FixedSizeCrop(Transform): + def __init__( + self, + size: Union[int, Sequence[int]], + fill: Union[int, float, Sequence[int], Sequence[float]] = 0, + padding_mode: str = "constant", + ) -> None: + super().__init__() + size = tuple(_setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")) + self.crop_height = size[0] + self.crop_width = size[1] + self.fill = fill # TODO: Fill is currently respected only on PIL. Apply tensor patch. + self.padding_mode = padding_mode + + def _get_params(self, sample: Any) -> Dict[str, Any]: + image = query_image(sample) + _, height, width = get_image_dimensions(image) + new_height = min(height, self.crop_height) + new_width = min(width, self.crop_width) + + needs_crop = new_height != height or new_width != width + + offset_height = max(height - self.crop_height, 0) + offset_width = max(width - self.crop_width, 0) + + r = torch.rand(1) + top = int(offset_height * r) + left = int(offset_width * r) + + if needs_crop: + bounding_boxes = query_bounding_box(sample) + bounding_boxes = cast( + features.BoundingBox, F.crop(bounding_boxes, top=top, left=left, height=height, width=width) + ) + bounding_boxes = features.BoundingBox.new_like( + bounding_boxes, + F.clamp_bounding_box( + bounding_boxes, format=bounding_boxes.format, image_size=bounding_boxes.image_size + ), + ) + height_and_width = bounding_boxes.to_format(features.BoundingBoxFormat.XYWH)[..., 2:] + is_valid = torch.all(height_and_width > 0, dim=-1) + else: + is_valid = None + + pad_bottom = max(self.crop_height - new_height, 0) + pad_right = max(self.crop_width - new_width, 0) + + needs_pad = pad_bottom != 0 or pad_right != 0 + + return dict( + needs_crop=needs_crop, + top=top, + left=left, + height=new_height, + width=new_width, + is_valid=is_valid, + padding=[0, 0, pad_right, pad_bottom], + needs_pad=needs_pad, + ) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + if params["needs_crop"]: + inpt = F.crop( + inpt, + top=params["top"], + left=params["left"], + height=params["height"], + width=params["width"], + ) + if isinstance(inpt, (features.Label, features.OneHotLabel, features.SegmentationMask)): + inpt = inpt.new_like(inpt, inpt[params["is_valid"]]) # type: ignore[arg-type] + elif isinstance(inpt, features.BoundingBox): + inpt = features.BoundingBox.new_like( + inpt, + F.clamp_bounding_box(inpt[params["is_valid"]], format=inpt.format, image_size=inpt.image_size), + ) + + if params["needs_pad"]: + inpt = F.pad(inpt, params["padding"], fill=self.fill, padding_mode=self.padding_mode) + + return inpt + + def forward(self, *inputs: Any) -> Any: + sample = inputs if len(inputs) > 1 else inputs[0] + if not ( + has_all(sample, features.BoundingBox) + and has_any(sample, PIL.Image.Image, features.Image, is_simple_tensor) + and has_any(sample, features.Label, features.OneHotLabel) + ): + raise TypeError( + f"{type(self).__name__}() requires input sample to contain Images or PIL Images, " + "BoundingBoxes and Labels or OneHotLabels. Sample can also contain Segmentation Masks." + ) + return super().forward(sample)