From 7b10e07c2eed2e510dece798443328b15592052e Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 15 Aug 2022 05:22:11 +0200 Subject: [PATCH 01/10] port `FixedSizeCrop` from detection references to prototype transforms --- torchvision/prototype/transforms/__init__.py | 1 + torchvision/prototype/transforms/_geometry.py | 57 +++++++++++++++++++ 2 files changed, 58 insertions(+) diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index 5617c010e5f..dd3696e8dbe 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -20,6 +20,7 @@ CenterCrop, ElasticTransform, FiveCrop, + FixedSizeCrop, Pad, RandomAffine, RandomCrop, diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index aa1ca109cc4..4b2e8a48fa7 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -611,3 +611,60 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill=self.fill, interpolation=self.interpolation, ) + + +class FixedSizeCrop(Transform): + def __init__(self, size, fill=0, padding_mode="constant"): + super().__init__() + size = tuple(_setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")) + self.crop_height = size[0] + self.crop_width = size[1] + self.fill = fill # TODO: Fill is currently respected only on PIL. Apply tensor patch. + self.padding_mode = padding_mode + + def _get_params(self, sample: Any) -> Dict[str, Any]: + image = query_image(sample) + _, height, width = get_image_dimensions(image) + new_height = min(height, self.crop_height) + new_width = min(width, self.crop_width) + + needs_crop = new_height != height or new_width != width + + offset_height = max(height - self.crop_height, 0) + offset_width = max(width - self.crop_width, 0) + + r = torch.rand(1) + top = int(offset_height * r) + left = int(offset_width * r) + + pad_bottom = max(self.crop_height - new_height, 0) + pad_right = max(self.crop_width - new_width, 0) + + needs_pad = pad_bottom != 0 or pad_right != 0 + + return dict( + needs_crop=needs_crop, + top=top, + left=left, + height=new_height, + width=new_width, + padding=[0, 0, pad_right, pad_bottom], + needs_pad=needs_pad, + ) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + if params["needs_crop"]: + inpt = F.crop( + inpt, + top=params["top"], + left=params["left"], + height=params["height"], + width=params["width"], + ) + # TODO: cull invalid bounding boxes and labels after we have resolved the preferred way in + # https://github.com/pytorch/vision/pull/6401 + + if params["needs_pad"]: + inpt = F.pad(inpt, params["padding"], fill=self.fill, padding_mode=self.padding_mode) + + return inpt From 279502a5c578ceefd12c651dc03853512fee8f2c Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 15 Aug 2022 05:56:15 +0200 Subject: [PATCH 02/10] mypy --- torchvision/prototype/transforms/_geometry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 4b2e8a48fa7..f9cf358503c 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -614,7 +614,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: class FixedSizeCrop(Transform): - def __init__(self, size, fill=0, padding_mode="constant"): + def __init__(self, size: Union[int, Sequence[int]], fill=0, padding_mode="constant") -> None: super().__init__() size = tuple(_setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")) self.crop_height = size[0] From 105bccd325d67509f2267f2fbeb617549968d516 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 16 Aug 2022 15:32:28 +0200 Subject: [PATCH 03/10] [skip ci] call invalid boxes and corresponding masks and labels --- torchvision/prototype/transforms/_geometry.py | 26 ++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 38aa5c3ae63..b6c9f7b3f8f 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -13,7 +13,7 @@ from typing_extensions import Literal from ._transform import _RandomApplyTransform -from ._utils import get_image_dimensions, has_any, is_simple_tensor, query_image +from ._utils import get_image_dimensions, has_any, is_simple_tensor, query_bboxes, query_image class RandomHorizontalFlip(_RandomApplyTransform): @@ -698,6 +698,20 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: top = int(offset_height * r) left = int(offset_width * r) + if needs_crop: + bounding_boxes = query_bboxes(sample) + bounding_boxes = F.crop(bounding_boxes, top=top, left=left, height=height, width=width) + bounding_boxes = features.BoundingBox.new_like( + bounding_boxes, + F.clamp_bounding_box( + bounding_boxes, format=bounding_boxes.format, image_size=bounding_boxes.image_size + ), + ) + height_and_width = bounding_boxes.to_format(features.BoundingBoxFormat.XYWH)[..., 2:] + is_valid = torch.all(height_and_width > 0, dim=-1) + else: + is_valid = None + pad_bottom = max(self.crop_height - new_height, 0) pad_right = max(self.crop_width - new_width, 0) @@ -709,6 +723,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: left=left, height=new_height, width=new_width, + is_valid=is_valid, padding=[0, 0, pad_right, pad_bottom], needs_pad=needs_pad, ) @@ -722,8 +737,13 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: height=params["height"], width=params["width"], ) - # TODO: cull invalid bounding boxes and labels after we have resolved the preferred way in - # https://github.com/pytorch/vision/pull/6401 + if isinstance(inpt, (features.BoundingBox, features.Label, features.SegmentationMask)): + inpt = inpt[params["is_valid"]] + if isinstance(inpt, features.BoundingBox): + features.BoundingBox.new_like( + inpt, + F.clamp_bounding_box(inpt, format=inpt.format, image_size=inpt.image_size), + ) if params["needs_pad"]: inpt = F.pad(inpt, params["padding"], fill=self.fill, padding_mode=self.padding_mode) From f5a4b2d6e950065c88d8f08bef5b2a1346a3976f Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 17 Aug 2022 14:57:18 +0200 Subject: [PATCH 04/10] cherry-pick missing functions from #6401 --- torchvision/prototype/transforms/_geometry.py | 4 ++-- torchvision/prototype/transforms/_utils.py | 9 +++++++++ torchvision/prototype/transforms/functional/__init__.py | 1 + torchvision/prototype/transforms/functional/_meta.py | 9 +++++++++ 4 files changed, 21 insertions(+), 2 deletions(-) diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index b6c9f7b3f8f..61beea718f9 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -13,7 +13,7 @@ from typing_extensions import Literal from ._transform import _RandomApplyTransform -from ._utils import get_image_dimensions, has_any, is_simple_tensor, query_bboxes, query_image +from ._utils import get_image_dimensions, has_any, is_simple_tensor, query_bounding_box, query_image class RandomHorizontalFlip(_RandomApplyTransform): @@ -699,7 +699,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: left = int(offset_width * r) if needs_crop: - bounding_boxes = query_bboxes(sample) + bounding_boxes = query_bounding_box(sample) bounding_boxes = F.crop(bounding_boxes, top=top, left=left, height=height, width=width) bounding_boxes = features.BoundingBox.new_like( bounding_boxes, diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index 9f2ef84ced5..4cfe1da3649 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -17,6 +17,15 @@ def query_image(sample: Any) -> Union[PIL.Image.Image, torch.Tensor, features.Im raise TypeError("No image was found in the sample") +def query_bounding_box(sample: Any) -> features.BoundingBox: + flat_sample, _ = tree_flatten(sample) + for i in flat_sample: + if isinstance(i, features.BoundingBox): + return i + + raise TypeError("No bounding box was found in the sample") + + def get_image_dimensions(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tuple[int, int, int]: if isinstance(image, features.Image): channels = image.num_channels diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index fee0c4dd1e3..5883cc9119a 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -1,5 +1,6 @@ from torchvision.transforms import InterpolationMode # usort: skip from ._meta import ( + clamp_bounding_box, convert_bounding_box_format, convert_color_space_image_tensor, convert_color_space_image_pil, diff --git a/torchvision/prototype/transforms/functional/_meta.py b/torchvision/prototype/transforms/functional/_meta.py index f1aea2018bc..168a6dfe1b4 100644 --- a/torchvision/prototype/transforms/functional/_meta.py +++ b/torchvision/prototype/transforms/functional/_meta.py @@ -61,6 +61,15 @@ def convert_bounding_box_format( return bounding_box +def clamp_bounding_box( + bounding_box: torch.Tensor, format: BoundingBoxFormat, image_size: Tuple[int, int] +) -> torch.Tensor: + xyxy_boxes = convert_bounding_box_format(bounding_box, format, BoundingBoxFormat.XYXY) + xyxy_boxes[..., 0::2].clamp_(min=0, max=image_size[1]) + xyxy_boxes[..., 1::2].clamp_(min=0, max=image_size[0]) + return convert_bounding_box_format(xyxy_boxes, BoundingBoxFormat.XYXY, format, copy=False) + + def _split_alpha(image: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: return image[..., :-1, :, :], image[..., -1:, :, :] From c9e36e829ec89a215d2d274ffb4d2bcb5a859b0f Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 17 Aug 2022 15:15:54 +0200 Subject: [PATCH 05/10] fix feature wrapping --- torchvision/prototype/transforms/_geometry.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 61beea718f9..f0344a51cb8 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -737,12 +737,12 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: height=params["height"], width=params["width"], ) - if isinstance(inpt, (features.BoundingBox, features.Label, features.SegmentationMask)): - inpt = inpt[params["is_valid"]] - if isinstance(inpt, features.BoundingBox): - features.BoundingBox.new_like( + if isinstance(inpt, (features.Label, features.SegmentationMask)): + inpt = inpt.new_like(inpt, inpt[params["is_valid"]]) + elif isinstance(inpt, features.BoundingBox): + inpt = features.BoundingBox.new_like( inpt, - F.clamp_bounding_box(inpt, format=inpt.format, image_size=inpt.image_size), + F.clamp_bounding_box(inpt[params["is_valid"]], format=inpt.format, image_size=inpt.image_size), ) if params["needs_pad"]: From 31127190f2066ae1cec5abb8b3b5d81163f0c828 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 17 Aug 2022 16:20:59 +0200 Subject: [PATCH 06/10] add test --- test/test_prototype_transforms.py | 125 ++++++++++++++++++++++++++++++ 1 file changed, 125 insertions(+) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 30c68e118e1..fb21b3400b0 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -1209,3 +1209,128 @@ def test__transform(self, mocker): transform(inpt_sentinel) mock.assert_called_once_with(inpt_sentinel, size=size_sentinel, interpolation=interpolation_sentinel) + + +class TestFixedSizeCrop: + def test__get_params(self): + pass + + @pytest.mark.parametrize("needs", list(itertools.product((False, True), repeat=2))) + def test__transform(self, mocker, needs): + fill_sentinel = mocker.MagicMock() + padding_mode_sentinel = mocker.MagicMock() + + transform = transforms.FixedSizeCrop((-1, -1), fill=fill_sentinel, padding_mode=padding_mode_sentinel) + transform._transformed_types = (mocker.MagicMock,) + + needs_crop, needs_pad = needs + top_sentinel = mocker.MagicMock() + left_sentinel = mocker.MagicMock() + height_sentinel = mocker.MagicMock() + width_sentinel = mocker.MagicMock() + padding_sentinel = mocker.MagicMock() + mocker.patch( + "torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params", + return_value=dict( + needs_crop=needs_crop, + top=top_sentinel, + left=left_sentinel, + height=height_sentinel, + width=width_sentinel, + padding=padding_sentinel, + needs_pad=needs_pad, + ), + ) + + inpt_sentinel = mocker.MagicMock() + + mock_crop = mocker.patch("torchvision.prototype.transforms._geometry.F.crop") + mock_pad = mocker.patch("torchvision.prototype.transforms._geometry.F.pad") + transform(inpt_sentinel) + + if needs_crop: + mock_crop.assert_called_once_with( + inpt_sentinel, + top=top_sentinel, + left=left_sentinel, + height=height_sentinel, + width=width_sentinel, + ) + else: + mock_crop.assert_not_called() + + if needs_pad: + # If we cropped before, the input to F.pad is no longer inpt_sentinel. Thus, we can't use + # `MagicMock.assert_called_once_with` and have to perform the checks manually + mock_pad.assert_called_once() + args, kwargs = mock_pad.call_args + if not needs_crop: + assert args[0] is inpt_sentinel + assert args[1] is padding_sentinel + assert kwargs == dict(fill=fill_sentinel, padding_mode=padding_mode_sentinel) + else: + mock_pad.assert_not_called() + + def test__transform_culling(self, mocker): + batch_size = 10 + image_size = (10, 10) + + is_valid = torch.randint(0, 2, (batch_size,), dtype=torch.bool) + mocker.patch( + "torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params", + return_value=dict( + needs_crop=True, + top=0, + left=0, + height=image_size[0], + width=image_size[1], + is_valid=is_valid, + needs_pad=False, + ), + ) + + bounding_boxes = make_bounding_box( + format=features.BoundingBoxFormat.XYXY, image_size=image_size, extra_dims=(batch_size,) + ) + segmentation_masks = make_segmentation_mask(size=image_size, extra_dims=(batch_size,)) + labels = make_label(size=(batch_size,)) + + transform = transforms.FixedSizeCrop((-1, -1)) + output = transform( + dict( + bounding_boxes=bounding_boxes, + segmentation_masks=segmentation_masks, + labels=labels, + ) + ) + + assert_equal(output["bounding_boxes"], bounding_boxes[is_valid]) + assert_equal(output["segmentation_masks"], segmentation_masks[is_valid]) + assert_equal(output["labels"], labels[is_valid]) + + def test__transform_bounding_box_clamping(self, mocker): + batch_size = 3 + image_size = (10, 10) + + mocker.patch( + "torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params", + return_value=dict( + needs_crop=True, + top=0, + left=0, + height=image_size[0], + width=image_size[1], + is_valid=torch.full((batch_size,), fill_value=True), + needs_pad=False, + ), + ) + + bounding_box = make_bounding_box( + format=features.BoundingBoxFormat.XYXY, image_size=image_size, extra_dims=(batch_size,) + ) + mock = mocker.patch("torchvision.prototype.transforms._geometry.F.clamp_bounding_box") + + transform = transforms.FixedSizeCrop((-1, -1)) + transform(bounding_box) + + mock.assert_called_once() From d5d456a909f5d4f9fa1a872448aa0d72bd85e6e5 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 17 Aug 2022 16:31:48 +0200 Subject: [PATCH 07/10] mypy --- torchvision/prototype/transforms/_geometry.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index f0344a51cb8..3ff273ccf02 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -675,7 +675,12 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: class FixedSizeCrop(Transform): - def __init__(self, size: Union[int, Sequence[int]], fill=0, padding_mode="constant") -> None: + def __init__( + self, + size: Union[int, Sequence[int]], + fill: Union[int, float, Sequence[int], Sequence[float]] = 0, + padding_mode: str = "constant", + ) -> None: super().__init__() size = tuple(_setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")) self.crop_height = size[0] @@ -700,7 +705,9 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: if needs_crop: bounding_boxes = query_bounding_box(sample) - bounding_boxes = F.crop(bounding_boxes, top=top, left=left, height=height, width=width) + bounding_boxes = cast( + features.BoundingBox, F.crop(bounding_boxes, top=top, left=left, height=height, width=width) + ) bounding_boxes = features.BoundingBox.new_like( bounding_boxes, F.clamp_bounding_box( @@ -738,7 +745,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: width=params["width"], ) if isinstance(inpt, (features.Label, features.SegmentationMask)): - inpt = inpt.new_like(inpt, inpt[params["is_valid"]]) + inpt = inpt.new_like(inpt, inpt[params["is_valid"]]) # type: ignore[arg-type] elif isinstance(inpt, features.BoundingBox): inpt = features.BoundingBox.new_like( inpt, From 24d584fc4db874a730b4c8ca665f0bc7796a640a Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 18 Aug 2022 09:35:21 +0200 Subject: [PATCH 08/10] add input type restrictions --- test/test_prototype_transforms.py | 8 ++++++++ torchvision/prototype/transforms/_geometry.py | 18 ++++++++++++++++-- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index fb21b3400b0..fc2af533666 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -1222,6 +1222,8 @@ def test__transform(self, mocker, needs): transform = transforms.FixedSizeCrop((-1, -1), fill=fill_sentinel, padding_mode=padding_mode_sentinel) transform._transformed_types = (mocker.MagicMock,) + mocker.patch("torchvision.prototype.transforms._geometry.has_all", return_value=True) + mocker.patch("torchvision.prototype.transforms._geometry.has_any", return_value=True) needs_crop, needs_pad = needs top_sentinel = mocker.MagicMock() @@ -1296,6 +1298,9 @@ def test__transform_culling(self, mocker): labels = make_label(size=(batch_size,)) transform = transforms.FixedSizeCrop((-1, -1)) + mocker.patch("torchvision.prototype.transforms._geometry.has_all", return_value=True) + mocker.patch("torchvision.prototype.transforms._geometry.has_any", return_value=True) + output = transform( dict( bounding_boxes=bounding_boxes, @@ -1331,6 +1336,9 @@ def test__transform_bounding_box_clamping(self, mocker): mock = mocker.patch("torchvision.prototype.transforms._geometry.F.clamp_bounding_box") transform = transforms.FixedSizeCrop((-1, -1)) + mocker.patch("torchvision.prototype.transforms._geometry.has_all", return_value=True) + mocker.patch("torchvision.prototype.transforms._geometry.has_any", return_value=True) + transform(bounding_box) mock.assert_called_once() diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 3ff273ccf02..53e0625577e 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -13,7 +13,7 @@ from typing_extensions import Literal from ._transform import _RandomApplyTransform -from ._utils import get_image_dimensions, has_any, is_simple_tensor, query_bounding_box, query_image +from ._utils import get_image_dimensions, has_all, has_any, is_simple_tensor, query_bounding_box, query_image class RandomHorizontalFlip(_RandomApplyTransform): @@ -744,7 +744,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: height=params["height"], width=params["width"], ) - if isinstance(inpt, (features.Label, features.SegmentationMask)): + if isinstance(inpt, (features.Label, features.OneHotLabel, features.SegmentationMask)): inpt = inpt.new_like(inpt, inpt[params["is_valid"]]) # type: ignore[arg-type] elif isinstance(inpt, features.BoundingBox): inpt = features.BoundingBox.new_like( @@ -756,3 +756,17 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: inpt = F.pad(inpt, params["padding"], fill=self.fill, padding_mode=self.padding_mode) return inpt + + def forward(self, *inputs: Any) -> Any: + # FIXME: revisit after https://github.com/pytorch/vision/pull/6401#discussion_r948749012 is resolved + sample = inputs if len(inputs) > 1 else inputs[0] + if not ( + has_all(sample, features.BoundingBox) + and has_any(sample, PIL.Image.Image, features.Image) + and has_any(sample, features.Label, features.OneHotLabel) + ): + raise TypeError( + f"{type(self).__name__}() requires input sample to contain Images or PIL Images, " + "BoundingBoxes and Labels or OneHotLabels. Sample can also contain Segmentation Masks." + ) + return super().forward(sample) From 8a11b7eba88092f83d328f97a09874e5b26fe3f7 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 18 Aug 2022 09:50:32 +0200 Subject: [PATCH 09/10] add test for _get_params --- test/test_prototype_transforms.py | 30 ++++++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index fc2af533666..11a7580c682 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -10,6 +10,7 @@ from test_prototype_transforms_functional import ( make_bounding_box, make_bounding_boxes, + make_image, make_images, make_label, make_one_hot_labels, @@ -1212,8 +1213,33 @@ def test__transform(self, mocker): class TestFixedSizeCrop: - def test__get_params(self): - pass + def test__get_params(self, mocker): + crop_size = (7, 7) + batch_shape = (10,) + image_size = (11, 5) + + transform = transforms.FixedSizeCrop(size=crop_size) + + sample = dict( + image=make_image(size=image_size, color_space=features.ColorSpace.RGB), + bounding_boxes=make_bounding_box( + format=features.BoundingBoxFormat.XYXY, image_size=image_size, extra_dims=batch_shape + ), + ) + params = transform._get_params(sample) + + assert params["needs_crop"] + assert params["height"] <= crop_size[0] + assert params["width"] <= crop_size[1] + + assert ( + isinstance(params["is_valid"], torch.Tensor) + and params["is_valid"].dtype is torch.bool + and params["is_valid"].shape == batch_shape + ) + + assert params["needs_pad"] + assert any(pad > 0 for pad in params["padding"]) @pytest.mark.parametrize("needs", list(itertools.product((False, True), repeat=2))) def test__transform(self, mocker, needs): From 435d5b5911ca49a3d87e3642c9d54bbb5b0336c9 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 18 Aug 2022 17:00:14 +0200 Subject: [PATCH 10/10] fix input checks --- torchvision/prototype/transforms/_geometry.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 6bb36d97821..ed8e4cc87ea 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -869,11 +869,10 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return inpt def forward(self, *inputs: Any) -> Any: - # FIXME: revisit after https://github.com/pytorch/vision/pull/6401#discussion_r948749012 is resolved sample = inputs if len(inputs) > 1 else inputs[0] if not ( has_all(sample, features.BoundingBox) - and has_any(sample, PIL.Image.Image, features.Image) + and has_any(sample, PIL.Image.Image, features.Image, is_simple_tensor) and has_any(sample, features.Label, features.OneHotLabel) ): raise TypeError(