From df6918cc653220fc465abe8aa4b5622e70c60de2 Mon Sep 17 00:00:00 2001 From: vfdev Date: Tue, 12 Jul 2022 22:50:09 +0200 Subject: [PATCH 01/13] [proto] Added few transforms tests, part 1 (#6262) * Added supported/unsupported data checks in the tests for cutmix/mixup * Added RandomRotation, RandomAffine transforms tests * Added tests for RandomZoomOut, Pad * Update test_prototype_transforms.py --- test/test_prototype_transforms.py | 265 +++++++++++++++++- torchvision/prototype/transforms/_geometry.py | 14 +- torchvision/transforms/transforms.py | 2 +- 3 files changed, 273 insertions(+), 8 deletions(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index eb92af41071..899835ba276 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -6,10 +6,13 @@ from test_prototype_transforms_functional import ( make_images, make_bounding_boxes, + make_bounding_box, make_one_hot_labels, + make_label, + make_segmentation_mask, ) from torchvision.prototype import transforms, features -from torchvision.transforms.functional import to_pil_image, pil_to_tensor +from torchvision.transforms.functional import to_pil_image, pil_to_tensor, InterpolationMode def make_vanilla_tensor_images(*args, **kwargs): @@ -106,6 +109,20 @@ def test_common(self, transform, input): def test_mixup_cutmix(self, transform, input): transform(input) + # add other data that should bypass and wont raise any error + input_copy = dict(input) + input_copy["path"] = "/path/to/somewhere" + input_copy["num"] = 1234 + transform(input_copy) + + # Check if we raise an error if sample contains bbox or mask or label + err_msg = "does not support bounding boxes, segmentation masks and plain labels" + input_copy = dict(input) + for unsup_data in [make_label(), make_bounding_box(format="XYXY"), make_segmentation_mask()]: + input_copy["unsupported"] = unsup_data + with pytest.raises(TypeError, match=err_msg): + transform(input_copy) + @parametrize( [ ( @@ -303,3 +320,249 @@ def test_features_bounding_box(self, p): assert_equal(expected, actual) assert actual.format == expected.format assert actual.image_size == expected.image_size + + +class TestPad: + def test_assertions(self): + with pytest.raises(TypeError, match="Got inappropriate padding arg"): + transforms.Pad("abc") + + with pytest.raises(ValueError, match="Padding must be an int or a 1, 2, or 4"): + transforms.Pad([-0.7, 0, 0.7]) + + with pytest.raises(TypeError, match="Got inappropriate fill arg"): + transforms.Pad(12, fill="abc") + + with pytest.raises(ValueError, match="Padding mode should be either"): + transforms.Pad(12, padding_mode="abc") + + @pytest.mark.parametrize("padding", [1, (1, 2), [1, 2, 3, 4]]) + @pytest.mark.parametrize("fill", [0, [1, 2, 3], (2, 3, 4)]) + @pytest.mark.parametrize("padding_mode", ["constant", "edge"]) + def test__transform(self, padding, fill, padding_mode, mocker): + transform = transforms.Pad(padding, fill=fill, padding_mode=padding_mode) + + fn = mocker.patch("torchvision.prototype.transforms.functional.pad") + inpt = mocker.MagicMock(spec=torch.Tensor) + _ = transform(inpt) + + fn.assert_called_once_with(inpt, padding=padding, fill=fill, padding_mode=padding_mode) + + +class TestRandomZoomOut: + def test_assertions(self): + with pytest.raises(TypeError, match="Got inappropriate fill arg"): + transforms.RandomZoomOut(fill="abc") + + with pytest.raises(TypeError, match="should be a sequence of length"): + transforms.RandomZoomOut(0, side_range=0) + + with pytest.raises(ValueError, match="Invalid canvas side range"): + transforms.RandomZoomOut(0, side_range=[4.0, 1.0]) + + @pytest.mark.parametrize("fill", [0, [1, 2, 3], (2, 3, 4)]) + @pytest.mark.parametrize("side_range", [(1.0, 4.0), [2.0, 5.0]]) + def test__get_params(self, fill, side_range): + transform = transforms.RandomZoomOut(fill=fill, side_range=side_range) + + image = features.Image(torch.rand(1, 3, 32, 32)) + c, h, w = image.shape[-3:] + + params = transform._get_params(image) + + assert params["fill"] == (fill if not isinstance(fill, int) else [fill] * c) + assert len(params["padding"]) == 4 + assert 0 <= params["padding"][0] <= (side_range[1] - 1) * w + assert 0 <= params["padding"][1] <= (side_range[1] - 1) * h + assert 0 <= params["padding"][2] <= (side_range[1] - 1) * w + assert 0 <= params["padding"][3] <= (side_range[1] - 1) * h + + @pytest.mark.parametrize("fill", [0, [1, 2, 3], (2, 3, 4)]) + @pytest.mark.parametrize("side_range", [(1.0, 4.0), [2.0, 5.0]]) + def test__transform(self, fill, side_range, mocker): + image = features.Image(torch.rand(1, 3, 32, 32)) + transform = transforms.RandomZoomOut(fill=fill, side_range=side_range, p=1) + + fn = mocker.patch("torchvision.prototype.transforms.functional.pad") + # vfdev-5, Feature Request: let's store params as Transform attribute + # This could be also helpful for users + torch.manual_seed(12) + _ = transform(image) + torch.manual_seed(12) + torch.rand(1) # random apply changes random state + params = transform._get_params(image) + + fn.assert_called_once_with(image, **params) + + +class TestRandomRotation: + def test_assertions(self): + with pytest.raises(ValueError, match="is a single number, it must be positive"): + transforms.RandomRotation(-0.7) + + for d in [[-0.7], [-0.7, 0, 0.7]]: + with pytest.raises(ValueError, match="degrees should be a sequence of length 2"): + transforms.RandomRotation(d) + + with pytest.raises(TypeError, match="Got inappropriate fill arg"): + transforms.RandomRotation(12, fill="abc") + + with pytest.raises(TypeError, match="center should be a sequence of length"): + transforms.RandomRotation(12, center=12) + + with pytest.raises(ValueError, match="center should be a sequence of length"): + transforms.RandomRotation(12, center=[1, 2, 3]) + + def test__get_params(self): + angle_bound = 34 + transform = transforms.RandomRotation(angle_bound) + + params = transform._get_params(None) + assert -angle_bound <= params["angle"] <= angle_bound + + angle_bounds = [12, 34] + transform = transforms.RandomRotation(angle_bounds) + + params = transform._get_params(None) + assert angle_bounds[0] <= params["angle"] <= angle_bounds[1] + + @pytest.mark.parametrize("degrees", [23, [0, 45], (0, 45)]) + @pytest.mark.parametrize("expand", [False, True]) + @pytest.mark.parametrize("fill", [0, [1, 2, 3], (2, 3, 4)]) + @pytest.mark.parametrize("center", [None, [2.0, 3.0]]) + def test__transform(self, degrees, expand, fill, center, mocker): + interpolation = InterpolationMode.BILINEAR + transform = transforms.RandomRotation( + degrees, interpolation=interpolation, expand=expand, fill=fill, center=center + ) + + if isinstance(degrees, (tuple, list)): + assert transform.degrees == [float(degrees[0]), float(degrees[1])] + else: + assert transform.degrees == [float(-degrees), float(degrees)] + + fn = mocker.patch("torchvision.prototype.transforms.functional.rotate") + inpt = mocker.MagicMock(spec=torch.Tensor) + # vfdev-5, Feature Request: let's store params as Transform attribute + # This could be also helpful for users + torch.manual_seed(12) + _ = transform(inpt) + torch.manual_seed(12) + params = transform._get_params(inpt) + + fn.assert_called_once_with(inpt, **params, interpolation=interpolation, expand=expand, fill=fill, center=center) + + +class TestRandomAffine: + def test_assertions(self): + with pytest.raises(ValueError, match="is a single number, it must be positive"): + transforms.RandomAffine(-0.7) + + for d in [[-0.7], [-0.7, 0, 0.7]]: + with pytest.raises(ValueError, match="degrees should be a sequence of length 2"): + transforms.RandomAffine(d) + + with pytest.raises(TypeError, match="Got inappropriate fill arg"): + transforms.RandomAffine(12, fill="abc") + + with pytest.raises(TypeError, match="Got inappropriate fill arg"): + transforms.RandomAffine(12, fill="abc") + + for kwargs in [ + {"center": 12}, + {"translate": 12}, + {"scale": 12}, + ]: + with pytest.raises(TypeError, match="should be a sequence of length"): + transforms.RandomAffine(12, **kwargs) + + for kwargs in [{"center": [1, 2, 3]}, {"translate": [1, 2, 3]}, {"scale": [1, 2, 3]}]: + with pytest.raises(ValueError, match="should be a sequence of length"): + transforms.RandomAffine(12, **kwargs) + + with pytest.raises(ValueError, match="translation values should be between 0 and 1"): + transforms.RandomAffine(12, translate=[-1.0, 2.0]) + + with pytest.raises(ValueError, match="scale values should be positive"): + transforms.RandomAffine(12, scale=[-1.0, 2.0]) + + with pytest.raises(ValueError, match="is a single number, it must be positive"): + transforms.RandomAffine(12, shear=-10) + + for s in [[-0.7], [-0.7, 0, 0.7]]: + with pytest.raises(ValueError, match="shear should be a sequence of length 2"): + transforms.RandomAffine(12, shear=s) + + @pytest.mark.parametrize("degrees", [23, [0, 45], (0, 45)]) + @pytest.mark.parametrize("translate", [None, [0.1, 0.2]]) + @pytest.mark.parametrize("scale", [None, [0.7, 1.2]]) + @pytest.mark.parametrize("shear", [None, 2.0, [5.0, 15.0], [1.0, 2.0, 3.0, 4.0]]) + def test__get_params(self, degrees, translate, scale, shear): + image = features.Image(torch.rand(1, 3, 32, 32)) + h, w = image.shape[-2:] + + transform = transforms.RandomAffine(degrees, translate=translate, scale=scale, shear=shear) + params = transform._get_params(image) + + if not isinstance(degrees, (list, tuple)): + assert -degrees <= params["angle"] <= degrees + else: + assert degrees[0] <= params["angle"] <= degrees[1] + + if translate is not None: + assert -translate[0] * w <= params["translations"][0] <= translate[0] * w + assert -translate[1] * h <= params["translations"][1] <= translate[1] * h + else: + assert params["translations"] == (0, 0) + + if scale is not None: + assert scale[0] <= params["scale"] <= scale[1] + else: + assert params["scale"] == 1.0 + + if shear is not None: + if isinstance(shear, float): + assert -shear <= params["shear"][0] <= shear + assert params["shear"][1] == 0.0 + elif len(shear) == 2: + assert shear[0] <= params["shear"][0] <= shear[1] + assert params["shear"][1] == 0.0 + else: + assert shear[0] <= params["shear"][0] <= shear[1] + assert shear[2] <= params["shear"][1] <= shear[3] + else: + assert params["shear"] == (0, 0) + + @pytest.mark.parametrize("degrees", [23, [0, 45], (0, 45)]) + @pytest.mark.parametrize("translate", [None, [0.1, 0.2]]) + @pytest.mark.parametrize("scale", [None, [0.7, 1.2]]) + @pytest.mark.parametrize("shear", [None, 2.0, [5.0, 15.0], [1.0, 2.0, 3.0, 4.0]]) + @pytest.mark.parametrize("fill", [0, [1, 2, 3], (2, 3, 4)]) + @pytest.mark.parametrize("center", [None, [2.0, 3.0]]) + def test__transform(self, degrees, translate, scale, shear, fill, center, mocker): + interpolation = InterpolationMode.BILINEAR + transform = transforms.RandomAffine( + degrees, + translate=translate, + scale=scale, + shear=shear, + interpolation=interpolation, + fill=fill, + center=center, + ) + + if isinstance(degrees, (tuple, list)): + assert transform.degrees == [float(degrees[0]), float(degrees[1])] + else: + assert transform.degrees == [float(-degrees), float(degrees)] + + fn = mocker.patch("torchvision.prototype.transforms.functional.affine") + inpt = features.Image(torch.rand(1, 3, 32, 32)) + # vfdev-5, Feature Request: let's store params as Transform attribute + # This could be also helpful for users + torch.manual_seed(12) + _ = transform(inpt) + torch.manual_seed(12) + params = transform._get_params(inpt) + + fn.assert_called_once_with(inpt, **params, interpolation=interpolation, fill=fill, center=center) diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index d4162b2b631..fd14ac0296b 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -236,16 +236,16 @@ def __init__( if not isinstance(padding, (numbers.Number, tuple, list)): raise TypeError("Got inappropriate padding arg") + if isinstance(padding, (tuple, list)) and len(padding) not in [1, 2, 4]: + raise ValueError( + f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple" + ) + _check_fill_arg(fill) if padding_mode not in ["constant", "edge", "reflect", "symmetric"]: raise ValueError("Padding mode should be either constant, edge, reflect or symmetric") - if isinstance(padding, Sequence) and len(padding) not in [1, 2, 4]: - raise ValueError( - f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple" - ) - self.padding = padding self.fill = fill self.padding_mode = padding_mode @@ -258,7 +258,7 @@ class RandomZoomOut(_RandomApplyTransform): def __init__( self, fill: Union[int, float, Sequence[int], Sequence[float]] = 0, - side_range: Tuple[float, float] = (1.0, 4.0), + side_range: Sequence[float] = (1.0, 4.0), p: float = 0.5, ) -> None: super().__init__(p=p) @@ -266,6 +266,8 @@ def __init__( _check_fill_arg(fill) self.fill = fill + _check_sequence_input(side_range, "side_range", req_sizes=(2,)) + self.side_range = side_range if side_range[0] < 1.0 or side_range[0] > side_range[1]: raise ValueError(f"Invalid canvas side range provided {side_range}.") diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 56f69e82033..cf119759982 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -1855,7 +1855,7 @@ def _check_sequence_input(x, name, req_sizes): if not isinstance(x, Sequence): raise TypeError(f"{name} should be a sequence of length {msg}.") if len(x) not in req_sizes: - raise ValueError(f"{name} should be sequence of length {msg}.") + raise ValueError(f"{name} should be a sequence of length {msg}.") def _setup_angle(x, name, req_sizes=(2,)): From 615b175e346c6a29382cce562762d6dc661f5c95 Mon Sep 17 00:00:00 2001 From: vfdev Date: Thu, 14 Jul 2022 16:30:36 +0200 Subject: [PATCH 02/13] Added RandomCrop transform and tests (#6271) --- test/test_prototype_transforms.py | 78 ++++++++++++++ torchvision/prototype/transforms/__init__.py | 1 + torchvision/prototype/transforms/_geometry.py | 101 +++++++++++++++--- torchvision/prototype/transforms/_utils.py | 13 ++- 4 files changed, 180 insertions(+), 13 deletions(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 899835ba276..d561705fdfe 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -82,6 +82,7 @@ class TestSmoke: transforms.RandomZoomOut(), transforms.RandomRotation(degrees=(-45, 45)), transforms.RandomAffine(degrees=(-45, 45)), + transforms.RandomCrop([16, 16], padding=1, pad_if_needed=True), ) def test_common(self, transform, input): transform(input) @@ -566,3 +567,80 @@ def test__transform(self, degrees, translate, scale, shear, fill, center, mocker params = transform._get_params(inpt) fn.assert_called_once_with(inpt, **params, interpolation=interpolation, fill=fill, center=center) + + +class TestRandomCrop: + def test_assertions(self): + with pytest.raises(ValueError, match="Please provide only two dimensions"): + transforms.RandomCrop([10, 12, 14]) + + with pytest.raises(TypeError, match="Got inappropriate padding arg"): + transforms.RandomCrop([10, 12], padding="abc") + + with pytest.raises(ValueError, match="Padding must be an int or a 1, 2, or 4"): + transforms.RandomCrop([10, 12], padding=[-0.7, 0, 0.7]) + + with pytest.raises(TypeError, match="Got inappropriate fill arg"): + transforms.RandomCrop([10, 12], padding=1, fill="abc") + + with pytest.raises(ValueError, match="Padding mode should be either"): + transforms.RandomCrop([10, 12], padding=1, padding_mode="abc") + + def test__get_params(self): + image = features.Image(torch.rand(1, 3, 32, 32)) + h, w = image.shape[-2:] + + transform = transforms.RandomCrop([10, 10]) + params = transform._get_params(image) + + assert 0 <= params["top"] <= h - transform.size[0] + 1 + assert 0 <= params["left"] <= w - transform.size[1] + 1 + assert params["height"] == 10 + assert params["width"] == 10 + + @pytest.mark.parametrize("padding", [None, 1, [2, 3], [1, 2, 3, 4]]) + @pytest.mark.parametrize("pad_if_needed", [False, True]) + @pytest.mark.parametrize("fill", [False, True]) + @pytest.mark.parametrize("padding_mode", ["constant", "edge"]) + def test_forward(self, padding, pad_if_needed, fill, padding_mode, mocker): + output_size = [10, 12] + transform = transforms.RandomCrop( + output_size, padding=padding, pad_if_needed=pad_if_needed, fill=fill, padding_mode=padding_mode + ) + + inpt = features.Image(torch.rand(1, 3, 32, 32)) + expected = mocker.MagicMock(spec=features.Image) + expected.num_channels = 3 + if isinstance(padding, int): + expected.image_size = (inpt.image_size[0] + padding, inpt.image_size[1] + padding) + elif isinstance(padding, list): + expected.image_size = ( + inpt.image_size[0] + sum(padding[0::2]), + inpt.image_size[1] + sum(padding[1::2]), + ) + else: + expected.image_size = inpt.image_size + _ = mocker.patch("torchvision.prototype.transforms.functional.pad", return_value=expected) + fn_crop = mocker.patch("torchvision.prototype.transforms.functional.crop") + + # vfdev-5, Feature Request: let's store params as Transform attribute + # This could be also helpful for users + torch.manual_seed(12) + _ = transform(inpt) + torch.manual_seed(12) + if padding is None and not pad_if_needed: + params = transform._get_params(inpt) + fn_crop.assert_called_once_with( + inpt, top=params["top"], left=params["left"], height=output_size[0], width=output_size[1] + ) + elif not pad_if_needed: + params = transform._get_params(expected) + fn_crop.assert_called_once_with( + expected, top=params["top"], left=params["left"], height=output_size[0], width=output_size[1] + ) + elif padding is None: + # vfdev-5: I do not know how to mock and test this case + pass + else: + # vfdev-5: I do not know how to mock and test this case + pass diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index 2075ea7c52b..db1d006336f 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -10,6 +10,7 @@ Resize, CenterCrop, RandomResizedCrop, + RandomCrop, FiveCrop, TenCrop, BatchMultiCrop, diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index fd14ac0296b..88a118dbc9a 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -35,7 +35,8 @@ def __init__( antialias: Optional[bool] = None, ) -> None: super().__init__() - self.size = [size] if isinstance(size, int) else list(size) + + self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") self.interpolation = interpolation self.max_size = max_size self.antialias = antialias @@ -80,7 +81,6 @@ def __init__( if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): warnings.warn("Scale and ratio should be of kind (min, max)") - self.size = size self.scale = scale self.ratio = ratio self.interpolation = interpolation @@ -225,6 +225,19 @@ def _check_fill_arg(fill: Union[int, float, Sequence[int], Sequence[float]]) -> raise TypeError("Got inappropriate fill arg") +def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None: + if not isinstance(padding, (numbers.Number, tuple, list)): + raise TypeError("Got inappropriate padding arg") + + if isinstance(padding, (tuple, list)) and len(padding) not in [1, 2, 4]: + raise ValueError(f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple") + + +def _check_padding_mode_arg(padding_mode: Literal["constant", "edge", "reflect", "symmetric"]) -> None: + if padding_mode not in ["constant", "edge", "reflect", "symmetric"]: + raise ValueError("Padding mode should be either constant, edge, reflect or symmetric") + + class Pad(Transform): def __init__( self, @@ -233,18 +246,10 @@ def __init__( padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant", ) -> None: super().__init__() - if not isinstance(padding, (numbers.Number, tuple, list)): - raise TypeError("Got inappropriate padding arg") - - if isinstance(padding, (tuple, list)) and len(padding) not in [1, 2, 4]: - raise ValueError( - f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple" - ) + _check_padding_arg(padding) _check_fill_arg(fill) - - if padding_mode not in ["constant", "edge", "reflect", "symmetric"]: - raise ValueError("Padding mode should be either constant, edge, reflect or symmetric") + _check_padding_mode_arg(padding_mode) self.padding = padding self.fill = fill @@ -416,3 +421,75 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill=self.fill, center=self.center, ) + + +class RandomCrop(Transform): + def __init__( + self, + size: Union[int, Sequence[int]], + padding: Optional[Union[int, Sequence[int]]] = None, + pad_if_needed: bool = False, + fill: Union[int, float, Sequence[int], Sequence[float]] = 0, + padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant", + ) -> None: + super().__init__() + + self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") + + if padding is not None: + _check_padding_arg(padding) + + if (padding is not None) or pad_if_needed: + _check_padding_mode_arg(padding_mode) + _check_fill_arg(fill) + + self.padding = padding + self.pad_if_needed = pad_if_needed + self.fill = fill + self.padding_mode = padding_mode + + def _get_params(self, sample: Any) -> Dict[str, Any]: + image = query_image(sample) + _, height, width = get_image_dimensions(image) + output_height, output_width = self.size + + if height + 1 < output_height or width + 1 < output_width: + raise ValueError( + f"Required crop size {(output_height, output_width)} is larger then input image size {(height, width)}" + ) + + if width == output_width and height == output_height: + return dict(top=0, left=0, height=height, width=width) + + top = torch.randint(0, height - output_height + 1, size=(1,)).item() + left = torch.randint(0, width - output_width + 1, size=(1,)).item() + return dict(top=top, left=left, height=output_height, width=output_width) + + def _forward(self, flat_inputs: List[Any]) -> List[Any]: + if self.padding is not None: + flat_inputs = [F.pad(flat_input, self.padding, self.fill, self.padding_mode) for flat_input in flat_inputs] + + image = query_image(flat_inputs) + _, height, width = get_image_dimensions(image) + + # pad the width if needed + if self.pad_if_needed and width < self.size[1]: + padding = [self.size[1] - width, 0] + flat_inputs = [F.pad(flat_input, padding, self.fill, self.padding_mode) for flat_input in flat_inputs] + # pad the height if needed + if self.pad_if_needed and height < self.size[0]: + padding = [0, self.size[0] - height] + flat_inputs = [F.pad(flat_input, padding, self.fill, self.padding_mode) for flat_input in flat_inputs] + + params = self._get_params(flat_inputs) + + return [F.crop(flat_input, **params) for flat_input in flat_inputs] + + def forward(self, *inputs: Any) -> Any: + from torch.utils._pytree import tree_flatten, tree_unflatten + + sample = inputs if len(inputs) > 1 else inputs[0] + + flat_inputs, spec = tree_flatten(sample) + out_flat_inputs = self._forward(flat_inputs) + return tree_unflatten(out_flat_inputs, spec) diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index 0517757a758..c41ef294975 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -2,6 +2,7 @@ import PIL.Image import torch +from torch.utils._pytree import tree_flatten from torchvision.prototype import features from torchvision.prototype.utils._internal import query_recursively @@ -9,10 +10,20 @@ def query_image(sample: Any) -> Union[PIL.Image.Image, torch.Tensor, features.Image]: + flat_sample, _ = tree_flatten(sample) + for i in flat_sample: + if type(i) == torch.Tensor or isinstance(i, (PIL.Image.Image, features.Image)): + return i + + raise TypeError("No image was found in the sample") + + +# vfdev-5: let's use tree_flatten instead of query_recursively and internal fn to make the code simplier +def query_image_(sample: Any) -> Union[PIL.Image.Image, torch.Tensor, features.Image]: def fn( id: Tuple[Any, ...], input: Any ) -> Optional[Tuple[Tuple[Any, ...], Union[PIL.Image.Image, torch.Tensor, features.Image]]]: - if type(input) in {torch.Tensor, features.Image} or isinstance(input, PIL.Image.Image): + if type(input) == torch.Tensor or isinstance(input, (PIL.Image.Image, features.Image)): return id, input return None From bb2f4e107070a282003095976ade370879a53d9d Mon Sep 17 00:00:00 2001 From: vfdev Date: Fri, 15 Jul 2022 15:43:49 +0200 Subject: [PATCH 03/13] [proto] Added GaussianBlur transform and tests (#6273) * Added GaussianBlur transform and tests * Fixing code format * Copied correctness test --- test/test_prototype_transforms.py | 55 +++++++++++++++ test/test_prototype_transforms_functional.py | 67 +++++++++++++++++++ torchvision/prototype/features/_feature.py | 3 + torchvision/prototype/features/_image.py | 6 ++ torchvision/prototype/transforms/__init__.py | 5 +- torchvision/prototype/transforms/_misc.py | 33 ++++++++- .../transforms/functional/__init__.py | 7 +- .../prototype/transforms/functional/_misc.py | 25 ++++++- 8 files changed, 197 insertions(+), 4 deletions(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index d561705fdfe..d6987f6b71b 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -644,3 +644,58 @@ def test_forward(self, padding, pad_if_needed, fill, padding_mode, mocker): else: # vfdev-5: I do not know how to mock and test this case pass + + +class TestGaussianBlur: + def test_assertions(self): + with pytest.raises(ValueError, match="Kernel size should be a tuple/list of two integers"): + transforms.GaussianBlur([10, 12, 14]) + + with pytest.raises(ValueError, match="Kernel size value should be an odd and positive number"): + transforms.GaussianBlur(4) + + with pytest.raises(TypeError, match="sigma should be a single float or a list/tuple with length 2"): + transforms.GaussianBlur(3, sigma=[1, 2, 3]) + + with pytest.raises(ValueError, match="If sigma is a single number, it must be positive"): + transforms.GaussianBlur(3, sigma=-1.0) + + with pytest.raises(ValueError, match="sigma values should be positive and of the form"): + transforms.GaussianBlur(3, sigma=[2.0, 1.0]) + + @pytest.mark.parametrize("sigma", [10.0, [10.0, 12.0]]) + def test__get_params(self, sigma): + transform = transforms.GaussianBlur(3, sigma=sigma) + params = transform._get_params(None) + + if isinstance(sigma, float): + assert params["sigma"][0] == params["sigma"][1] == 10 + else: + assert sigma[0] <= params["sigma"][0] <= sigma[1] + assert sigma[0] <= params["sigma"][1] <= sigma[1] + + @pytest.mark.parametrize("kernel_size", [3, [3, 5], (5, 3)]) + @pytest.mark.parametrize("sigma", [2.0, [2.0, 3.0]]) + def test__transform(self, kernel_size, sigma, mocker): + transform = transforms.GaussianBlur(kernel_size=kernel_size, sigma=sigma) + + if isinstance(kernel_size, (tuple, list)): + assert transform.kernel_size == kernel_size + else: + assert transform.kernel_size == (kernel_size, kernel_size) + + if isinstance(sigma, (tuple, list)): + assert transform.sigma == sigma + else: + assert transform.sigma == (sigma, sigma) + + fn = mocker.patch("torchvision.prototype.transforms.functional.gaussian_blur") + inpt = features.Image(torch.rand(1, 3, 32, 32)) + # vfdev-5, Feature Request: let's store params as Transform attribute + # This could be also helpful for users + torch.manual_seed(12) + _ = transform(inpt) + torch.manual_seed(12) + params = transform._get_params(inpt) + + fn.assert_called_once_with(inpt, **params) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index c880e8db55b..e39eb4b6632 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -1,6 +1,7 @@ import functools import itertools import math +import os import numpy as np import pytest @@ -495,6 +496,7 @@ def center_crop_bounding_box(): ) +@register_kernel_info_from_sample_inputs_fn def center_crop_segmentation_mask(): for mask, output_size in itertools.product( make_segmentation_masks(image_sizes=((16, 16), (7, 33), (31, 9))), @@ -503,6 +505,16 @@ def center_crop_segmentation_mask(): yield SampleInput(mask, output_size) +@register_kernel_info_from_sample_inputs_fn +def gaussian_blur_image_tensor(): + for image, kernel_size, sigma in itertools.product( + make_images(extra_dims=((4,),)), + [[3, 3]], + [None, [3.0, 3.0]], + ): + yield SampleInput(image, kernel_size=kernel_size, sigma=sigma) + + @pytest.mark.parametrize( "kernel", [ @@ -1555,3 +1567,58 @@ def _compute_expected_segmentation_mask(mask, output_size): expected = _compute_expected_segmentation_mask(mask, output_size) torch.testing.assert_close(expected, actual) + + +# Copied from test/test_functional_tensor.py +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("image_size", ("small", "large")) +@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16]) +@pytest.mark.parametrize("ksize", [(3, 3), [3, 5], (23, 23)]) +@pytest.mark.parametrize("sigma", [[0.5, 0.5], (0.5, 0.5), (0.8, 0.8), (1.7, 1.7)]) +def test_correctness_gaussian_blur_image_tensor(device, image_size, dt, ksize, sigma): + fn = F.gaussian_blur_image_tensor + + # true_cv2_results = { + # # np_img = np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3)) + # # cv2.GaussianBlur(np_img, ksize=(3, 3), sigmaX=0.8) + # "3_3_0.8": ... + # # cv2.GaussianBlur(np_img, ksize=(3, 3), sigmaX=0.5) + # "3_3_0.5": ... + # # cv2.GaussianBlur(np_img, ksize=(3, 5), sigmaX=0.8) + # "3_5_0.8": ... + # # cv2.GaussianBlur(np_img, ksize=(3, 5), sigmaX=0.5) + # "3_5_0.5": ... + # # np_img2 = np.arange(26 * 28, dtype="uint8").reshape((26, 28)) + # # cv2.GaussianBlur(np_img2, ksize=(23, 23), sigmaX=1.7) + # "23_23_1.7": ... + # } + p = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "gaussian_blur_opencv_results.pt") + true_cv2_results = torch.load(p) + + if image_size == "small": + tensor = ( + torch.from_numpy(np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))).permute(2, 0, 1).to(device) + ) + else: + tensor = torch.from_numpy(np.arange(26 * 28, dtype="uint8").reshape((1, 26, 28))).to(device) + + if dt == torch.float16 and device == "cpu": + # skip float16 on CPU case + return + + if dt is not None: + tensor = tensor.to(dtype=dt) + + _ksize = (ksize, ksize) if isinstance(ksize, int) else ksize + _sigma = sigma[0] if sigma is not None else None + shape = tensor.shape + gt_key = f"{shape[-2]}_{shape[-1]}_{shape[-3]}__{_ksize[0]}_{_ksize[1]}_{_sigma}" + if gt_key not in true_cv2_results: + return + + true_out = ( + torch.tensor(true_cv2_results[gt_key]).reshape(shape[-2], shape[-1], shape[-3]).permute(2, 0, 1).to(tensor) + ) + + out = fn(tensor, kernel_size=ksize, sigma=sigma) + torch.testing.assert_close(out, true_out, rtol=0.0, atol=1.0, msg=f"{ksize}, {sigma}") diff --git a/torchvision/prototype/features/_feature.py b/torchvision/prototype/features/_feature.py index 775f09f2f4b..6013672d7ef 100644 --- a/torchvision/prototype/features/_feature.py +++ b/torchvision/prototype/features/_feature.py @@ -189,3 +189,6 @@ def equalize(self) -> Any: def invert(self) -> Any: return self + + def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Any: + return self diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py index 447e67b33e9..0abda7b01d8 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/features/_image.py @@ -309,3 +309,9 @@ def invert(self) -> Image: output = _F.invert_image_tensor(self) return Image.new_like(self, output) + + def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Image: + from torchvision.prototype.transforms import functional as _F + + output = _F.gaussian_blur_image_tensor(self, kernel_size=kernel_size, sigma=sigma) + return Image.new_like(self, output) diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index db1d006336f..f77b36d4643 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -22,7 +22,10 @@ RandomAffine, ) from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace -from ._misc import Identity, Normalize, ToDtype, Lambda +from ._misc import Identity, GaussianBlur, Normalize, ToDtype, Lambda from ._type_conversion import DecodeImage, LabelToOneHot from ._deprecated import Grayscale, RandomGrayscale, ToTensor, ToPILImage, PILToTensor # usort: skip + +# TODO: add RandomPerspective, RandomInvert, RandomPosterize, RandomSolarize, +# RandomAdjustSharpness, RandomAutocontrast, ElasticTransform diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index 54440ee05a5..b8e9101f2a0 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -1,8 +1,9 @@ import functools -from typing import Any, List, Type, Callable, Dict +from typing import Any, List, Type, Callable, Dict, Sequence, Union import torch from torchvision.prototype.transforms import Transform, functional as F +from torchvision.transforms.transforms import _setup_size class Identity(Transform): @@ -46,6 +47,36 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any: return input +class GaussianBlur(Transform): + def __init__( + self, kernel_size: Union[int, Sequence[int]], sigma: Union[float, Sequence[float]] = (0.1, 2.0) + ) -> None: + super().__init__() + self.kernel_size = _setup_size(kernel_size, "Kernel size should be a tuple/list of two integers") + for ks in self.kernel_size: + if ks <= 0 or ks % 2 == 0: + raise ValueError("Kernel size value should be an odd and positive number.") + + if isinstance(sigma, float): + if sigma <= 0: + raise ValueError("If sigma is a single number, it must be positive.") + sigma = (sigma, sigma) + elif isinstance(sigma, Sequence) and len(sigma) == 2: + if not 0.0 < sigma[0] <= sigma[1]: + raise ValueError("sigma values should be positive and of the form (min, max).") + else: + raise TypeError("sigma should be a single float or a list/tuple with length 2 floats.") + + self.sigma = sigma + + def _get_params(self, sample: Any) -> Dict[str, Any]: + sigma = torch.empty(1).uniform_(self.sigma[0], self.sigma[1]).item() + return dict(sigma=[sigma, sigma]) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.gaussian_blur(inpt, **params) + + class ToDtype(Lambda): def __init__(self, dtype: torch.dtype, *types: Type) -> None: self.dtype = dtype diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index a8c17577a56..2d2618cf497 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -99,7 +99,12 @@ ten_crop_image_tensor, ten_crop_image_pil, ) -from ._misc import normalize_image_tensor, gaussian_blur_image_tensor +from ._misc import ( + normalize_image_tensor, + gaussian_blur, + gaussian_blur_image_tensor, + gaussian_blur_image_pil, +) from ._type_conversion import ( decode_image_with_pil, decode_video_with_av, diff --git a/torchvision/prototype/transforms/functional/_misc.py b/torchvision/prototype/transforms/functional/_misc.py index 7b7139a5fd9..e51cac1745e 100644 --- a/torchvision/prototype/transforms/functional/_misc.py +++ b/torchvision/prototype/transforms/functional/_misc.py @@ -1,14 +1,28 @@ -from typing import Optional, List +from typing import Optional, List, Union import PIL.Image import torch +from torchvision.prototype import features from torchvision.transforms import functional_tensor as _FT from torchvision.transforms.functional import pil_to_tensor, to_pil_image +# shortcut type +DType = Union[torch.Tensor, PIL.Image.Image, features._Feature] + + normalize_image_tensor = _FT.normalize +def normalize(inpt: DType, mean: List[float], std: List[float], inplace: bool = False) -> DType: + if isinstance(inpt, features.Image): + return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace) + elif type(inpt) == torch.Tensor: + return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace) + else: + raise TypeError("Unsupported input type") + + def gaussian_blur_image_tensor( img: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None ) -> torch.Tensor: @@ -42,3 +56,12 @@ def gaussian_blur_image_pil(img: PIL.Image, kernel_size: List[int], sigma: Optio t_img = pil_to_tensor(img) output = gaussian_blur_image_tensor(t_img, kernel_size=kernel_size, sigma=sigma) return to_pil_image(output, mode=img.mode) + + +def gaussian_blur(inpt: DType, kernel_size: List[int], sigma: Optional[List[float]] = None) -> DType: + if isinstance(inpt, features._Feature): + return inpt.gaussian_blur(kernel_size=kernel_size, sigma=sigma) + elif isinstance(inpt, PIL.Image.Image): + return gaussian_blur_image_pil(inpt, kernel_size=kernel_size, sigma=sigma) + else: + return gaussian_blur_image_tensor(inpt, kernel_size=kernel_size, sigma=sigma) From 378f3c3142f26a4d6f0f97ea00fcb1ef0012d5d5 Mon Sep 17 00:00:00 2001 From: vfdev Date: Sat, 16 Jul 2022 15:06:17 +0200 Subject: [PATCH 04/13] [proto] Added random color transforms and tests (#6275) * Added random color transforms and tests * Disable smoke test for RandomSolarize, RandomAdjustSharpness --- test/test_prototype_transforms.py | 31 ++++++++++++ test/test_prototype_transforms_functional.py | 51 ++++++++++++++++++++ torchvision/prototype/transforms/__init__.py | 14 ++++-- torchvision/prototype/transforms/_color.py | 38 ++++++++++++++- 4 files changed, 129 insertions(+), 5 deletions(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index d6987f6b71b..ca187aa5af5 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -83,6 +83,12 @@ class TestSmoke: transforms.RandomRotation(degrees=(-45, 45)), transforms.RandomAffine(degrees=(-45, 45)), transforms.RandomCrop([16, 16], padding=1, pad_if_needed=True), + # TODO: Something wrong with input data setup. Let's fix that + # transforms.RandomEqualize(), + # transforms.RandomInvert(), + # transforms.RandomPosterize(bits=4), + # transforms.RandomSolarize(threshold=0.5), + # transforms.RandomAdjustSharpness(sharpness_factor=0.5), ) def test_common(self, transform, input): transform(input) @@ -699,3 +705,28 @@ def test__transform(self, kernel_size, sigma, mocker): params = transform._get_params(inpt) fn.assert_called_once_with(inpt, **params) + + +class TestRandomColorOp: + @pytest.mark.parametrize("p", [0.0, 1.0]) + @pytest.mark.parametrize( + "transform_cls, func_op_name, kwargs", + [ + (transforms.RandomEqualize, "equalize", {}), + (transforms.RandomInvert, "invert", {}), + (transforms.RandomAutocontrast, "autocontrast", {}), + (transforms.RandomPosterize, "posterize", {"bits": 4}), + (transforms.RandomSolarize, "solarize", {"threshold": 0.5}), + (transforms.RandomAdjustSharpness, "adjust_sharpness", {"sharpness_factor": 0.5}), + ], + ) + def test__transform(self, p, transform_cls, func_op_name, kwargs, mocker): + transform = transform_cls(p=p, **kwargs) + + fn = mocker.patch(f"torchvision.prototype.transforms.functional.{func_op_name}") + inpt = mocker.MagicMock(spec=features.Image) + _ = transform(inpt) + if p > 0.0: + fn.assert_called_once_with(inpt, **kwargs) + else: + fn.call_count == 0 diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index e39eb4b6632..e2d5ff2d24d 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -515,6 +515,57 @@ def gaussian_blur_image_tensor(): yield SampleInput(image, kernel_size=kernel_size, sigma=sigma) +@register_kernel_info_from_sample_inputs_fn +def equalize_image_tensor(): + for image in make_images(extra_dims=(), color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB)): + if image.dtype != torch.uint8: + continue + yield SampleInput(image) + + +@register_kernel_info_from_sample_inputs_fn +def invert_image_tensor(): + for image in make_images(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB)): + yield SampleInput(image) + + +@register_kernel_info_from_sample_inputs_fn +def posterize_image_tensor(): + for image, bits in itertools.product( + make_images(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB)), + [1, 4, 8], + ): + if image.dtype != torch.uint8: + continue + yield SampleInput(image, bits=bits) + + +@register_kernel_info_from_sample_inputs_fn +def solarize_image_tensor(): + for image, threshold in itertools.product( + make_images(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB)), + [0.1, 0.5, 127.0], + ): + if image.is_floating_point() and threshold > 1.0: + continue + yield SampleInput(image, threshold=threshold) + + +@register_kernel_info_from_sample_inputs_fn +def autocontrast_image_tensor(): + for image in make_images(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB)): + yield SampleInput(image) + + +@register_kernel_info_from_sample_inputs_fn +def adjust_sharpness_image_tensor(): + for image, sharpness_factor in itertools.product( + make_images(extra_dims=((4,),), color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB)), + [0.1, 0.5], + ): + yield SampleInput(image, sharpness_factor=sharpness_factor) + + @pytest.mark.parametrize( "kernel", [ diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index f77b36d4643..42984847412 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -4,7 +4,16 @@ from ._augment import RandomErasing, RandomMixup, RandomCutmix from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment, AugMix -from ._color import ColorJitter, RandomPhotometricDistort, RandomEqualize +from ._color import ( + ColorJitter, + RandomPhotometricDistort, + RandomEqualize, + RandomInvert, + RandomPosterize, + RandomSolarize, + RandomAutocontrast, + RandomAdjustSharpness, +) from ._container import Compose, RandomApply, RandomChoice, RandomOrder from ._geometry import ( Resize, @@ -27,5 +36,4 @@ from ._deprecated import Grayscale, RandomGrayscale, ToTensor, ToPILImage, PILToTensor # usort: skip -# TODO: add RandomPerspective, RandomInvert, RandomPosterize, RandomSolarize, -# RandomAdjustSharpness, RandomAutocontrast, ElasticTransform +# TODO: add RandomPerspective, ElasticTransform diff --git a/torchvision/prototype/transforms/_color.py b/torchvision/prototype/transforms/_color.py index 85e22aaeb1a..7fd198161c9 100644 --- a/torchvision/prototype/transforms/_color.py +++ b/torchvision/prototype/transforms/_color.py @@ -151,8 +151,42 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: class RandomEqualize(_RandomApplyTransform): - def __init__(self, p: float = 0.5): + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.equalize(inpt) + + +class RandomInvert(_RandomApplyTransform): + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.invert(inpt) + + +class RandomPosterize(_RandomApplyTransform): + def __init__(self, bits: int, p: float = 0.5) -> None: super().__init__(p=p) + self.bits = bits def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.equalize(inpt) + return F.posterize(inpt, bits=self.bits) + + +class RandomSolarize(_RandomApplyTransform): + def __init__(self, threshold: float, p: float = 0.5) -> None: + super().__init__(p=p) + self.threshold = threshold + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.solarize(inpt, threshold=self.threshold) + + +class RandomAutocontrast(_RandomApplyTransform): + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.autocontrast(inpt) + + +class RandomAdjustSharpness(_RandomApplyTransform): + def __init__(self, sharpness_factor: float, p: float = 0.5) -> None: + super().__init__(p=p) + self.sharpness_factor = sharpness_factor + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.adjust_sharpness(inpt, sharpness_factor=self.sharpness_factor) From 3a9aca1698f9e3988bbb8a95467fa424bf9b65af Mon Sep 17 00:00:00 2001 From: vfdev Date: Mon, 18 Jul 2022 18:44:39 +0200 Subject: [PATCH 05/13] Added RandomPerspective and tests (#6284) - replaced real image creation by mocks for other tests --- test/test_prototype_transforms.py | 108 +++++++++++++++--- test/test_prototype_transforms_functional.py | 4 +- torchvision/prototype/transforms/__init__.py | 1 + torchvision/prototype/transforms/_geometry.py | 58 ++++++++++ .../transforms/functional/_geometry.py | 6 +- 5 files changed, 157 insertions(+), 20 deletions(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index ca187aa5af5..b8cfabcccc9 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -350,7 +350,7 @@ def test__transform(self, padding, fill, padding_mode, mocker): transform = transforms.Pad(padding, fill=fill, padding_mode=padding_mode) fn = mocker.patch("torchvision.prototype.transforms.functional.pad") - inpt = mocker.MagicMock(spec=torch.Tensor) + inpt = mocker.MagicMock(spec=features.Image) _ = transform(inpt) fn.assert_called_once_with(inpt, padding=padding, fill=fill, padding_mode=padding_mode) @@ -369,11 +369,12 @@ def test_assertions(self): @pytest.mark.parametrize("fill", [0, [1, 2, 3], (2, 3, 4)]) @pytest.mark.parametrize("side_range", [(1.0, 4.0), [2.0, 5.0]]) - def test__get_params(self, fill, side_range): + def test__get_params(self, fill, side_range, mocker): transform = transforms.RandomZoomOut(fill=fill, side_range=side_range) - image = features.Image(torch.rand(1, 3, 32, 32)) - c, h, w = image.shape[-3:] + image = mocker.MagicMock(spec=features.Image) + c = image.num_channels = 3 + h, w = image.image_size = (24, 32) params = transform._get_params(image) @@ -387,19 +388,22 @@ def test__get_params(self, fill, side_range): @pytest.mark.parametrize("fill", [0, [1, 2, 3], (2, 3, 4)]) @pytest.mark.parametrize("side_range", [(1.0, 4.0), [2.0, 5.0]]) def test__transform(self, fill, side_range, mocker): - image = features.Image(torch.rand(1, 3, 32, 32)) + inpt = mocker.MagicMock(spec=features.Image) + inpt.num_channels = 3 + inpt.image_size = (24, 32) + transform = transforms.RandomZoomOut(fill=fill, side_range=side_range, p=1) fn = mocker.patch("torchvision.prototype.transforms.functional.pad") # vfdev-5, Feature Request: let's store params as Transform attribute # This could be also helpful for users torch.manual_seed(12) - _ = transform(image) + _ = transform(inpt) torch.manual_seed(12) torch.rand(1) # random apply changes random state - params = transform._get_params(image) + params = transform._get_params(inpt) - fn.assert_called_once_with(image, **params) + fn.assert_called_once_with(inpt, **params) class TestRandomRotation: @@ -449,7 +453,7 @@ def test__transform(self, degrees, expand, fill, center, mocker): assert transform.degrees == [float(-degrees), float(degrees)] fn = mocker.patch("torchvision.prototype.transforms.functional.rotate") - inpt = mocker.MagicMock(spec=torch.Tensor) + inpt = mocker.MagicMock(spec=features.Image) # vfdev-5, Feature Request: let's store params as Transform attribute # This could be also helpful for users torch.manual_seed(12) @@ -504,9 +508,11 @@ def test_assertions(self): @pytest.mark.parametrize("translate", [None, [0.1, 0.2]]) @pytest.mark.parametrize("scale", [None, [0.7, 1.2]]) @pytest.mark.parametrize("shear", [None, 2.0, [5.0, 15.0], [1.0, 2.0, 3.0, 4.0]]) - def test__get_params(self, degrees, translate, scale, shear): - image = features.Image(torch.rand(1, 3, 32, 32)) - h, w = image.shape[-2:] + def test__get_params(self, degrees, translate, scale, shear, mocker): + image = mocker.MagicMock(spec=features.Image) + image.num_channels = 3 + image.image_size = (24, 32) + h, w = image.image_size transform = transforms.RandomAffine(degrees, translate=translate, scale=scale, shear=shear) params = transform._get_params(image) @@ -564,7 +570,10 @@ def test__transform(self, degrees, translate, scale, shear, fill, center, mocker assert transform.degrees == [float(-degrees), float(degrees)] fn = mocker.patch("torchvision.prototype.transforms.functional.affine") - inpt = features.Image(torch.rand(1, 3, 32, 32)) + inpt = mocker.MagicMock(spec=features.Image) + inpt.num_channels = 3 + inpt.image_size = (24, 32) + # vfdev-5, Feature Request: let's store params as Transform attribute # This could be also helpful for users torch.manual_seed(12) @@ -592,9 +601,11 @@ def test_assertions(self): with pytest.raises(ValueError, match="Padding mode should be either"): transforms.RandomCrop([10, 12], padding=1, padding_mode="abc") - def test__get_params(self): - image = features.Image(torch.rand(1, 3, 32, 32)) - h, w = image.shape[-2:] + def test__get_params(self, mocker): + image = mocker.MagicMock(spec=features.Image) + image.num_channels = 3 + image.image_size = (24, 32) + h, w = image.image_size transform = transforms.RandomCrop([10, 10]) params = transform._get_params(image) @@ -614,7 +625,10 @@ def test_forward(self, padding, pad_if_needed, fill, padding_mode, mocker): output_size, padding=padding, pad_if_needed=pad_if_needed, fill=fill, padding_mode=padding_mode ) - inpt = features.Image(torch.rand(1, 3, 32, 32)) + inpt = mocker.MagicMock(spec=features.Image) + inpt.num_channels = 3 + inpt.image_size = (32, 32) + expected = mocker.MagicMock(spec=features.Image) expected.num_channels = 3 if isinstance(padding, int): @@ -696,7 +710,10 @@ def test__transform(self, kernel_size, sigma, mocker): assert transform.sigma == (sigma, sigma) fn = mocker.patch("torchvision.prototype.transforms.functional.gaussian_blur") - inpt = features.Image(torch.rand(1, 3, 32, 32)) + inpt = mocker.MagicMock(spec=features.Image) + inpt.num_channels = 3 + inpt.image_size = (24, 32) + # vfdev-5, Feature Request: let's store params as Transform attribute # This could be also helpful for users torch.manual_seed(12) @@ -730,3 +747,58 @@ def test__transform(self, p, transform_cls, func_op_name, kwargs, mocker): fn.assert_called_once_with(inpt, **kwargs) else: fn.call_count == 0 + + +class TestRandomPerspective: + def test_assertions(self): + with pytest.raises(ValueError, match="Argument distortion_scale value should be between 0 and 1"): + transforms.RandomPerspective(distortion_scale=-1.0) + + with pytest.raises(TypeError, match="Got inappropriate fill arg"): + transforms.RandomPerspective(0.5, fill="abc") + + def test__get_params(self, mocker): + dscale = 0.5 + transform = transforms.RandomPerspective(dscale) + image = mocker.MagicMock(spec=features.Image) + image.num_channels = 3 + image.image_size = (24, 32) + + params = transform._get_params(image) + + h, w = image.image_size + assert len(params["startpoints"]) == 4 + for x, y in params["startpoints"]: + assert x in (0, w - 1) + assert y in (0, h - 1) + + assert len(params["endpoints"]) == 4 + for (x, y), name in zip(params["endpoints"], ["tl", "tr", "br", "bl"]): + if "t" in name: + assert 0 <= y <= int(dscale * h // 2), (x, y, name) + if "b" in name: + assert h - int(dscale * h // 2) - 1 <= y <= h, (x, y, name) + if "l" in name: + assert 0 <= x <= int(dscale * w // 2), (x, y, name) + if "r" in name: + assert w - int(dscale * w // 2) - 1 <= x <= w, (x, y, name) + + @pytest.mark.parametrize("distortion_scale", [0.1, 0.7]) + def test__transform(self, distortion_scale, mocker): + interpolation = InterpolationMode.BILINEAR + fill = 12 + transform = transforms.RandomPerspective(distortion_scale, fill=fill, interpolation=interpolation) + + fn = mocker.patch("torchvision.prototype.transforms.functional.perspective") + inpt = mocker.MagicMock(spec=features.Image) + inpt.num_channels = 3 + inpt.image_size = (24, 32) + # vfdev-5, Feature Request: let's store params as Transform attribute + # This could be also helpful for users + torch.manual_seed(12) + _ = transform(inpt) + torch.manual_seed(12) + torch.rand(1) # random apply changes random state + params = transform._get_params(inpt) + + fn.assert_called_once_with(inpt, **params, fill=fill, interpolation=interpolation) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index e2d5ff2d24d..873516869f8 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -599,9 +599,11 @@ def test_scriptable(kernel): and all( feature_type not in name for feature_type in {"image", "segmentation_mask", "bounding_box", "label", "pil"} ) - and name not in {"to_image_tensor", "InterpolationMode", "decode_video_with_av", "crop", "rotate"} + and name + not in {"to_image_tensor", "InterpolationMode", "decode_video_with_av", "crop", "rotate", "perspective"} # We skip 'crop' due to missing 'height' and 'width' # We skip 'rotate' due to non implemented yet expand=True case for bboxes + # We skip 'perspective' as it requires different input args than perspective_image_tensor etc ], ) def test_functional_mid_level(func): diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index 42984847412..c41171a05be 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -29,6 +29,7 @@ RandomZoomOut, RandomRotation, RandomAffine, + RandomPerspective, ) from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace from ._misc import Identity, GaussianBlur, Normalize, ToDtype, Lambda diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 88a118dbc9a..3cf3858720e 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -292,6 +292,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: bottom = canvas_height - (top + orig_h) padding = [left, top, right, bottom] + # vfdev-5: Can we put that into pad_image_tensor ? fill = self.fill if not isinstance(fill, collections.abc.Sequence): fill = [fill] * orig_c @@ -493,3 +494,60 @@ def forward(self, *inputs: Any) -> Any: flat_inputs, spec = tree_flatten(sample) out_flat_inputs = self._forward(flat_inputs) return tree_unflatten(out_flat_inputs, spec) + + +class RandomPerspective(_RandomApplyTransform): + def __init__( + self, + distortion_scale: float, + fill: Union[int, float, Sequence[int], Sequence[float]] = 0, + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + p: float = 0.5, + ) -> None: + super().__init__(p=p) + + _check_fill_arg(fill) + if not (0 <= distortion_scale <= 1): + raise ValueError("Argument distortion_scale value should be between 0 and 1") + + self.distortion_scale = distortion_scale + self.interpolation = interpolation + self.fill = fill + + def _get_params(self, sample: Any) -> Dict[str, Any]: + # Get image size + # TODO: make it work with bboxes and segm masks + image = query_image(sample) + _, height, width = get_image_dimensions(image) + + distortion_scale = self.distortion_scale + + half_height = height // 2 + half_width = width // 2 + topleft = [ + int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1,)).item()), + int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1,)).item()), + ] + topright = [ + int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1,)).item()), + int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1,)).item()), + ] + botright = [ + int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1,)).item()), + int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1,)).item()), + ] + botleft = [ + int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1,)).item()), + int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1,)).item()), + ] + startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]] + endpoints = [topleft, topright, botright, botleft] + return dict(startpoints=startpoints, endpoints=endpoints) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.perspective( + inpt, + **params, + fill=self.fill, + interpolation=self.interpolation, + ) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 8d3ed675047..87419ba8640 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -11,6 +11,7 @@ _get_inverse_affine_matrix, InterpolationMode, _compute_output_size, + _get_perspective_coeffs, ) from ._meta import convert_bounding_box_format, get_dimensions_image_tensor, get_dimensions_image_pil @@ -765,10 +766,13 @@ def perspective_segmentation_mask(img: torch.Tensor, perspective_coeffs: List[fl def perspective( inpt: DType, - perspective_coeffs: List[float], + startpoints: List[List[int]], + endpoints: List[List[int]], interpolation: InterpolationMode = InterpolationMode.BILINEAR, fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, ) -> DType: + perspective_coeffs = _get_perspective_coeffs(startpoints, endpoints) + if isinstance(inpt, features._Feature): return inpt.perspective(perspective_coeffs, interpolation=interpolation, fill=fill) elif isinstance(inpt, PIL.Image.Image): From 794443d241379afdcfe41ac75e4b70dfff5cfab8 Mon Sep 17 00:00:00 2001 From: vfdev Date: Mon, 18 Jul 2022 22:07:14 +0200 Subject: [PATCH 06/13] Added more functional tests (#6285) --- test/test_prototype_transforms_functional.py | 72 +++++++++++++++++++- 1 file changed, 69 insertions(+), 3 deletions(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 873516869f8..61d1adfab18 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -200,6 +200,30 @@ def horizontal_flip_bounding_box(): yield SampleInput(bounding_box, format=bounding_box.format, image_size=bounding_box.image_size) +@register_kernel_info_from_sample_inputs_fn +def horizontal_flip_segmentation_mask(): + for mask in make_segmentation_masks(): + yield SampleInput(mask) + + +@register_kernel_info_from_sample_inputs_fn +def vertical_flip_image_tensor(): + for image in make_images(): + yield SampleInput(image) + + +@register_kernel_info_from_sample_inputs_fn +def vertical_flip_bounding_box(): + for bounding_box in make_bounding_boxes(formats=[features.BoundingBoxFormat.XYXY]): + yield SampleInput(bounding_box, format=bounding_box.format, image_size=bounding_box.image_size) + + +@register_kernel_info_from_sample_inputs_fn +def vertical_flip_segmentation_mask(): + for mask in make_segmentation_masks(): + yield SampleInput(mask) + + @register_kernel_info_from_sample_inputs_fn def resize_image_tensor(): for image, interpolation, max_size, antialias in itertools.product( @@ -404,9 +428,17 @@ def crop_segmentation_mask(): @register_kernel_info_from_sample_inputs_fn -def vertical_flip_segmentation_mask(): - for mask in make_segmentation_masks(): - yield SampleInput(mask) +def resized_crop_image_tensor(): + for mask, top, left, height, width, size, antialias in itertools.product( + make_images(), + [-8, 9], + [-8, 9], + [12], + [12], + [(16, 18)], + [True, False], + ): + yield SampleInput(mask, top=top, left=left, height=height, width=width, size=size, antialias=antialias) @register_kernel_info_from_sample_inputs_fn @@ -457,6 +489,19 @@ def pad_bounding_box(): yield SampleInput(bounding_box, padding=padding, format=bounding_box.format) +@register_kernel_info_from_sample_inputs_fn +def perspective_image_tensor(): + for image, perspective_coeffs, fill in itertools.product( + make_images(extra_dims=((), (4,))), + [ + [1.2405, 0.1772, -6.9113, 0.0463, 1.251, -5.235, 0.00013, 0.0018], + [0.7366, -0.11724, 1.45775, -0.15012, 0.73406, 2.6019, -0.0072, -0.0063], + ], + [None, [128], [12.0]], # fill + ): + yield SampleInput(image, perspective_coeffs=perspective_coeffs, fill=fill) + + @register_kernel_info_from_sample_inputs_fn def perspective_bounding_box(): for bounding_box, perspective_coeffs in itertools.product( @@ -488,6 +533,15 @@ def perspective_segmentation_mask(): ) +@register_kernel_info_from_sample_inputs_fn +def center_crop_image_tensor(): + for mask, output_size in itertools.product( + make_images(sizes=((16, 16), (7, 33), (31, 9))), + [[4, 3], [42, 70], [4]], # crop sizes < image sizes, crop_sizes > image sizes, single crop size + ): + yield SampleInput(mask, output_size) + + @register_kernel_info_from_sample_inputs_fn def center_crop_bounding_box(): for bounding_box, output_size in itertools.product(make_bounding_boxes(), [(24, 12), [16, 18], [46, 48], [12]]): @@ -1181,6 +1235,18 @@ def _compute_expected_mask(mask, top_, left_, height_, width_): torch.testing.assert_close(output_mask, expected_mask) +@pytest.mark.parametrize("device", cpu_and_gpu()) +def test_correctness_horizontal_flip_segmentation_mask_on_fixed_input(device): + mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device) + mask[:, :, 0] = 1 + + out_mask = F.horizontal_flip_segmentation_mask(mask) + + expected_mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device) + expected_mask[:, :, -1] = 1 + torch.testing.assert_close(out_mask, expected_mask) + + @pytest.mark.parametrize("device", cpu_and_gpu()) def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device): mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device) From 94c7ddeb569f04509682a5f5a137620abc5f1657 Mon Sep 17 00:00:00 2001 From: vfdev Date: Fri, 22 Jul 2022 19:38:10 +0200 Subject: [PATCH 07/13] [proto] Added elastic transform and tests (#6295) * WIP [proto] Added functional elastic transform with tests * Added more functional tests * WIP on elastic op * Added elastic transform and tests * Added tests * Added tests for ElasticTransform --- test/test_functional_tensor.py | 18 ++- test/test_prototype_transforms.py | 75 +++++++++++++ test/test_prototype_transforms_functional.py | 106 ++++++++++++++++-- .../prototype/features/_bounding_box.py | 11 ++ torchvision/prototype/features/_feature.py | 8 ++ torchvision/prototype/features/_image.py | 13 +++ .../prototype/features/_segmentation_mask.py | 12 ++ torchvision/prototype/transforms/__init__.py | 3 +- torchvision/prototype/transforms/_geometry.py | 69 ++++++++++++ .../transforms/functional/__init__.py | 6 + .../transforms/functional/_geometry.py | 92 ++++++++++++++- torchvision/transforms/functional.py | 13 ++- torchvision/transforms/functional_tensor.py | 10 +- 13 files changed, 414 insertions(+), 22 deletions(-) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 969aedf6d2d..bec868c88fd 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -1352,16 +1352,24 @@ def test_ten_crop(device): assert_equal(transformed_batch, s_transformed_batch) +def test_elastic_transform_asserts(): + with pytest.raises(TypeError, match="Argument displacement should be a Tensor"): + _ = F.elastic_transform("abc", displacement=None) + + with pytest.raises(TypeError, match="img should be PIL Image or Tensor"): + _ = F.elastic_transform("abc", displacement=torch.rand(1)) + + img_tensor = torch.rand(1, 3, 32, 24) + with pytest.raises(ValueError, match="Argument displacement shape should"): + _ = F.elastic_transform(img_tensor, displacement=torch.rand(1, 2)) + + @pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR, BICUBIC]) @pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16]) @pytest.mark.parametrize( "fill", - [ - None, - [255, 255, 255], - (2.0,), - ], + [None, [255, 255, 255], (2.0,)], ) def test_elastic_transform_consistency(device, interpolation, dt, fill): script_elastic_transform = torch.jit.script(F.elastic_transform) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index b8cfabcccc9..51eba38c7a6 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -397,6 +397,7 @@ def test__transform(self, fill, side_range, mocker): fn = mocker.patch("torchvision.prototype.transforms.functional.pad") # vfdev-5, Feature Request: let's store params as Transform attribute # This could be also helpful for users + # Otherwise, we can mock transform._get_params torch.manual_seed(12) _ = transform(inpt) torch.manual_seed(12) @@ -456,6 +457,7 @@ def test__transform(self, degrees, expand, fill, center, mocker): inpt = mocker.MagicMock(spec=features.Image) # vfdev-5, Feature Request: let's store params as Transform attribute # This could be also helpful for users + # Otherwise, we can mock transform._get_params torch.manual_seed(12) _ = transform(inpt) torch.manual_seed(12) @@ -576,6 +578,7 @@ def test__transform(self, degrees, translate, scale, shear, fill, center, mocker # vfdev-5, Feature Request: let's store params as Transform attribute # This could be also helpful for users + # Otherwise, we can mock transform._get_params torch.manual_seed(12) _ = transform(inpt) torch.manual_seed(12) @@ -645,6 +648,7 @@ def test_forward(self, padding, pad_if_needed, fill, padding_mode, mocker): # vfdev-5, Feature Request: let's store params as Transform attribute # This could be also helpful for users + # Otherwise, we can mock transform._get_params torch.manual_seed(12) _ = transform(inpt) torch.manual_seed(12) @@ -716,6 +720,7 @@ def test__transform(self, kernel_size, sigma, mocker): # vfdev-5, Feature Request: let's store params as Transform attribute # This could be also helpful for users + # Otherwise, we can mock transform._get_params torch.manual_seed(12) _ = transform(inpt) torch.manual_seed(12) @@ -795,6 +800,7 @@ def test__transform(self, distortion_scale, mocker): inpt.image_size = (24, 32) # vfdev-5, Feature Request: let's store params as Transform attribute # This could be also helpful for users + # Otherwise, we can mock transform._get_params torch.manual_seed(12) _ = transform(inpt) torch.manual_seed(12) @@ -802,3 +808,72 @@ def test__transform(self, distortion_scale, mocker): params = transform._get_params(inpt) fn.assert_called_once_with(inpt, **params, fill=fill, interpolation=interpolation) + + +class TestElasticTransform: + def test_assertions(self): + + with pytest.raises(TypeError, match="alpha should be float or a sequence of floats"): + transforms.ElasticTransform({}) + + with pytest.raises(ValueError, match="alpha is a sequence its length should be one of 2"): + transforms.ElasticTransform([1.0, 2.0, 3.0]) + + with pytest.raises(ValueError, match="alpha should be a sequence of floats"): + transforms.ElasticTransform([1, 2]) + + with pytest.raises(TypeError, match="sigma should be float or a sequence of floats"): + transforms.ElasticTransform(1.0, {}) + + with pytest.raises(ValueError, match="sigma is a sequence its length should be one of 2"): + transforms.ElasticTransform(1.0, [1.0, 2.0, 3.0]) + + with pytest.raises(ValueError, match="sigma should be a sequence of floats"): + transforms.ElasticTransform(1.0, [1, 2]) + + with pytest.raises(TypeError, match="Got inappropriate fill arg"): + transforms.ElasticTransform(1.0, 2.0, fill="abc") + + def test__get_params(self, mocker): + alpha = 2.0 + sigma = 3.0 + transform = transforms.ElasticTransform(alpha, sigma) + image = mocker.MagicMock(spec=features.Image) + image.num_channels = 3 + image.image_size = (24, 32) + + params = transform._get_params(image) + + h, w = image.image_size + displacement = params["displacement"] + assert displacement.shape == (1, h, w, 2) + assert (-alpha / w <= displacement[0, ..., 0]).all() and (displacement[0, ..., 0] <= alpha / w).all() + assert (-alpha / h <= displacement[0, ..., 1]).all() and (displacement[0, ..., 1] <= alpha / h).all() + + @pytest.mark.parametrize("alpha", [5.0, [5.0, 10.0]]) + @pytest.mark.parametrize("sigma", [2.0, [2.0, 5.0]]) + def test__transform(self, alpha, sigma, mocker): + interpolation = InterpolationMode.BILINEAR + fill = 12 + transform = transforms.ElasticTransform(alpha, sigma=sigma, fill=fill, interpolation=interpolation) + + if isinstance(alpha, float): + assert transform.alpha == [alpha, alpha] + else: + assert transform.alpha == alpha + + if isinstance(sigma, float): + assert transform.sigma == [sigma, sigma] + else: + assert transform.sigma == sigma + + fn = mocker.patch("torchvision.prototype.transforms.functional.elastic") + inpt = mocker.MagicMock(spec=features.Image) + inpt.num_channels = 3 + inpt.image_size = (24, 32) + + # Let's mock transform._get_params to control the output: + transform._get_params = mocker.MagicMock() + _ = transform(inpt) + params = transform._get_params(inpt) + fn.assert_called_once_with(inpt, **params, fill=fill, interpolation=interpolation) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 61d1adfab18..fb5f10459fe 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -59,7 +59,7 @@ def make_images( yield make_image(size, color_space=color_space, dtype=dtype) for color_space, dtype, extra_dims_ in itertools.product(color_spaces, dtypes, extra_dims): - yield make_image(color_space=color_space, extra_dims=extra_dims_, dtype=dtype) + yield make_image(size=sizes[0], color_space=color_space, extra_dims=extra_dims_, dtype=dtype) def randint_with_tensor_bounds(arg1, arg2=None, **kwargs): @@ -149,12 +149,12 @@ def make_segmentation_mask(size=None, *, num_categories=80, extra_dims=(), dtype def make_segmentation_masks( - image_sizes=((16, 16), (7, 33), (31, 9)), + sizes=((16, 16), (7, 33), (31, 9)), dtypes=(torch.long,), extra_dims=((), (4,), (2, 3)), ): - for image_size, dtype, extra_dims_ in itertools.product(image_sizes, dtypes, extra_dims): - yield make_segmentation_mask(size=image_size, dtype=dtype, extra_dims=extra_dims_) + for size, dtype, extra_dims_ in itertools.product(sizes, dtypes, extra_dims): + yield make_segmentation_mask(size=size, dtype=dtype, extra_dims=extra_dims_) class SampleInput: @@ -533,6 +533,40 @@ def perspective_segmentation_mask(): ) +@register_kernel_info_from_sample_inputs_fn +def elastic_image_tensor(): + for image, fill in itertools.product( + make_images(extra_dims=((), (4,))), + [None, [128], [12.0]], # fill + ): + h, w = image.shape[-2:] + displacement = torch.rand(1, h, w, 2) + yield SampleInput(image, displacement=displacement, fill=fill) + + +@register_kernel_info_from_sample_inputs_fn +def elastic_bounding_box(): + for bounding_box in make_bounding_boxes(): + h, w = bounding_box.image_size + displacement = torch.rand(1, h, w, 2) + yield SampleInput( + bounding_box, + format=bounding_box.format, + displacement=displacement, + ) + + +@register_kernel_info_from_sample_inputs_fn +def elastic_segmentation_mask(): + for mask in make_segmentation_masks(extra_dims=((), (4,))): + h, w = mask.shape[-2:] + displacement = torch.rand(1, h, w, 2) + yield SampleInput( + mask, + displacement=displacement, + ) + + @register_kernel_info_from_sample_inputs_fn def center_crop_image_tensor(): for mask, output_size in itertools.product( @@ -553,7 +587,7 @@ def center_crop_bounding_box(): @register_kernel_info_from_sample_inputs_fn def center_crop_segmentation_mask(): for mask, output_size in itertools.product( - make_segmentation_masks(image_sizes=((16, 16), (7, 33), (31, 9))), + make_segmentation_masks(sizes=((16, 16), (7, 33), (31, 9))), [[4, 3], [42, 70], [4]], # crop sizes < image sizes, crop_sizes > image sizes, single crop size ): yield SampleInput(mask, output_size) @@ -654,10 +688,20 @@ def test_scriptable(kernel): feature_type not in name for feature_type in {"image", "segmentation_mask", "bounding_box", "label", "pil"} ) and name - not in {"to_image_tensor", "InterpolationMode", "decode_video_with_av", "crop", "rotate", "perspective"} + not in { + "to_image_tensor", + "InterpolationMode", + "decode_video_with_av", + "crop", + "rotate", + "perspective", + "elastic_transform", + "elastic", + } # We skip 'crop' due to missing 'height' and 'width' # We skip 'rotate' due to non implemented yet expand=True case for bboxes # We skip 'perspective' as it requires different input args than perspective_image_tensor etc + # Skip 'elastic', TODO: inspect why test is failing ], ) def test_functional_mid_level(func): @@ -670,7 +714,9 @@ def test_functional_mid_level(func): if key in kwargs: del kwargs[key] output = func(*sample_input.args, **kwargs) - torch.testing.assert_close(output, expected, msg=f"finfo={finfo}, output={output}, expected={expected}") + torch.testing.assert_close( + output, expected, msg=f"finfo={finfo.name}, output={output}, expected={expected}" + ) break @@ -1739,5 +1785,49 @@ def test_correctness_gaussian_blur_image_tensor(device, image_size, dt, ksize, s torch.tensor(true_cv2_results[gt_key]).reshape(shape[-2], shape[-1], shape[-3]).permute(2, 0, 1).to(tensor) ) - out = fn(tensor, kernel_size=ksize, sigma=sigma) + image = features.Image(tensor) + + out = fn(image, kernel_size=ksize, sigma=sigma) torch.testing.assert_close(out, true_out, rtol=0.0, atol=1.0, msg=f"{ksize}, {sigma}") + + +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize( + "fn, make_samples", [(F.elastic_image_tensor, make_images), (F.elastic_segmentation_mask, make_segmentation_masks)] +) +def test_correctness_elastic_image_or_mask_tensor(device, fn, make_samples): + in_box = [10, 15, 25, 35] + for sample in make_samples(sizes=((64, 76),), extra_dims=((), (4,))): + c, h, w = sample.shape[-3:] + # Setup a dummy image with 4 points + sample[..., in_box[1], in_box[0]] = torch.tensor([12, 34, 96, 112])[:c] + sample[..., in_box[3] - 1, in_box[0]] = torch.tensor([12, 34, 96, 112])[:c] + sample[..., in_box[3] - 1, in_box[2] - 1] = torch.tensor([12, 34, 96, 112])[:c] + sample[..., in_box[1], in_box[2] - 1] = torch.tensor([12, 34, 96, 112])[:c] + sample = sample.to(device) + + if fn == F.elastic_image_tensor: + sample = features.Image(sample) + kwargs = {"interpolation": F.InterpolationMode.NEAREST} + else: + sample = features.SegmentationMask(sample) + kwargs = {} + + # Create a displacement grid using sin + n, m = 5.0, 0.1 + d1 = m * torch.sin(torch.arange(h, dtype=torch.float) * torch.pi * n / h) + d2 = m * torch.sin(torch.arange(w, dtype=torch.float) * torch.pi * n / w) + + d1 = d1[:, None].expand((h, w)) + d2 = d2[None, :].expand((h, w)) + + displacement = torch.cat([d1[..., None], d2[..., None]], dim=-1) + displacement = displacement.reshape(1, h, w, 2) + + output = fn(sample, displacement=displacement, **kwargs) + + # Check places where transformed points should be + torch.testing.assert_close(output[..., 12, 9], sample[..., in_box[1], in_box[0]]) + torch.testing.assert_close(output[..., 17, 27], sample[..., in_box[1], in_box[2] - 1]) + torch.testing.assert_close(output[..., 31, 6], sample[..., in_box[3] - 1, in_box[0]]) + torch.testing.assert_close(output[..., 37, 23], sample[..., in_box[3] - 1, in_box[2] - 1]) diff --git a/torchvision/prototype/features/_bounding_box.py b/torchvision/prototype/features/_bounding_box.py index eb9d1f6ac3a..f1957df5f1d 100644 --- a/torchvision/prototype/features/_bounding_box.py +++ b/torchvision/prototype/features/_bounding_box.py @@ -207,3 +207,14 @@ def perspective( output = _F.perspective_bounding_box(self, self.format, perspective_coeffs) return BoundingBox.new_like(self, output, dtype=output.dtype) + + def elastic( + self, + displacement: torch.Tensor, + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, + ) -> BoundingBox: + from torchvision.prototype.transforms import functional as _F + + output = _F.elastic_bounding_box(self, self.format, displacement) + return BoundingBox.new_like(self, output, dtype=output.dtype) diff --git a/torchvision/prototype/features/_feature.py b/torchvision/prototype/features/_feature.py index 6013672d7ef..b2dedea86d3 100644 --- a/torchvision/prototype/features/_feature.py +++ b/torchvision/prototype/features/_feature.py @@ -157,6 +157,14 @@ def perspective( ) -> Any: return self + def elastic( + self, + displacement: torch.Tensor, + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, + ) -> Any: + return self + def adjust_brightness(self, brightness_factor: float) -> Any: return self diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py index 0abda7b01d8..18cf7c1964d 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/features/_image.py @@ -244,6 +244,19 @@ def perspective( output = _F.perspective_image_tensor(self, perspective_coeffs, interpolation=interpolation, fill=fill) return Image.new_like(self, output) + def elastic( + self, + displacement: torch.Tensor, + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, + ) -> Image: + from torchvision.prototype.transforms.functional import _geometry as _F + + fill = _F._convert_fill_arg(fill) + + output = _F.elastic_image_tensor(self, displacement, interpolation=interpolation, fill=fill) + return Image.new_like(self, output) + def adjust_brightness(self, brightness_factor: float) -> Image: from torchvision.prototype.transforms import functional as _F diff --git a/torchvision/prototype/features/_segmentation_mask.py b/torchvision/prototype/features/_segmentation_mask.py index f894f33d1b2..5f7ea80430b 100644 --- a/torchvision/prototype/features/_segmentation_mask.py +++ b/torchvision/prototype/features/_segmentation_mask.py @@ -2,6 +2,7 @@ from typing import List, Optional, Union, Sequence +import torch from torchvision.transforms import InterpolationMode from ._feature import _Feature @@ -119,3 +120,14 @@ def perspective( output = _F.perspective_segmentation_mask(self, perspective_coeffs) return SegmentationMask.new_like(self, output) + + def elastic( + self, + displacement: torch.Tensor, + interpolation: InterpolationMode = InterpolationMode.NEAREST, + fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, + ) -> SegmentationMask: + from torchvision.prototype.transforms import functional as _F + + output = _F.elastic_segmentation_mask(self, displacement) + return SegmentationMask.new_like(self, output, dtype=output.dtype) diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index c41171a05be..483f811ef7d 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -30,11 +30,10 @@ RandomRotation, RandomAffine, RandomPerspective, + ElasticTransform, ) from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace from ._misc import Identity, GaussianBlur, Normalize, ToDtype, Lambda from ._type_conversion import DecodeImage, LabelToOneHot from ._deprecated import Grayscale, RandomGrayscale, ToTensor, ToPILImage, PILToTensor # usort: skip - -# TODO: add RandomPerspective, ElasticTransform diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 3cf3858720e..8a1b94060c4 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -551,3 +551,72 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill=self.fill, interpolation=self.interpolation, ) + + +def _setup_float_or_seq(arg: Union[float, Sequence[float]], name: str, req_size: int = 2) -> Sequence[float]: + if not isinstance(arg, (float, Sequence)): + raise TypeError(f"{name} should be float or a sequence of floats. Got {type(arg)}") + if isinstance(arg, Sequence) and len(arg) != req_size: + raise ValueError(f"If {name} is a sequence its length should be one of {req_size}. Got {len(arg)}") + if isinstance(arg, Sequence): + for element in arg: + if not isinstance(element, float): + raise ValueError(f"{name} should be a sequence of floats. Got {type(element)}") + + if isinstance(arg, float): + arg = [float(arg), float(arg)] + if isinstance(arg, (list, tuple)) and len(arg) == 1: + arg = [arg[0], arg[0]] + return arg + + +class ElasticTransform(Transform): + def __init__( + self, + alpha: Union[float, Sequence[float]] = 50.0, + sigma: Union[float, Sequence[float]] = 5.0, + fill: Union[int, float, Sequence[int], Sequence[float]] = 0, + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + ) -> None: + super().__init__() + self.alpha = _setup_float_or_seq(alpha, "alpha", 2) + self.sigma = _setup_float_or_seq(sigma, "sigma", 2) + + _check_fill_arg(fill) + + self.interpolation = interpolation + self.fill = fill + + def _get_params(self, sample: Any) -> Dict[str, Any]: + # Get image size + # TODO: make it work with bboxes and segm masks + image = query_image(sample) + _, *size = get_image_dimensions(image) + + dx = torch.rand([1, 1] + size) * 2 - 1 + if self.sigma[0] > 0.0: + kx = int(8 * self.sigma[0] + 1) + # if kernel size is even we have to make it odd + if kx % 2 == 0: + kx += 1 + dx = F.gaussian_blur(dx, [kx, kx], list(self.sigma)) + dx = dx * self.alpha[0] / size[0] + + dy = torch.rand([1, 1] + size) * 2 - 1 + if self.sigma[1] > 0.0: + ky = int(8 * self.sigma[1] + 1) + # if kernel size is even we have to make it odd + if ky % 2 == 0: + ky += 1 + dy = F.gaussian_blur(dy, [ky, ky], list(self.sigma)) + dy = dy * self.alpha[1] / size[1] + displacement = torch.concat([dx, dy], 1).permute([0, 2, 3, 1]) # 1 x H x W x 2 + return dict(displacement=displacement) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.elastic( + inpt, + **params, + fill=self.fill, + interpolation=self.interpolation, + ) diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index 2d2618cf497..638049f96fb 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -89,6 +89,12 @@ perspective_image_tensor, perspective_image_pil, perspective_segmentation_mask, + elastic, + elastic_transform, + elastic_bounding_box, + elastic_image_tensor, + elastic_image_pil, + elastic_segmentation_mask, vertical_flip, vertical_flip_image_tensor, vertical_flip_image_pil, diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 87419ba8640..e48701bb436 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -12,6 +12,8 @@ InterpolationMode, _compute_output_size, _get_perspective_coeffs, + pil_to_tensor, + to_pil_image, ) from ._meta import convert_bounding_box_format, get_dimensions_image_tensor, get_dimensions_image_pil @@ -760,8 +762,10 @@ def perspective_bounding_box( ).view(original_shape) -def perspective_segmentation_mask(img: torch.Tensor, perspective_coeffs: List[float]) -> torch.Tensor: - return perspective_image_tensor(img, perspective_coeffs=perspective_coeffs, interpolation=InterpolationMode.NEAREST) +def perspective_segmentation_mask(mask: torch.Tensor, perspective_coeffs: List[float]) -> torch.Tensor: + return perspective_image_tensor( + mask, perspective_coeffs=perspective_coeffs, interpolation=InterpolationMode.NEAREST + ) def perspective( @@ -783,6 +787,90 @@ def perspective( return perspective_image_tensor(inpt, perspective_coeffs, interpolation=interpolation, fill=fill) +def elastic_image_tensor( + img: torch.Tensor, + displacement: torch.Tensor, + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + fill: Optional[List[float]] = None, +) -> torch.Tensor: + return _FT.elastic_transform(img, displacement, interpolation=interpolation.value, fill=fill) + + +def elastic_image_pil( + img: PIL.Image.Image, + displacement: torch.Tensor, + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, +) -> PIL.Image.Image: + t_img = pil_to_tensor(img) + fill = _convert_fill_arg(fill) + + output = elastic_image_tensor(t_img, displacement, interpolation=interpolation, fill=fill) + return to_pil_image(output, mode=img.mode) + + +def elastic_bounding_box( + bounding_box: torch.Tensor, + format: features.BoundingBoxFormat, + displacement: torch.Tensor, +) -> torch.Tensor: + displacement = displacement.to(bounding_box.device) + + original_shape = bounding_box.shape + bounding_box = convert_bounding_box_format( + bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY + ).view(-1, 4) + + # Question (vfdev-5): should we rely on good displacement shape and fetch image size from it + # Or add image_size arg and check displacement shape + image_size = displacement.shape[-3], displacement.shape[-2] + + id_grid = _FT._create_identity_grid(list(image_size)).to(bounding_box.device) + # We construct an approximation of inverse grid as inv_grid = id_grid - displacement + # This is not an exact inverse of the grid + inv_grid = id_grid - displacement + + # Get points from bboxes + points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].view(-1, 2) + index_x = torch.floor(points[:, 0] + 0.5).to(dtype=torch.long) + index_y = torch.floor(points[:, 1] + 0.5).to(dtype=torch.long) + # Transform points: + t_size = torch.tensor(image_size[::-1], device=displacement.device, dtype=displacement.dtype) + transformed_points = (inv_grid[0, index_y, index_x, :] + 1) * 0.5 * t_size - 0.5 + + transformed_points = transformed_points.view(-1, 4, 2) + out_bbox_mins, _ = torch.min(transformed_points, dim=1) + out_bbox_maxs, _ = torch.max(transformed_points, dim=1) + out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1) + + return convert_bounding_box_format( + out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False + ).view(original_shape) + + +def elastic_segmentation_mask(mask: torch.Tensor, displacement: torch.Tensor) -> torch.Tensor: + return elastic_image_tensor(mask, displacement=displacement, interpolation=InterpolationMode.NEAREST) + + +def elastic( + inpt: DType, + displacement: torch.Tensor, + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, +) -> DType: + if isinstance(inpt, features._Feature): + return inpt.elastic(displacement, interpolation=interpolation, fill=fill) + elif isinstance(inpt, PIL.Image.Image): + return elastic_image_pil(inpt, displacement, interpolation=interpolation, fill=fill) + else: + fill = _convert_fill_arg(fill) + + return elastic_image_tensor(inpt, displacement, interpolation=interpolation, fill=fill) + + +elastic_transform = elastic + + def _center_crop_parse_output_size(output_size: List[int]) -> List[int]: if isinstance(output_size, numbers.Number): return [int(output_size), int(output_size)] diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index eea53a228a9..442e8d4288d 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1555,7 +1555,7 @@ def elastic_transform( If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format, where ... means it can have an arbitrary number of leading dimensions. If img is PIL Image, it is expected to be in mode "P", "L" or "RGB". - displacement (Tensor): The displacement field. + displacement (Tensor): The displacement field. Expected shape is [1, H, W, 2]. interpolation (InterpolationMode): Desired interpolation enum defined by :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. @@ -1577,7 +1577,7 @@ def elastic_transform( interpolation = _interpolation_modes_from_int(interpolation) if not isinstance(displacement, torch.Tensor): - raise TypeError("displacement should be a Tensor") + raise TypeError("Argument displacement should be a Tensor") t_img = img if not isinstance(img, torch.Tensor): @@ -1585,6 +1585,15 @@ def elastic_transform( raise TypeError(f"img should be PIL Image or Tensor. Got {type(img)}") t_img = pil_to_tensor(img) + shape = t_img.shape + shape = (1,) + shape[-2:] + (2,) + if shape != displacement.shape: + raise ValueError(f"Argument displacement shape should be {shape}, but given {displacement.shape}") + + # TODO: if image shape is [N1, N2, ..., C, H, W] and + # displacement is [1, H, W, 2] we need to reshape input image + # such grid_sampler takes internal code for 4D input + output = F_t.elastic_transform( t_img, displacement, diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 2b0872acf8a..c35edfb74b0 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -932,6 +932,12 @@ def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool return img +def _create_identity_grid(size: List[int]) -> Tensor: + hw_space = [torch.linspace((-s + 1) / s, (s - 1) / s, s) for s in size] + grid_y, grid_x = torch.meshgrid(hw_space, indexing="ij") + return torch.stack([grid_x, grid_y], -1).unsqueeze(0) # 1 x H x W x 2 + + def elastic_transform( img: Tensor, displacement: Tensor, @@ -945,8 +951,6 @@ def elastic_transform( size = list(img.shape[-2:]) displacement = displacement.to(img.device) - hw_space = [torch.linspace((-s + 1) / s, (s - 1) / s, s) for s in size] - grid_y, grid_x = torch.meshgrid(hw_space, indexing="ij") - identity_grid = torch.stack([grid_x, grid_y], -1).unsqueeze(0) # 1 x H x W x 2 + identity_grid = _create_identity_grid(size) grid = identity_grid.to(img.device) + displacement return _apply_grid_transform(img, grid, interpolation, fill) From b7ed6832b43e8db24a8b10d51ff754fefa0b7894 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Fri, 22 Jul 2022 19:49:06 +0200 Subject: [PATCH 08/13] Try to format code as in https://github.com/pytorch/vision/pull/5106 --- .../classification/train_quantization.py | 2 +- references/detection/group_by_aspect_ratio.py | 2 +- references/detection/train.py | 4 +- references/detection/transforms.py | 5 +- references/optical_flow/train.py | 4 +- references/optical_flow/utils.py | 7 +- references/segmentation/utils.py | 2 +- test/builtin_dataset_mocks.py | 4 +- test/conftest.py | 2 +- test/datasets_utils.py | 2 +- test/test_datasets_download.py | 6 +- test/test_datasets_samplers.py | 8 +- test/test_datasets_video_utils.py | 4 +- test/test_extended_models.py | 2 +- test/test_functional_tensor.py | 8 +- test/test_image.py | 20 +-- test/test_models.py | 2 +- .../test_models_detection_negative_samples.py | 4 +- test/test_onnx.py | 9 +- test/test_ops.py | 4 +- test/test_prototype_builtin_datasets.py | 8 +- test/test_prototype_datasets_utils.py | 4 +- test/test_prototype_models.py | 2 +- test/test_prototype_transforms.py | 10 +- test/test_transforms.py | 2 +- test/test_transforms_tensor.py | 15 +- test/test_utils.py | 6 +- torchvision/__init__.py | 7 +- torchvision/_utils.py | 2 +- torchvision/datasets/__init__.py | 14 +- torchvision/datasets/_optical_flow.py | 4 +- torchvision/datasets/caltech.py | 2 +- torchvision/datasets/celeba.py | 4 +- torchvision/datasets/cityscapes.py | 4 +- torchvision/datasets/clevr.py | 2 +- torchvision/datasets/coco.py | 2 +- torchvision/datasets/country211.py | 2 +- torchvision/datasets/dtd.py | 4 +- torchvision/datasets/fer2013.py | 2 +- torchvision/datasets/flowers102.py | 2 +- torchvision/datasets/folder.py | 3 +- torchvision/datasets/food101.py | 4 +- torchvision/datasets/hmdb51.py | 2 +- torchvision/datasets/imagenet.py | 2 +- torchvision/datasets/inaturalist.py | 2 +- torchvision/datasets/kinetics.py | 2 +- torchvision/datasets/lsun.py | 2 +- torchvision/datasets/mnist.py | 2 +- torchvision/datasets/omniglot.py | 2 +- torchvision/datasets/oxford_iiit_pet.py | 3 +- torchvision/datasets/pcam.py | 2 +- torchvision/datasets/places365.py | 2 +- torchvision/datasets/rendered_sst2.py | 4 +- torchvision/datasets/samplers/__init__.py | 2 +- torchvision/datasets/samplers/clip_sampler.py | 2 +- torchvision/datasets/sbd.py | 2 +- torchvision/datasets/sbu.py | 2 +- torchvision/datasets/semeion.py | 2 +- torchvision/datasets/stanford_cars.py | 2 +- torchvision/datasets/stl10.py | 2 +- torchvision/datasets/sun397.py | 2 +- torchvision/datasets/svhn.py | 2 +- torchvision/datasets/ucf101.py | 2 +- torchvision/datasets/utils.py | 7 +- torchvision/datasets/video_utils.py | 9 +- torchvision/datasets/voc.py | 2 +- torchvision/datasets/widerface.py | 9 +- torchvision/extension.py | 1 - torchvision/io/__init__.py | 12 +- torchvision/io/_video_opt.py | 2 +- torchvision/io/video_reader.py | 5 +- torchvision/models/__init__.py | 6 +- torchvision/models/_api.py | 2 +- torchvision/models/_utils.py | 2 +- torchvision/models/alexnet.py | 4 +- torchvision/models/convnext.py | 4 +- torchvision/models/densenet.py | 4 +- torchvision/models/detection/_utils.py | 4 +- .../models/detection/backbone_utils.py | 4 +- torchvision/models/detection/faster_rcnn.py | 14 +- torchvision/models/detection/fcos.py | 12 +- .../models/detection/generalized_rcnn.py | 2 +- torchvision/models/detection/keypoint_rcnn.py | 6 +- torchvision/models/detection/mask_rcnn.py | 8 +- torchvision/models/detection/retinanet.py | 14 +- torchvision/models/detection/roi_heads.py | 2 +- torchvision/models/detection/rpn.py | 20 ++- torchvision/models/detection/ssd.py | 6 +- torchvision/models/detection/ssdlite.py | 6 +- torchvision/models/detection/transform.py | 2 +- torchvision/models/efficientnet.py | 6 +- torchvision/models/feature_extraction.py | 5 +- torchvision/models/googlenet.py | 6 +- torchvision/models/inception.py | 6 +- torchvision/models/mnasnet.py | 4 +- torchvision/models/mobilenetv2.py | 9 +- torchvision/models/mobilenetv3.py | 4 +- torchvision/models/quantization/googlenet.py | 6 +- torchvision/models/quantization/inception.py | 6 +- .../models/quantization/mobilenetv2.py | 11 +- .../models/quantization/mobilenetv3.py | 10 +- torchvision/models/quantization/resnet.py | 8 +- .../models/quantization/shufflenetv2.py | 4 +- torchvision/models/regnet.py | 4 +- torchvision/models/resnet.py | 6 +- torchvision/models/segmentation/_utils.py | 2 +- torchvision/models/segmentation/deeplabv3.py | 8 +- torchvision/models/segmentation/fcn.py | 6 +- torchvision/models/segmentation/lraspp.py | 6 +- torchvision/models/shufflenetv2.py | 6 +- torchvision/models/squeezenet.py | 4 +- torchvision/models/swin_transformer.py | 6 +- torchvision/models/vgg.py | 6 +- torchvision/models/video/mvit.py | 4 +- torchvision/models/video/resnet.py | 6 +- torchvision/models/vision_transformer.py | 6 +- torchvision/ops/__init__.py | 14 +- torchvision/ops/boxes.py | 4 +- torchvision/ops/ciou_loss.py | 2 +- torchvision/ops/drop_block.py | 4 +- torchvision/ops/feature_pyramid_network.py | 2 +- torchvision/ops/giou_loss.py | 2 +- torchvision/ops/misc.py | 2 +- torchvision/ops/poolers.py | 2 +- torchvision/ops/ps_roi_align.py | 2 +- torchvision/ops/ps_roi_pool.py | 2 +- torchvision/ops/roi_align.py | 2 +- torchvision/ops/roi_pool.py | 2 +- torchvision/prototype/__init__.py | 6 +- torchvision/prototype/datasets/_api.py | 2 +- .../prototype/datasets/_builtin/__init__.py | 2 +- .../prototype/datasets/_builtin/caltech.py | 15 +- .../prototype/datasets/_builtin/celeba.py | 24 +--- .../prototype/datasets/_builtin/cifar.py | 12 +- .../prototype/datasets/_builtin/clevr.py | 12 +- .../prototype/datasets/_builtin/coco.py | 29 ++-- .../prototype/datasets/_builtin/country211.py | 4 +- .../prototype/datasets/_builtin/cub200.py | 22 +-- .../prototype/datasets/_builtin/dtd.py | 18 +-- .../prototype/datasets/_builtin/fer2013.py | 15 +- .../prototype/datasets/_builtin/food101.py | 21 +-- .../prototype/datasets/_builtin/gtsrb.py | 12 +- .../prototype/datasets/_builtin/imagenet.py | 24 ++-- .../prototype/datasets/_builtin/mnist.py | 6 +- .../datasets/_builtin/oxford_iiit_pet.py | 18 +-- .../prototype/datasets/_builtin/pcam.py | 13 +- .../prototype/datasets/_builtin/sbd.py | 19 +-- .../prototype/datasets/_builtin/semeion.py | 12 +- .../datasets/_builtin/stanford_cars.py | 4 +- .../prototype/datasets/_builtin/svhn.py | 22 +-- .../prototype/datasets/_builtin/usps.py | 4 +- .../prototype/datasets/_builtin/voc.py | 21 +-- torchvision/prototype/datasets/_folder.py | 6 +- .../prototype/datasets/utils/__init__.py | 2 +- .../prototype/datasets/utils/_dataset.py | 2 +- .../prototype/datasets/utils/_internal.py | 16 +-- .../prototype/datasets/utils/_resource.py | 18 +-- .../prototype/features/_bounding_box.py | 2 +- torchvision/prototype/features/_encoded.py | 2 +- torchvision/prototype/features/_feature.py | 2 +- torchvision/prototype/features/_image.py | 7 +- torchvision/prototype/features/_label.py | 2 +- .../prototype/features/_segmentation_mask.py | 2 +- .../models/depth/stereo/raft_stereo.py | 6 +- torchvision/prototype/transforms/__init__.py | 34 ++--- torchvision/prototype/transforms/_augment.py | 4 +- .../prototype/transforms/_auto_augment.py | 6 +- torchvision/prototype/transforms/_color.py | 6 +- .../prototype/transforms/_container.py | 2 +- torchvision/prototype/transforms/_geometry.py | 10 +- torchvision/prototype/transforms/_meta.py | 4 +- torchvision/prototype/transforms/_misc.py | 4 +- .../prototype/transforms/_type_conversion.py | 2 +- torchvision/prototype/transforms/_utils.py | 4 +- .../transforms/functional/__init__.py | 133 +++++++++--------- .../prototype/transforms/functional/_color.py | 2 +- .../transforms/functional/_geometry.py | 12 +- .../prototype/transforms/functional/_meta.py | 4 +- .../prototype/transforms/functional/_misc.py | 2 +- .../transforms/functional/_type_conversion.py | 2 +- torchvision/prototype/utils/_internal.py | 13 +- torchvision/transforms/_presets.py | 2 +- torchvision/transforms/_transforms_video.py | 5 +- torchvision/transforms/autoaugment.py | 2 +- torchvision/transforms/functional.py | 5 +- torchvision/transforms/functional_tensor.py | 6 +- torchvision/transforms/transforms.py | 4 +- torchvision/utils.py | 6 +- 188 files changed, 553 insertions(+), 740 deletions(-) diff --git a/references/classification/train_quantization.py b/references/classification/train_quantization.py index c0e5af1dcfc..a66a47f8674 100644 --- a/references/classification/train_quantization.py +++ b/references/classification/train_quantization.py @@ -9,7 +9,7 @@ import torchvision import utils from torch import nn -from train import train_one_epoch, evaluate, load_data +from train import evaluate, load_data, train_one_epoch def main(args): diff --git a/references/detection/group_by_aspect_ratio.py b/references/detection/group_by_aspect_ratio.py index 1323849a6a4..5312cc036d6 100644 --- a/references/detection/group_by_aspect_ratio.py +++ b/references/detection/group_by_aspect_ratio.py @@ -2,7 +2,7 @@ import copy import math from collections import defaultdict -from itertools import repeat, chain +from itertools import chain, repeat import numpy as np import torch diff --git a/references/detection/train.py b/references/detection/train.py index f56ac66881c..178f7460417 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -29,8 +29,8 @@ import torchvision.models.detection.mask_rcnn import utils from coco_utils import get_coco, get_coco_kp -from engine import train_one_epoch, evaluate -from group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups +from engine import evaluate, train_one_epoch +from group_by_aspect_ratio import create_aspect_ratio_groups, GroupedBatchSampler from torchvision.transforms import InterpolationMode from transforms import SimpleCopyPaste diff --git a/references/detection/transforms.py b/references/detection/transforms.py index 35ae34bd56a..7da854505f2 100644 --- a/references/detection/transforms.py +++ b/references/detection/transforms.py @@ -1,11 +1,10 @@ -from typing import List, Tuple, Dict, Optional, Union +from typing import Dict, List, Optional, Tuple, Union import torch import torchvision from torch import nn, Tensor from torchvision import ops -from torchvision.transforms import functional as F -from torchvision.transforms import transforms as T, InterpolationMode +from torchvision.transforms import functional as F, InterpolationMode, transforms as T def _flip_coco_person_keypoints(kps, width): diff --git a/references/optical_flow/train.py b/references/optical_flow/train.py index 7c4c45ab275..0327d92bdf9 100644 --- a/references/optical_flow/train.py +++ b/references/optical_flow/train.py @@ -6,8 +6,8 @@ import torch import torchvision.models.optical_flow import utils -from presets import OpticalFlowPresetTrain, OpticalFlowPresetEval -from torchvision.datasets import KittiFlow, FlyingChairs, FlyingThings3D, Sintel, HD1K +from presets import OpticalFlowPresetEval, OpticalFlowPresetTrain +from torchvision.datasets import FlyingChairs, FlyingThings3D, HD1K, KittiFlow, Sintel def get_train_dataset(stage, dataset_root): diff --git a/references/optical_flow/utils.py b/references/optical_flow/utils.py index 065a2be8bfc..8b07e9de35c 100644 --- a/references/optical_flow/utils.py +++ b/references/optical_flow/utils.py @@ -1,8 +1,7 @@ import datetime import os import time -from collections import defaultdict -from collections import deque +from collections import defaultdict, deque import torch import torch.distributed as dist @@ -158,7 +157,7 @@ def log_every(self, iterable, print_freq=5, header=None): def compute_metrics(flow_pred, flow_gt, valid_flow_mask=None): epe = ((flow_pred - flow_gt) ** 2).sum(dim=1).sqrt() - flow_norm = (flow_gt ** 2).sum(dim=1).sqrt() + flow_norm = (flow_gt**2).sum(dim=1).sqrt() if valid_flow_mask is not None: epe = epe[valid_flow_mask] @@ -183,7 +182,7 @@ def sequence_loss(flow_preds, flow_gt, valid_flow_mask, gamma=0.8, max_flow=400) raise ValueError(f"Gamma should be < 1, got {gamma}.") # exlude invalid pixels and extremely large diplacements - flow_norm = torch.sum(flow_gt ** 2, dim=1).sqrt() + flow_norm = torch.sum(flow_gt**2, dim=1).sqrt() valid_flow_mask = valid_flow_mask & (flow_norm < max_flow) valid_flow_mask = valid_flow_mask[:, None, :, :] diff --git a/references/segmentation/utils.py b/references/segmentation/utils.py index dfd12726b53..4ea24db83ed 100644 --- a/references/segmentation/utils.py +++ b/references/segmentation/utils.py @@ -75,7 +75,7 @@ def update(self, a, b): with torch.inference_mode(): k = (a >= 0) & (a < n) inds = n * a[k].to(torch.int64) + b[k] - self.mat += torch.bincount(inds, minlength=n ** 2).reshape(n, n) + self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n) def reset(self): self.mat.zero_() diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index 0fe0cbd6dd7..8c5484a2823 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -14,12 +14,12 @@ import unittest.mock import warnings import xml.etree.ElementTree as ET -from collections import defaultdict, Counter +from collections import Counter, defaultdict import numpy as np import pytest import torch -from datasets_utils import make_zip, make_tar, create_image_folder, create_image_file, combinations_grid +from datasets_utils import combinations_grid, create_image_file, create_image_folder, make_tar, make_zip from torch.nn.functional import one_hot from torch.testing import make_tensor as _make_tensor from torchvision.prototype import datasets diff --git a/test/conftest.py b/test/conftest.py index a8b9054a4e5..1a9b2db7f5c 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -3,7 +3,7 @@ import numpy as np import pytest import torch -from common_utils import IN_CIRCLE_CI, CIRCLECI_GPU_NO_CUDA_MSG, IN_FBCODE, IN_RE_WORKER, CUDA_NOT_AVAILABLE_MSG +from common_utils import CIRCLECI_GPU_NO_CUDA_MSG, CUDA_NOT_AVAILABLE_MSG, IN_CIRCLE_CI, IN_FBCODE, IN_RE_WORKER def pytest_configure(config): diff --git a/test/datasets_utils.py b/test/datasets_utils.py index 88eb4e17823..2043caae0a2 100644 --- a/test/datasets_utils.py +++ b/test/datasets_utils.py @@ -22,7 +22,7 @@ import torch import torchvision.datasets import torchvision.io -from common_utils import get_tmp_dir, disable_console_output +from common_utils import disable_console_output, get_tmp_dir __all__ = [ diff --git a/test/test_datasets_download.py b/test/test_datasets_download.py index 4d2e475e1df..5fa7c6bca44 100644 --- a/test/test_datasets_download.py +++ b/test/test_datasets_download.py @@ -9,15 +9,15 @@ from os import path from urllib.error import HTTPError, URLError from urllib.parse import urlparse -from urllib.request import urlopen, Request +from urllib.request import Request, urlopen import pytest from torchvision import datasets from torchvision.datasets.utils import ( - download_url, + _get_redirect_url, check_integrity, download_file_from_google_drive, - _get_redirect_url, + download_url, USER_AGENT, ) diff --git a/test/test_datasets_samplers.py b/test/test_datasets_samplers.py index 7174d6321f7..9e3826b2c13 100644 --- a/test/test_datasets_samplers.py +++ b/test/test_datasets_samplers.py @@ -1,12 +1,8 @@ import pytest import torch -from common_utils import get_list_of_videos, assert_equal +from common_utils import assert_equal, get_list_of_videos from torchvision import io -from torchvision.datasets.samplers import ( - DistributedSampler, - RandomClipSampler, - UniformClipSampler, -) +from torchvision.datasets.samplers import DistributedSampler, RandomClipSampler, UniformClipSampler from torchvision.datasets.video_utils import VideoClips diff --git a/test/test_datasets_video_utils.py b/test/test_datasets_video_utils.py index cfdbd6f6d02..adaa4f5446c 100644 --- a/test/test_datasets_video_utils.py +++ b/test/test_datasets_video_utils.py @@ -1,8 +1,8 @@ import pytest import torch -from common_utils import get_list_of_videos, assert_equal +from common_utils import assert_equal, get_list_of_videos from torchvision import io -from torchvision.datasets.video_utils import VideoClips, unfold +from torchvision.datasets.video_utils import unfold, VideoClips class TestVideo: diff --git a/test/test_extended_models.py b/test/test_extended_models.py index 7961d173e3f..677d19d18f7 100644 --- a/test/test_extended_models.py +++ b/test/test_extended_models.py @@ -5,7 +5,7 @@ import test_models as TM import torch from torchvision import models -from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._api import Weights, WeightsEnum from torchvision.models._utils import handle_legacy_interface diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index bec868c88fd..1914bc571fb 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -14,14 +14,14 @@ import torchvision.transforms.functional_pil as F_pil import torchvision.transforms.functional_tensor as F_t from common_utils import ( - cpu_and_gpu, - needs_cuda, + _assert_approx_equal_tensor_to_pil, + _assert_equal_tensor_to_pil, _create_data, _create_data_batch, - _assert_equal_tensor_to_pil, - _assert_approx_equal_tensor_to_pil, _test_fn_on_batch, assert_equal, + cpu_and_gpu, + needs_cuda, ) from torchvision.transforms import InterpolationMode diff --git a/test/test_image.py b/test/test_image.py index e4358f6f1e1..89374ebc8c5 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -8,21 +8,21 @@ import pytest import torch import torchvision.transforms.functional as F -from common_utils import needs_cuda, assert_equal -from PIL import Image, __version__ as PILLOW_VERSION +from common_utils import assert_equal, needs_cuda +from PIL import __version__ as PILLOW_VERSION, Image from torchvision.io.image import ( - decode_png, + _read_png_16, + decode_image, decode_jpeg, + decode_png, encode_jpeg, - write_jpeg, - decode_image, - read_file, encode_png, - write_png, - write_file, ImageReadMode, + read_file, read_image, - _read_png_16, + write_file, + write_jpeg, + write_png, ) IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets") @@ -168,7 +168,7 @@ def test_decode_png(img_path, pil_mode, mode): img_lpng = _read_png_16(img_path, mode=mode) assert img_lpng.dtype == torch.int32 # PIL converts 16 bits pngs in uint8 - img_lpng = torch.round(img_lpng / (2 ** 16 - 1) * 255).to(torch.uint8) + img_lpng = torch.round(img_lpng / (2**16 - 1) * 255).to(torch.uint8) else: data = read_file(img_path) img_lpng = decode_image(data, mode=mode) diff --git a/test/test_models.py b/test/test_models.py index 866fafae5f6..05bab11e479 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -14,7 +14,7 @@ import torch.fx import torch.nn as nn from _utils_internal import get_relative_path -from common_utils import map_nested_tensor_object, freeze_rng_state, set_rng_seed, cpu_and_gpu, needs_cuda +from common_utils import cpu_and_gpu, freeze_rng_state, map_nested_tensor_object, needs_cuda, set_rng_seed from torchvision import models ACCEPT = os.getenv("EXPECTTEST_ACCEPT", "0") == "1" diff --git a/test/test_models_detection_negative_samples.py b/test/test_models_detection_negative_samples.py index c4efbd96cf3..13db78d53fc 100644 --- a/test/test_models_detection_negative_samples.py +++ b/test/test_models_detection_negative_samples.py @@ -4,7 +4,7 @@ from common_utils import assert_equal from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, TwoMLPHead from torchvision.models.detection.roi_heads import RoIHeads -from torchvision.models.detection.rpn import AnchorGenerator, RPNHead, RegionProposalNetwork +from torchvision.models.detection.rpn import AnchorGenerator, RegionProposalNetwork, RPNHead from torchvision.ops import MultiScaleRoIAlign @@ -60,7 +60,7 @@ def test_assign_targets_to_proposals(self): resolution = box_roi_pool.output_size[0] representation_size = 1024 - box_head = TwoMLPHead(4 * resolution ** 2, representation_size) + box_head = TwoMLPHead(4 * resolution**2, representation_size) representation_size = 1024 box_predictor = FastRCNNPredictor(representation_size, 2) diff --git a/test/test_onnx.py b/test/test_onnx.py index ba0880a621d..d5dae64b4d0 100644 --- a/test/test_onnx.py +++ b/test/test_onnx.py @@ -4,13 +4,12 @@ import pytest import torch -from common_utils import set_rng_seed, assert_equal -from torchvision import models -from torchvision import ops +from common_utils import assert_equal, set_rng_seed +from torchvision import models, ops from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, TwoMLPHead from torchvision.models.detection.image_list import ImageList from torchvision.models.detection.roi_heads import RoIHeads -from torchvision.models.detection.rpn import AnchorGenerator, RPNHead, RegionProposalNetwork +from torchvision.models.detection.rpn import AnchorGenerator, RegionProposalNetwork, RPNHead from torchvision.models.detection.transform import GeneralizedRCNNTransform from torchvision.ops._register_onnx_ops import _onnx_opset_version @@ -265,7 +264,7 @@ def _init_test_roi_heads_faster_rcnn(self): resolution = box_roi_pool.output_size[0] representation_size = 1024 - box_head = TwoMLPHead(out_channels * resolution ** 2, representation_size) + box_head = TwoMLPHead(out_channels * resolution**2, representation_size) representation_size = 1024 box_predictor = FastRCNNPredictor(representation_size, num_classes) diff --git a/test/test_ops.py b/test/test_ops.py index 96cfb630e8d..bc4f9d19464 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -79,7 +79,7 @@ def test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, **kwar rois_dtype = self.dtype if rois_dtype is None else rois_dtype pool_size = 5 # n_channels % (pool_size ** 2) == 0 required for PS opeartions. - n_channels = 2 * (pool_size ** 2) + n_channels = 2 * (pool_size**2) x = torch.rand(2, n_channels, 10, 10, dtype=x_dtype, device=device) if not contiguous: x = x.permute(0, 1, 3, 2) @@ -115,7 +115,7 @@ def test_is_leaf_node(self, device): def test_backward(self, seed, device, contiguous): torch.random.manual_seed(seed) pool_size = 2 - x = torch.rand(1, 2 * (pool_size ** 2), 5, 5, dtype=self.dtype, device=device, requires_grad=True) + x = torch.rand(1, 2 * (pool_size**2), 5, 5, dtype=self.dtype, device=device, requires_grad=True) if not contiguous: x = x.permute(0, 1, 3, 2) rois = torch.tensor( diff --git a/test/test_prototype_builtin_datasets.py b/test/test_prototype_builtin_datasets.py index 5a8c9e7eff8..6ddba1806c6 100644 --- a/test/test_prototype_builtin_datasets.py +++ b/test/test_prototype_builtin_datasets.py @@ -5,14 +5,14 @@ import pytest import torch -from builtin_dataset_mocks import parametrize_dataset_mocks, DATASET_MOCKS -from torch.testing._comparison import assert_equal, TensorLikePair, ObjectPair +from builtin_dataset_mocks import DATASET_MOCKS, parametrize_dataset_mocks +from torch.testing._comparison import assert_equal, ObjectPair, TensorLikePair from torch.utils.data import DataLoader from torch.utils.data.graph import traverse from torch.utils.data.graph_settings import get_all_graph_pipes -from torchdata.datapipes.iter import Shuffler, ShardingFilter +from torchdata.datapipes.iter import ShardingFilter, Shuffler from torchvision._utils import sequence_to_str -from torchvision.prototype import transforms, datasets +from torchvision.prototype import datasets, transforms from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE from torchvision.prototype.features import Image, Label diff --git a/test/test_prototype_datasets_utils.py b/test/test_prototype_datasets_utils.py index 8790b1638f9..2098ac736ac 100644 --- a/test/test_prototype_datasets_utils.py +++ b/test/test_prototype_datasets_utils.py @@ -9,8 +9,8 @@ from torchdata.datapipes.iter import FileOpener, TarArchiveLoader from torchvision.datasets._optical_flow import _read_flo as read_flo_ref from torchvision.datasets.utils import _decompress -from torchvision.prototype.datasets.utils import HttpResource, GDriveResource, Dataset, OnlineResource -from torchvision.prototype.datasets.utils._internal import read_flo, fromfile +from torchvision.prototype.datasets.utils import Dataset, GDriveResource, HttpResource, OnlineResource +from torchvision.prototype.datasets.utils._internal import fromfile, read_flo @pytest.mark.filterwarnings("error:The given NumPy array is not writeable:UserWarning") diff --git a/test/test_prototype_models.py b/test/test_prototype_models.py index c76a84f8634..eefb1669901 100644 --- a/test/test_prototype_models.py +++ b/test/test_prototype_models.py @@ -2,7 +2,7 @@ import test_models as TM import torch import torchvision.prototype.models.depth.stereo.raft_stereo as raft_stereo -from common_utils import set_rng_seed, cpu_and_gpu +from common_utils import cpu_and_gpu, set_rng_seed @pytest.mark.parametrize("model_builder", (raft_stereo.raft_stereo_base, raft_stereo.raft_stereo_realtime)) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 51eba38c7a6..6ccf779129e 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -4,15 +4,15 @@ import torch from common_utils import assert_equal from test_prototype_transforms_functional import ( - make_images, - make_bounding_boxes, make_bounding_box, - make_one_hot_labels, + make_bounding_boxes, + make_images, make_label, + make_one_hot_labels, make_segmentation_mask, ) -from torchvision.prototype import transforms, features -from torchvision.transforms.functional import to_pil_image, pil_to_tensor, InterpolationMode +from torchvision.prototype import features, transforms +from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image def make_vanilla_tensor_images(*args, **kwargs): diff --git a/test/test_transforms.py b/test/test_transforms.py index e8eba1e3b48..47872cc0b68 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -24,7 +24,7 @@ except ImportError: stats = None -from common_utils import cycle_over, int_dtypes, float_dtypes, assert_equal +from common_utils import assert_equal, cycle_over, float_dtypes, int_dtypes GRACE_HOPPER = get_file_path_2( diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 7dc6dbd95d9..f4ca544deb8 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -6,19 +6,18 @@ import torch import torchvision.transforms._pil_constants as _pil_constants from common_utils import ( - get_tmp_dir, - int_dtypes, - float_dtypes, + _assert_approx_equal_tensor_to_pil, + _assert_equal_tensor_to_pil, _create_data, _create_data_batch, - _assert_equal_tensor_to_pil, - _assert_approx_equal_tensor_to_pil, - cpu_and_gpu, assert_equal, + cpu_and_gpu, + float_dtypes, + get_tmp_dir, + int_dtypes, ) from torchvision import transforms as T -from torchvision.transforms import InterpolationMode -from torchvision.transforms import functional as F +from torchvision.transforms import functional as F, InterpolationMode from torchvision.transforms.autoaugment import _apply_op NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC diff --git a/test/test_utils.py b/test/test_utils.py index 7cff53e98a3..dde3ee90dc3 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -10,7 +10,7 @@ import torchvision.transforms.functional as F import torchvision.utils as utils from common_utils import assert_equal -from PIL import Image, __version__ as PILLOW_VERSION, ImageColor +from PIL import __version__ as PILLOW_VERSION, Image, ImageColor PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split(".")) @@ -45,8 +45,8 @@ def test_normalize_in_make_grid(): # Rounding the result to one decimal for comparison n_digits = 1 - rounded_grid_max = torch.round(grid_max * 10 ** n_digits) / (10 ** n_digits) - rounded_grid_min = torch.round(grid_min * 10 ** n_digits) / (10 ** n_digits) + rounded_grid_max = torch.round(grid_max * 10**n_digits) / (10**n_digits) + rounded_grid_min = torch.round(grid_min * 10**n_digits) / (10**n_digits) assert_equal(norm_max, rounded_grid_max, msg="Normalized max is not equal to 1") assert_equal(norm_min, rounded_grid_min, msg="Normalized min is not equal to 0") diff --git a/torchvision/__init__.py b/torchvision/__init__.py index 32b522cbc42..739f79407b3 100644 --- a/torchvision/__init__.py +++ b/torchvision/__init__.py @@ -2,12 +2,7 @@ import warnings import torch -from torchvision import datasets -from torchvision import io -from torchvision import models -from torchvision import ops -from torchvision import transforms -from torchvision import utils +from torchvision import datasets, io, models, ops, transforms, utils from .extension import _HAS_OPS diff --git a/torchvision/_utils.py b/torchvision/_utils.py index 8e8fe1b8a83..b739ef0966e 100644 --- a/torchvision/_utils.py +++ b/torchvision/_utils.py @@ -1,5 +1,5 @@ import enum -from typing import Sequence, TypeVar, Type +from typing import Sequence, Type, TypeVar T = TypeVar("T", bound=enum.Enum) diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py index 295fe922478..099d10da35d 100644 --- a/torchvision/datasets/__init__.py +++ b/torchvision/datasets/__init__.py @@ -1,4 +1,4 @@ -from ._optical_flow import KittiFlow, Sintel, FlyingChairs, FlyingThings3D, HD1K +from ._optical_flow import FlyingChairs, FlyingThings3D, HD1K, KittiFlow, Sintel from .caltech import Caltech101, Caltech256 from .celeba import CelebA from .cifar import CIFAR10, CIFAR100 @@ -11,19 +11,19 @@ from .fakedata import FakeData from .fer2013 import FER2013 from .fgvc_aircraft import FGVCAircraft -from .flickr import Flickr8k, Flickr30k +from .flickr import Flickr30k, Flickr8k from .flowers102 import Flowers102 -from .folder import ImageFolder, DatasetFolder +from .folder import DatasetFolder, ImageFolder from .food101 import Food101 from .gtsrb import GTSRB from .hmdb51 import HMDB51 from .imagenet import ImageNet from .inaturalist import INaturalist -from .kinetics import Kinetics400, Kinetics +from .kinetics import Kinetics, Kinetics400 from .kitti import Kitti -from .lfw import LFWPeople, LFWPairs +from .lfw import LFWPairs, LFWPeople from .lsun import LSUN, LSUNClass -from .mnist import MNIST, EMNIST, FashionMNIST, KMNIST, QMNIST +from .mnist import EMNIST, FashionMNIST, KMNIST, MNIST, QMNIST from .omniglot import Omniglot from .oxford_iiit_pet import OxfordIIITPet from .pcam import PCAM @@ -40,7 +40,7 @@ from .ucf101 import UCF101 from .usps import USPS from .vision import VisionDataset -from .voc import VOCSegmentation, VOCDetection +from .voc import VOCDetection, VOCSegmentation from .widerface import WIDERFace __all__ = ( diff --git a/torchvision/datasets/_optical_flow.py b/torchvision/datasets/_optical_flow.py index 8a36c1b8d04..bc26f51dc75 100644 --- a/torchvision/datasets/_optical_flow.py +++ b/torchvision/datasets/_optical_flow.py @@ -9,7 +9,7 @@ from PIL import Image from ..io.image import _read_png_16 -from .utils import verify_str_arg, _read_pfm +from .utils import _read_pfm, verify_str_arg from .vision import VisionDataset @@ -466,7 +466,7 @@ def _read_16bits_png_with_flow_and_valid_mask(file_name): flow_and_valid = _read_png_16(file_name).to(torch.float32) flow, valid_flow_mask = flow_and_valid[:2, :, :], flow_and_valid[2, :, :] - flow = (flow - 2 ** 15) / 64 # This conversion is explained somewhere on the kitti archive + flow = (flow - 2**15) / 64 # This conversion is explained somewhere on the kitti archive valid_flow_mask = valid_flow_mask.bool() # For consistency with other datasets, we convert to numpy diff --git a/torchvision/datasets/caltech.py b/torchvision/datasets/caltech.py index e95043ce2de..3a9635dfe09 100644 --- a/torchvision/datasets/caltech.py +++ b/torchvision/datasets/caltech.py @@ -1,6 +1,6 @@ import os import os.path -from typing import Any, Callable, List, Optional, Union, Tuple +from typing import Any, Callable, List, Optional, Tuple, Union from PIL import Image diff --git a/torchvision/datasets/celeba.py b/torchvision/datasets/celeba.py index e9dd883b92e..dbacece88c9 100644 --- a/torchvision/datasets/celeba.py +++ b/torchvision/datasets/celeba.py @@ -1,12 +1,12 @@ import csv import os from collections import namedtuple -from typing import Any, Callable, List, Optional, Union, Tuple +from typing import Any, Callable, List, Optional, Tuple, Union import PIL import torch -from .utils import download_file_from_google_drive, check_integrity, verify_str_arg, extract_archive +from .utils import check_integrity, download_file_from_google_drive, extract_archive, verify_str_arg from .vision import VisionDataset CSV = namedtuple("CSV", ["header", "index", "data"]) diff --git a/torchvision/datasets/cityscapes.py b/torchvision/datasets/cityscapes.py index b07c093e10c..86d65c7c091 100644 --- a/torchvision/datasets/cityscapes.py +++ b/torchvision/datasets/cityscapes.py @@ -1,11 +1,11 @@ import json import os from collections import namedtuple -from typing import Any, Callable, Dict, List, Optional, Union, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple, Union from PIL import Image -from .utils import extract_archive, verify_str_arg, iterable_to_str +from .utils import extract_archive, iterable_to_str, verify_str_arg from .vision import VisionDataset diff --git a/torchvision/datasets/clevr.py b/torchvision/datasets/clevr.py index 112765a6b5d..94e261e3355 100644 --- a/torchvision/datasets/clevr.py +++ b/torchvision/datasets/clevr.py @@ -1,6 +1,6 @@ import json import pathlib -from typing import Any, Callable, Optional, Tuple, List +from typing import Any, Callable, List, Optional, Tuple from urllib.parse import urlparse from PIL import Image diff --git a/torchvision/datasets/coco.py b/torchvision/datasets/coco.py index 9bb8bda67d1..f53aba16e5f 100644 --- a/torchvision/datasets/coco.py +++ b/torchvision/datasets/coco.py @@ -1,5 +1,5 @@ import os.path -from typing import Any, Callable, Optional, Tuple, List +from typing import Any, Callable, List, Optional, Tuple from PIL import Image diff --git a/torchvision/datasets/country211.py b/torchvision/datasets/country211.py index b5c650cb276..9a62520fe2b 100644 --- a/torchvision/datasets/country211.py +++ b/torchvision/datasets/country211.py @@ -2,7 +2,7 @@ from typing import Callable, Optional from .folder import ImageFolder -from .utils import verify_str_arg, download_and_extract_archive +from .utils import download_and_extract_archive, verify_str_arg class Country211(ImageFolder): diff --git a/torchvision/datasets/dtd.py b/torchvision/datasets/dtd.py index deb27312573..2d8314346b9 100644 --- a/torchvision/datasets/dtd.py +++ b/torchvision/datasets/dtd.py @@ -1,10 +1,10 @@ import os import pathlib -from typing import Optional, Callable +from typing import Callable, Optional import PIL.Image -from .utils import verify_str_arg, download_and_extract_archive +from .utils import download_and_extract_archive, verify_str_arg from .vision import VisionDataset diff --git a/torchvision/datasets/fer2013.py b/torchvision/datasets/fer2013.py index 60cbfd9bf28..bcd20c1e4a2 100644 --- a/torchvision/datasets/fer2013.py +++ b/torchvision/datasets/fer2013.py @@ -5,7 +5,7 @@ import torch from PIL import Image -from .utils import verify_str_arg, check_integrity +from .utils import check_integrity, verify_str_arg from .vision import VisionDataset diff --git a/torchvision/datasets/flowers102.py b/torchvision/datasets/flowers102.py index 97a8fb416ba..ad3a6dda0e8 100644 --- a/torchvision/datasets/flowers102.py +++ b/torchvision/datasets/flowers102.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Any, Tuple, Callable, Optional +from typing import Any, Callable, Optional, Tuple import PIL.Image diff --git a/torchvision/datasets/folder.py b/torchvision/datasets/folder.py index d5a7e88083b..40d5e26d242 100644 --- a/torchvision/datasets/folder.py +++ b/torchvision/datasets/folder.py @@ -1,7 +1,6 @@ import os import os.path -from typing import Any, Callable, cast, Dict, List, Optional, Tuple -from typing import Union +from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union from PIL import Image diff --git a/torchvision/datasets/food101.py b/torchvision/datasets/food101.py index 1bb4d8094d5..aa405eedcf9 100644 --- a/torchvision/datasets/food101.py +++ b/torchvision/datasets/food101.py @@ -1,10 +1,10 @@ import json from pathlib import Path -from typing import Any, Tuple, Callable, Optional +from typing import Any, Callable, Optional, Tuple import PIL.Image -from .utils import verify_str_arg, download_and_extract_archive +from .utils import download_and_extract_archive, verify_str_arg from .vision import VisionDataset diff --git a/torchvision/datasets/hmdb51.py b/torchvision/datasets/hmdb51.py index f7341f4aa30..9067418d847 100644 --- a/torchvision/datasets/hmdb51.py +++ b/torchvision/datasets/hmdb51.py @@ -1,6 +1,6 @@ import glob import os -from typing import Optional, Callable, Tuple, Dict, Any, List +from typing import Any, Callable, Dict, List, Optional, Tuple from torch import Tensor diff --git a/torchvision/datasets/imagenet.py b/torchvision/datasets/imagenet.py index a272bb86e57..4b86bf2f2b9 100644 --- a/torchvision/datasets/imagenet.py +++ b/torchvision/datasets/imagenet.py @@ -2,7 +2,7 @@ import shutil import tempfile from contextlib import contextmanager -from typing import Any, Dict, List, Iterator, Optional, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple import torch diff --git a/torchvision/datasets/inaturalist.py b/torchvision/datasets/inaturalist.py index 7d5fc279820..50b32ef0f4a 100644 --- a/torchvision/datasets/inaturalist.py +++ b/torchvision/datasets/inaturalist.py @@ -1,6 +1,6 @@ import os import os.path -from typing import Any, Callable, Dict, List, Optional, Union, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple, Union from PIL import Image diff --git a/torchvision/datasets/kinetics.py b/torchvision/datasets/kinetics.py index 2ba5e50845e..9352355522d 100644 --- a/torchvision/datasets/kinetics.py +++ b/torchvision/datasets/kinetics.py @@ -11,7 +11,7 @@ from torch import Tensor from .folder import find_classes, make_dataset -from .utils import download_and_extract_archive, download_url, verify_str_arg, check_integrity +from .utils import check_integrity, download_and_extract_archive, download_url, verify_str_arg from .video_utils import VideoClips from .vision import VisionDataset diff --git a/torchvision/datasets/lsun.py b/torchvision/datasets/lsun.py index c290e6dc0e8..a936351cdcc 100644 --- a/torchvision/datasets/lsun.py +++ b/torchvision/datasets/lsun.py @@ -7,7 +7,7 @@ from PIL import Image -from .utils import verify_str_arg, iterable_to_str +from .utils import iterable_to_str, verify_str_arg from .vision import VisionDataset diff --git a/torchvision/datasets/mnist.py b/torchvision/datasets/mnist.py index 9f9ec457499..fd742544935 100644 --- a/torchvision/datasets/mnist.py +++ b/torchvision/datasets/mnist.py @@ -12,7 +12,7 @@ import torch from PIL import Image -from .utils import download_and_extract_archive, extract_archive, verify_str_arg, check_integrity +from .utils import check_integrity, download_and_extract_archive, extract_archive, verify_str_arg from .vision import VisionDataset diff --git a/torchvision/datasets/omniglot.py b/torchvision/datasets/omniglot.py index 5a09d61ccca..41d18c1bdd5 100644 --- a/torchvision/datasets/omniglot.py +++ b/torchvision/datasets/omniglot.py @@ -3,7 +3,7 @@ from PIL import Image -from .utils import download_and_extract_archive, check_integrity, list_dir, list_files +from .utils import check_integrity, download_and_extract_archive, list_dir, list_files from .vision import VisionDataset diff --git a/torchvision/datasets/oxford_iiit_pet.py b/torchvision/datasets/oxford_iiit_pet.py index 733aa78256b..667ee13717d 100644 --- a/torchvision/datasets/oxford_iiit_pet.py +++ b/torchvision/datasets/oxford_iiit_pet.py @@ -1,8 +1,7 @@ import os import os.path import pathlib -from typing import Any, Callable, Optional, Union, Tuple -from typing import Sequence +from typing import Any, Callable, Optional, Sequence, Tuple, Union from PIL import Image diff --git a/torchvision/datasets/pcam.py b/torchvision/datasets/pcam.py index 4f124674961..63faf721a0f 100644 --- a/torchvision/datasets/pcam.py +++ b/torchvision/datasets/pcam.py @@ -3,7 +3,7 @@ from PIL import Image -from .utils import download_file_from_google_drive, _decompress, verify_str_arg +from .utils import _decompress, download_file_from_google_drive, verify_str_arg from .vision import VisionDataset diff --git a/torchvision/datasets/places365.py b/torchvision/datasets/places365.py index dd11d7331ae..c26b6f03074 100644 --- a/torchvision/datasets/places365.py +++ b/torchvision/datasets/places365.py @@ -4,7 +4,7 @@ from urllib.parse import urljoin from .folder import default_loader -from .utils import verify_str_arg, check_integrity, download_and_extract_archive +from .utils import check_integrity, download_and_extract_archive, verify_str_arg from .vision import VisionDataset diff --git a/torchvision/datasets/rendered_sst2.py b/torchvision/datasets/rendered_sst2.py index 02445dddb05..89adf8cf8d8 100644 --- a/torchvision/datasets/rendered_sst2.py +++ b/torchvision/datasets/rendered_sst2.py @@ -1,10 +1,10 @@ from pathlib import Path -from typing import Any, Tuple, Callable, Optional +from typing import Any, Callable, Optional, Tuple import PIL.Image from .folder import make_dataset -from .utils import verify_str_arg, download_and_extract_archive +from .utils import download_and_extract_archive, verify_str_arg from .vision import VisionDataset diff --git a/torchvision/datasets/samplers/__init__.py b/torchvision/datasets/samplers/__init__.py index 861a029a9ec..58b2d2abd93 100644 --- a/torchvision/datasets/samplers/__init__.py +++ b/torchvision/datasets/samplers/__init__.py @@ -1,3 +1,3 @@ -from .clip_sampler import DistributedSampler, UniformClipSampler, RandomClipSampler +from .clip_sampler import DistributedSampler, RandomClipSampler, UniformClipSampler __all__ = ("DistributedSampler", "UniformClipSampler", "RandomClipSampler") diff --git a/torchvision/datasets/samplers/clip_sampler.py b/torchvision/datasets/samplers/clip_sampler.py index f4975f8c021..026c3d75d3b 100644 --- a/torchvision/datasets/samplers/clip_sampler.py +++ b/torchvision/datasets/samplers/clip_sampler.py @@ -1,5 +1,5 @@ import math -from typing import Optional, List, Iterator, Sized, Union, cast +from typing import cast, Iterator, List, Optional, Sized, Union import torch import torch.distributed as dist diff --git a/torchvision/datasets/sbd.py b/torchvision/datasets/sbd.py index 030643dc794..8399d025b1b 100644 --- a/torchvision/datasets/sbd.py +++ b/torchvision/datasets/sbd.py @@ -5,7 +5,7 @@ import numpy as np from PIL import Image -from .utils import download_url, verify_str_arg, download_and_extract_archive +from .utils import download_and_extract_archive, download_url, verify_str_arg from .vision import VisionDataset diff --git a/torchvision/datasets/sbu.py b/torchvision/datasets/sbu.py index cd483a46190..6bfe0b88cba 100644 --- a/torchvision/datasets/sbu.py +++ b/torchvision/datasets/sbu.py @@ -3,7 +3,7 @@ from PIL import Image -from .utils import download_url, check_integrity +from .utils import check_integrity, download_url from .vision import VisionDataset diff --git a/torchvision/datasets/semeion.py b/torchvision/datasets/semeion.py index eb9ee247f13..c47703afbde 100644 --- a/torchvision/datasets/semeion.py +++ b/torchvision/datasets/semeion.py @@ -4,7 +4,7 @@ import numpy as np from PIL import Image -from .utils import download_url, check_integrity +from .utils import check_integrity, download_url from .vision import VisionDataset diff --git a/torchvision/datasets/stanford_cars.py b/torchvision/datasets/stanford_cars.py index daca0b0b46a..3e9430ef214 100644 --- a/torchvision/datasets/stanford_cars.py +++ b/torchvision/datasets/stanford_cars.py @@ -1,5 +1,5 @@ import pathlib -from typing import Callable, Optional, Any, Tuple +from typing import Any, Callable, Optional, Tuple from PIL import Image diff --git a/torchvision/datasets/stl10.py b/torchvision/datasets/stl10.py index 1ef50cf0a24..8a906619a9d 100644 --- a/torchvision/datasets/stl10.py +++ b/torchvision/datasets/stl10.py @@ -1,5 +1,5 @@ import os.path -from typing import Any, Callable, Optional, Tuple, cast +from typing import Any, Callable, cast, Optional, Tuple import numpy as np from PIL import Image diff --git a/torchvision/datasets/sun397.py b/torchvision/datasets/sun397.py index cc3457fb16f..05cb910dde8 100644 --- a/torchvision/datasets/sun397.py +++ b/torchvision/datasets/sun397.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Any, Tuple, Callable, Optional +from typing import Any, Callable, Optional, Tuple import PIL.Image diff --git a/torchvision/datasets/svhn.py b/torchvision/datasets/svhn.py index 8a2e5839971..facb2d8858e 100644 --- a/torchvision/datasets/svhn.py +++ b/torchvision/datasets/svhn.py @@ -4,7 +4,7 @@ import numpy as np from PIL import Image -from .utils import download_url, check_integrity, verify_str_arg +from .utils import check_integrity, download_url, verify_str_arg from .vision import VisionDataset diff --git a/torchvision/datasets/ucf101.py b/torchvision/datasets/ucf101.py index 4ee5f1f3df9..c82b509e535 100644 --- a/torchvision/datasets/ucf101.py +++ b/torchvision/datasets/ucf101.py @@ -1,5 +1,5 @@ import os -from typing import Any, Dict, List, Tuple, Optional, Callable +from typing import Any, Callable, Dict, List, Optional, Tuple from torch import Tensor diff --git a/torchvision/datasets/utils.py b/torchvision/datasets/utils.py index b14f25e986b..30506b3fc79 100644 --- a/torchvision/datasets/utils.py +++ b/torchvision/datasets/utils.py @@ -15,7 +15,7 @@ import urllib.request import warnings import zipfile -from typing import Any, Callable, List, Iterable, Optional, TypeVar, Dict, IO, Tuple, Iterator +from typing import Any, Callable, Dict, IO, Iterable, Iterator, List, Optional, Tuple, TypeVar from urllib.parse import urlparse import numpy as np @@ -23,10 +23,7 @@ import torch from torch.utils.model_zoo import tqdm -from .._internally_replaced_utils import ( - _download_file_from_remote_location, - _is_remote_location_available, -) +from .._internally_replaced_utils import _download_file_from_remote_location, _is_remote_location_available USER_AGENT = "pytorch/vision" diff --git a/torchvision/datasets/video_utils.py b/torchvision/datasets/video_utils.py index 3fdd50d19c7..c4890ff4416 100644 --- a/torchvision/datasets/video_utils.py +++ b/torchvision/datasets/video_utils.py @@ -2,15 +2,10 @@ import math import warnings from fractions import Fraction -from typing import Any, Dict, List, Optional, Callable, Union, Tuple, TypeVar, cast +from typing import Any, Callable, cast, Dict, List, Optional, Tuple, TypeVar, Union import torch -from torchvision.io import ( - _probe_video_from_file, - _read_video_from_file, - read_video, - read_video_timestamps, -) +from torchvision.io import _probe_video_from_file, _read_video_from_file, read_video, read_video_timestamps from .utils import tqdm diff --git a/torchvision/datasets/voc.py b/torchvision/datasets/voc.py index 3448d62702c..32888cd5c8c 100644 --- a/torchvision/datasets/voc.py +++ b/torchvision/datasets/voc.py @@ -9,7 +9,7 @@ except ImportError: from xml.etree.ElementTree import parse as ET_parse import warnings -from typing import Any, Callable, Dict, Optional, Tuple, List +from typing import Any, Callable, Dict, List, Optional, Tuple from PIL import Image diff --git a/torchvision/datasets/widerface.py b/torchvision/datasets/widerface.py index a0f1e1fe285..b46c7982d8b 100644 --- a/torchvision/datasets/widerface.py +++ b/torchvision/datasets/widerface.py @@ -1,16 +1,11 @@ import os from os.path import abspath, expanduser -from typing import Any, Callable, List, Dict, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from PIL import Image -from .utils import ( - download_file_from_google_drive, - download_and_extract_archive, - extract_archive, - verify_str_arg, -) +from .utils import download_and_extract_archive, download_file_from_google_drive, extract_archive, verify_str_arg from .vision import VisionDataset diff --git a/torchvision/extension.py b/torchvision/extension.py index ae1da9c0d04..3bad8351b23 100644 --- a/torchvision/extension.py +++ b/torchvision/extension.py @@ -23,7 +23,6 @@ def _has_ops(): def _has_ops(): # noqa: F811 return True - except (ImportError, OSError): pass diff --git a/torchvision/io/__init__.py b/torchvision/io/__init__.py index 22788cef71e..ba7d4f69f26 100644 --- a/torchvision/io/__init__.py +++ b/torchvision/io/__init__.py @@ -9,8 +9,6 @@ except ModuleNotFoundError: _HAS_GPU_VIDEO_DECODER = False from ._video_opt import ( - Timebase, - VideoMetaData, _HAS_VIDEO_OPT, _probe_video_from_file, _probe_video_from_memory, @@ -18,25 +16,23 @@ _read_video_from_memory, _read_video_timestamps_from_file, _read_video_timestamps_from_memory, + Timebase, + VideoMetaData, ) from .image import ( - ImageReadMode, decode_image, decode_jpeg, decode_png, encode_jpeg, encode_png, + ImageReadMode, read_file, read_image, write_file, write_jpeg, write_png, ) -from .video import ( - read_video, - read_video_timestamps, - write_video, -) +from .video import read_video, read_video_timestamps, write_video from .video_reader import VideoReader diff --git a/torchvision/io/_video_opt.py b/torchvision/io/_video_opt.py index 055b195a8f4..b598196d413 100644 --- a/torchvision/io/_video_opt.py +++ b/torchvision/io/_video_opt.py @@ -1,7 +1,7 @@ import math import warnings from fractions import Fraction -from typing import List, Tuple, Dict, Optional, Union +from typing import Dict, List, Optional, Tuple, Union import torch diff --git a/torchvision/io/video_reader.py b/torchvision/io/video_reader.py index 881b9d75bd4..c2ffa049d31 100644 --- a/torchvision/io/video_reader.py +++ b/torchvision/io/video_reader.py @@ -8,16 +8,13 @@ from ._load_gpu_decoder import _HAS_GPU_VIDEO_DECODER except ModuleNotFoundError: _HAS_GPU_VIDEO_DECODER = False -from ._video_opt import ( - _HAS_VIDEO_OPT, -) +from ._video_opt import _HAS_VIDEO_OPT if _HAS_VIDEO_OPT: def _has_video_opt() -> bool: return True - else: def _has_video_opt() -> bool: diff --git a/torchvision/models/__init__.py b/torchvision/models/__init__.py index 00b5ebefe55..7bca0276c34 100644 --- a/torchvision/models/__init__.py +++ b/torchvision/models/__init__.py @@ -13,9 +13,5 @@ from .vgg import * from .vision_transformer import * from .swin_transformer import * -from . import detection -from . import optical_flow -from . import quantization -from . import segmentation -from . import video +from . import detection, optical_flow, quantization, segmentation, video from ._api import get_weight diff --git a/torchvision/models/_api.py b/torchvision/models/_api.py index 7c6530d66c4..901bb0015e4 100644 --- a/torchvision/models/_api.py +++ b/torchvision/models/_api.py @@ -3,7 +3,7 @@ import sys from dataclasses import dataclass, fields from inspect import signature -from typing import Any, Callable, Dict, Mapping, cast +from typing import Any, Callable, cast, Dict, Mapping from torchvision._utils import StrEnum diff --git a/torchvision/models/_utils.py b/torchvision/models/_utils.py index c565f611999..5d930e60295 100644 --- a/torchvision/models/_utils.py +++ b/torchvision/models/_utils.py @@ -2,7 +2,7 @@ import inspect import warnings from collections import OrderedDict -from typing import Any, Dict, Optional, TypeVar, Callable, Tuple, Union +from typing import Any, Callable, Dict, Optional, Tuple, TypeVar, Union from torch import nn diff --git a/torchvision/models/alexnet.py b/torchvision/models/alexnet.py index 6c461a501c9..5d1401dcb36 100644 --- a/torchvision/models/alexnet.py +++ b/torchvision/models/alexnet.py @@ -6,9 +6,9 @@ from ..transforms._presets import ImageClassification from ..utils import _log_api_usage_once -from ._api import WeightsEnum, Weights +from ._api import Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param +from ._utils import _ovewrite_named_param, handle_legacy_interface __all__ = ["AlexNet", "AlexNet_Weights", "alexnet"] diff --git a/torchvision/models/convnext.py b/torchvision/models/convnext.py index 435789ca0e2..5b79e5934f4 100644 --- a/torchvision/models/convnext.py +++ b/torchvision/models/convnext.py @@ -9,9 +9,9 @@ from ..ops.stochastic_depth import StochasticDepth from ..transforms._presets import ImageClassification from ..utils import _log_api_usage_once -from ._api import WeightsEnum, Weights +from ._api import Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param +from ._utils import _ovewrite_named_param, handle_legacy_interface __all__ = [ diff --git a/torchvision/models/densenet.py b/torchvision/models/densenet.py index e8a66f5771b..8eaac615c86 100644 --- a/torchvision/models/densenet.py +++ b/torchvision/models/densenet.py @@ -11,9 +11,9 @@ from ..transforms._presets import ImageClassification from ..utils import _log_api_usage_once -from ._api import WeightsEnum, Weights +from ._api import Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param +from ._utils import _ovewrite_named_param, handle_legacy_interface __all__ = [ diff --git a/torchvision/models/detection/_utils.py b/torchvision/models/detection/_utils.py index 12b3784099f..7d28e96d305 100644 --- a/torchvision/models/detection/_utils.py +++ b/torchvision/models/detection/_utils.py @@ -3,9 +3,9 @@ from typing import Dict, List, Optional, Tuple import torch -from torch import Tensor, nn +from torch import nn, Tensor from torch.nn import functional as F -from torchvision.ops import FrozenBatchNorm2d, complete_box_iou_loss, distance_box_iou_loss, generalized_box_iou_loss +from torchvision.ops import complete_box_iou_loss, distance_box_iou_loss, FrozenBatchNorm2d, generalized_box_iou_loss class BalancedPositiveNegativeSampler: diff --git a/torchvision/models/detection/backbone_utils.py b/torchvision/models/detection/backbone_utils.py index fbef524b99c..4941d7ec440 100644 --- a/torchvision/models/detection/backbone_utils.py +++ b/torchvision/models/detection/backbone_utils.py @@ -6,8 +6,8 @@ from torchvision.ops.feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool from .. import mobilenet, resnet -from .._api import WeightsEnum, _get_enum_from_fn -from .._utils import IntermediateLayerGetter, handle_legacy_interface +from .._api import _get_enum_from_fn, WeightsEnum +from .._utils import handle_legacy_interface, IntermediateLayerGetter class BackboneWithFPN(nn.Module): diff --git a/torchvision/models/detection/faster_rcnn.py b/torchvision/models/detection/faster_rcnn.py index fb98ca86b34..de46aadfe4f 100644 --- a/torchvision/models/detection/faster_rcnn.py +++ b/torchvision/models/detection/faster_rcnn.py @@ -7,17 +7,17 @@ from ...ops import misc as misc_nn_ops from ...transforms._presets import ObjectDetection -from .._api import WeightsEnum, Weights +from .._api import Weights, WeightsEnum from .._meta import _COCO_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_value_param -from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large -from ..resnet import ResNet50_Weights, resnet50 +from .._utils import _ovewrite_value_param, handle_legacy_interface +from ..mobilenetv3 import mobilenet_v3_large, MobileNet_V3_Large_Weights +from ..resnet import resnet50, ResNet50_Weights from ._utils import overwrite_eps from .anchor_utils import AnchorGenerator -from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers, _mobilenet_extractor +from .backbone_utils import _mobilenet_extractor, _resnet_fpn_extractor, _validate_trainable_layers from .generalized_rcnn import GeneralizedRCNN from .roi_heads import RoIHeads -from .rpn import RPNHead, RegionProposalNetwork +from .rpn import RegionProposalNetwork, RPNHead from .transform import GeneralizedRCNNTransform @@ -250,7 +250,7 @@ def __init__( if box_head is None: resolution = box_roi_pool.output_size[0] representation_size = 1024 - box_head = TwoMLPHead(out_channels * resolution ** 2, representation_size) + box_head = TwoMLPHead(out_channels * resolution**2, representation_size) if box_predictor is None: representation_size = 1024 diff --git a/torchvision/models/detection/fcos.py b/torchvision/models/detection/fcos.py index 9851b7f7c05..efaac721328 100644 --- a/torchvision/models/detection/fcos.py +++ b/torchvision/models/detection/fcos.py @@ -2,21 +2,19 @@ import warnings from collections import OrderedDict from functools import partial -from typing import Any, Callable, Dict, List, Tuple, Optional +from typing import Any, Callable, Dict, List, Optional, Tuple import torch from torch import nn, Tensor -from ...ops import sigmoid_focal_loss, generalized_box_iou_loss -from ...ops import boxes as box_ops -from ...ops import misc as misc_nn_ops +from ...ops import boxes as box_ops, generalized_box_iou_loss, misc as misc_nn_ops, sigmoid_focal_loss from ...ops.feature_pyramid_network import LastLevelP6P7 from ...transforms._presets import ObjectDetection from ...utils import _log_api_usage_once -from .._api import WeightsEnum, Weights +from .._api import Weights, WeightsEnum from .._meta import _COCO_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_value_param -from ..resnet import ResNet50_Weights, resnet50 +from .._utils import _ovewrite_value_param, handle_legacy_interface +from ..resnet import resnet50, ResNet50_Weights from . import _utils as det_utils from .anchor_utils import AnchorGenerator from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers diff --git a/torchvision/models/detection/generalized_rcnn.py b/torchvision/models/detection/generalized_rcnn.py index fdcaea5a3eb..b481265077f 100644 --- a/torchvision/models/detection/generalized_rcnn.py +++ b/torchvision/models/detection/generalized_rcnn.py @@ -4,7 +4,7 @@ import warnings from collections import OrderedDict -from typing import Tuple, List, Dict, Optional, Union +from typing import Dict, List, Optional, Tuple, Union import torch from torch import nn, Tensor diff --git a/torchvision/models/detection/keypoint_rcnn.py b/torchvision/models/detection/keypoint_rcnn.py index 0052e49409c..f4044a2c1a2 100644 --- a/torchvision/models/detection/keypoint_rcnn.py +++ b/torchvision/models/detection/keypoint_rcnn.py @@ -6,10 +6,10 @@ from ...ops import misc as misc_nn_ops from ...transforms._presets import ObjectDetection -from .._api import WeightsEnum, Weights +from .._api import Weights, WeightsEnum from .._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES -from .._utils import handle_legacy_interface, _ovewrite_value_param -from ..resnet import ResNet50_Weights, resnet50 +from .._utils import _ovewrite_value_param, handle_legacy_interface +from ..resnet import resnet50, ResNet50_Weights from ._utils import overwrite_eps from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers from .faster_rcnn import FasterRCNN diff --git a/torchvision/models/detection/mask_rcnn.py b/torchvision/models/detection/mask_rcnn.py index 66dde13adff..422bacd135b 100644 --- a/torchvision/models/detection/mask_rcnn.py +++ b/torchvision/models/detection/mask_rcnn.py @@ -6,13 +6,13 @@ from ...ops import misc as misc_nn_ops from ...transforms._presets import ObjectDetection -from .._api import WeightsEnum, Weights +from .._api import Weights, WeightsEnum from .._meta import _COCO_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_value_param -from ..resnet import ResNet50_Weights, resnet50 +from .._utils import _ovewrite_value_param, handle_legacy_interface +from ..resnet import resnet50, ResNet50_Weights from ._utils import overwrite_eps from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers -from .faster_rcnn import FasterRCNN, FastRCNNConvFCHead, RPNHead, _default_anchorgen +from .faster_rcnn import _default_anchorgen, FasterRCNN, FastRCNNConvFCHead, RPNHead __all__ = [ diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index 18e6b432a4f..57c75354389 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -2,23 +2,21 @@ import warnings from collections import OrderedDict from functools import partial -from typing import Any, Callable, Dict, List, Tuple, Optional +from typing import Any, Callable, Dict, List, Optional, Tuple import torch from torch import nn, Tensor -from ...ops import sigmoid_focal_loss -from ...ops import boxes as box_ops -from ...ops import misc as misc_nn_ops +from ...ops import boxes as box_ops, misc as misc_nn_ops, sigmoid_focal_loss from ...ops.feature_pyramid_network import LastLevelP6P7 from ...transforms._presets import ObjectDetection from ...utils import _log_api_usage_once -from .._api import WeightsEnum, Weights +from .._api import Weights, WeightsEnum from .._meta import _COCO_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_value_param -from ..resnet import ResNet50_Weights, resnet50 +from .._utils import _ovewrite_value_param, handle_legacy_interface +from ..resnet import resnet50, ResNet50_Weights from . import _utils as det_utils -from ._utils import overwrite_eps, _box_loss +from ._utils import _box_loss, overwrite_eps from .anchor_utils import AnchorGenerator from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers from .transform import GeneralizedRCNNTransform diff --git a/torchvision/models/detection/roi_heads.py b/torchvision/models/detection/roi_heads.py index d2abebfca68..18a6782a06b 100644 --- a/torchvision/models/detection/roi_heads.py +++ b/torchvision/models/detection/roi_heads.py @@ -1,4 +1,4 @@ -from typing import Optional, List, Dict, Tuple +from typing import Dict, List, Optional, Tuple import torch import torch.nn.functional as F diff --git a/torchvision/models/detection/rpn.py b/torchvision/models/detection/rpn.py index 39f82ca323b..07a8b931150 100644 --- a/torchvision/models/detection/rpn.py +++ b/torchvision/models/detection/rpn.py @@ -1,10 +1,9 @@ -from typing import List, Optional, Dict, Tuple +from typing import Dict, List, Optional, Tuple import torch from torch import nn, Tensor from torch.nn import functional as F -from torchvision.ops import Conv2dNormActivation -from torchvision.ops import boxes as box_ops +from torchvision.ops import boxes as box_ops, Conv2dNormActivation from . import _utils as det_utils @@ -322,15 +321,12 @@ def compute_loss( labels = torch.cat(labels, dim=0) regression_targets = torch.cat(regression_targets, dim=0) - box_loss = ( - F.smooth_l1_loss( - pred_bbox_deltas[sampled_pos_inds], - regression_targets[sampled_pos_inds], - beta=1 / 9, - reduction="sum", - ) - / (sampled_inds.numel()) - ) + box_loss = F.smooth_l1_loss( + pred_bbox_deltas[sampled_pos_inds], + regression_targets[sampled_pos_inds], + beta=1 / 9, + reduction="sum", + ) / (sampled_inds.numel()) objectness_loss = F.binary_cross_entropy_with_logits(objectness[sampled_inds], labels[sampled_inds]) diff --git a/torchvision/models/detection/ssd.py b/torchvision/models/detection/ssd.py index bcbea25d6d7..1a926116450 100644 --- a/torchvision/models/detection/ssd.py +++ b/torchvision/models/detection/ssd.py @@ -9,10 +9,10 @@ from ...ops import boxes as box_ops from ...transforms._presets import ObjectDetection from ...utils import _log_api_usage_once -from .._api import WeightsEnum, Weights +from .._api import Weights, WeightsEnum from .._meta import _COCO_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_value_param -from ..vgg import VGG, VGG16_Weights, vgg16 +from .._utils import _ovewrite_value_param, handle_legacy_interface +from ..vgg import VGG, vgg16, VGG16_Weights from . import _utils as det_utils from .anchor_utils import DefaultBoxGenerator from .backbone_utils import _validate_trainable_layers diff --git a/torchvision/models/detection/ssdlite.py b/torchvision/models/detection/ssdlite.py index 3be9b6fb9f2..7d695823b39 100644 --- a/torchvision/models/detection/ssdlite.py +++ b/torchvision/models/detection/ssdlite.py @@ -10,10 +10,10 @@ from ...transforms._presets import ObjectDetection from ...utils import _log_api_usage_once from .. import mobilenet -from .._api import WeightsEnum, Weights +from .._api import Weights, WeightsEnum from .._meta import _COCO_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_value_param -from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large +from .._utils import _ovewrite_value_param, handle_legacy_interface +from ..mobilenetv3 import mobilenet_v3_large, MobileNet_V3_Large_Weights from . import _utils as det_utils from .anchor_utils import DefaultBoxGenerator from .backbone_utils import _validate_trainable_layers diff --git a/torchvision/models/detection/transform.py b/torchvision/models/detection/transform.py index 4f653a86acd..dd2d728abf9 100644 --- a/torchvision/models/detection/transform.py +++ b/torchvision/models/detection/transform.py @@ -1,5 +1,5 @@ import math -from typing import List, Tuple, Dict, Optional, Any +from typing import Any, Dict, List, Optional, Tuple import torch import torchvision diff --git a/torchvision/models/efficientnet.py b/torchvision/models/efficientnet.py index bfd59aee951..417ebabcbe5 100644 --- a/torchvision/models/efficientnet.py +++ b/torchvision/models/efficientnet.py @@ -3,7 +3,7 @@ import warnings from dataclasses import dataclass from functools import partial -from typing import Any, Callable, Dict, Optional, List, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import torch from torch import nn, Tensor @@ -12,9 +12,9 @@ from ..ops.misc import Conv2dNormActivation, SqueezeExcitation from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once -from ._api import WeightsEnum, Weights +from ._api import Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param, _make_divisible +from ._utils import _make_divisible, _ovewrite_named_param, handle_legacy_interface __all__ = [ diff --git a/torchvision/models/feature_extraction.py b/torchvision/models/feature_extraction.py index 1b380076b2a..d247d9a3e26 100644 --- a/torchvision/models/feature_extraction.py +++ b/torchvision/models/feature_extraction.py @@ -5,12 +5,11 @@ from collections import OrderedDict from copy import deepcopy from itertools import chain -from typing import Dict, Callable, List, Union, Optional, Tuple, Any +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch import torchvision -from torch import fx -from torch import nn +from torch import fx, nn from torch.fx.graph_module import _copy_attr diff --git a/torchvision/models/googlenet.py b/torchvision/models/googlenet.py index 5b0a91d4791..895fcd1e4e6 100644 --- a/torchvision/models/googlenet.py +++ b/torchvision/models/googlenet.py @@ -1,7 +1,7 @@ import warnings from collections import namedtuple from functools import partial -from typing import Optional, Tuple, List, Callable, Any +from typing import Any, Callable, List, Optional, Tuple import torch import torch.nn as nn @@ -10,9 +10,9 @@ from ..transforms._presets import ImageClassification from ..utils import _log_api_usage_once -from ._api import WeightsEnum, Weights +from ._api import Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param +from ._utils import _ovewrite_named_param, handle_legacy_interface __all__ = ["GoogLeNet", "GoogLeNetOutputs", "_GoogLeNetOutputs", "GoogLeNet_Weights", "googlenet"] diff --git a/torchvision/models/inception.py b/torchvision/models/inception.py index 9207485085f..c1a87954f7c 100644 --- a/torchvision/models/inception.py +++ b/torchvision/models/inception.py @@ -1,7 +1,7 @@ import warnings from collections import namedtuple from functools import partial -from typing import Callable, Any, Optional, Tuple, List +from typing import Any, Callable, List, Optional, Tuple import torch import torch.nn.functional as F @@ -9,9 +9,9 @@ from ..transforms._presets import ImageClassification from ..utils import _log_api_usage_once -from ._api import WeightsEnum, Weights +from ._api import Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param +from ._utils import _ovewrite_named_param, handle_legacy_interface __all__ = ["Inception3", "InceptionOutputs", "_InceptionOutputs", "Inception_V3_Weights", "inception_v3"] diff --git a/torchvision/models/mnasnet.py b/torchvision/models/mnasnet.py index 8286674d232..27117ae3a83 100644 --- a/torchvision/models/mnasnet.py +++ b/torchvision/models/mnasnet.py @@ -8,9 +8,9 @@ from ..transforms._presets import ImageClassification from ..utils import _log_api_usage_once -from ._api import WeightsEnum, Weights +from ._api import Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param +from ._utils import _ovewrite_named_param, handle_legacy_interface __all__ = [ diff --git a/torchvision/models/mobilenetv2.py b/torchvision/models/mobilenetv2.py index 4c4a7d1e293..06fbff2802a 100644 --- a/torchvision/models/mobilenetv2.py +++ b/torchvision/models/mobilenetv2.py @@ -1,17 +1,16 @@ import warnings from functools import partial -from typing import Callable, Any, Optional, List +from typing import Any, Callable, List, Optional import torch -from torch import Tensor -from torch import nn +from torch import nn, Tensor from ..ops.misc import Conv2dNormActivation from ..transforms._presets import ImageClassification from ..utils import _log_api_usage_once -from ._api import WeightsEnum, Weights +from ._api import Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param, _make_divisible +from ._utils import _make_divisible, _ovewrite_named_param, handle_legacy_interface __all__ = ["MobileNetV2", "MobileNet_V2_Weights", "mobilenet_v2"] diff --git a/torchvision/models/mobilenetv3.py b/torchvision/models/mobilenetv3.py index dfdd529bfc2..10d2a1c91ac 100644 --- a/torchvision/models/mobilenetv3.py +++ b/torchvision/models/mobilenetv3.py @@ -8,9 +8,9 @@ from ..ops.misc import Conv2dNormActivation, SqueezeExcitation as SElayer from ..transforms._presets import ImageClassification from ..utils import _log_api_usage_once -from ._api import WeightsEnum, Weights +from ._api import Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param, _make_divisible +from ._utils import _make_divisible, _ovewrite_named_param, handle_legacy_interface __all__ = [ diff --git a/torchvision/models/quantization/googlenet.py b/torchvision/models/quantization/googlenet.py index c535300a68c..f0205ef608c 100644 --- a/torchvision/models/quantization/googlenet.py +++ b/torchvision/models/quantization/googlenet.py @@ -8,10 +8,10 @@ from torch.nn import functional as F from ...transforms._presets import ImageClassification -from .._api import WeightsEnum, Weights +from .._api import Weights, WeightsEnum from .._meta import _IMAGENET_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_named_param -from ..googlenet import GoogLeNetOutputs, BasicConv2d, Inception, InceptionAux, GoogLeNet, GoogLeNet_Weights +from .._utils import _ovewrite_named_param, handle_legacy_interface +from ..googlenet import BasicConv2d, GoogLeNet, GoogLeNet_Weights, GoogLeNetOutputs, Inception, InceptionAux from .utils import _fuse_modules, _replace_relu, quantize_model diff --git a/torchvision/models/quantization/inception.py b/torchvision/models/quantization/inception.py index ba4b21d4112..1698cec7557 100644 --- a/torchvision/models/quantization/inception.py +++ b/torchvision/models/quantization/inception.py @@ -7,12 +7,12 @@ import torch.nn.functional as F from torch import Tensor from torchvision.models import inception as inception_module -from torchvision.models.inception import InceptionOutputs, Inception_V3_Weights +from torchvision.models.inception import Inception_V3_Weights, InceptionOutputs from ...transforms._presets import ImageClassification -from .._api import WeightsEnum, Weights +from .._api import Weights, WeightsEnum from .._meta import _IMAGENET_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_named_param +from .._utils import _ovewrite_named_param, handle_legacy_interface from .utils import _fuse_modules, _replace_relu, quantize_model diff --git a/torchvision/models/quantization/mobilenetv2.py b/torchvision/models/quantization/mobilenetv2.py index 936e9bcc1b1..61a3cb7eeba 100644 --- a/torchvision/models/quantization/mobilenetv2.py +++ b/torchvision/models/quantization/mobilenetv2.py @@ -1,16 +1,15 @@ from functools import partial from typing import Any, Optional, Union -from torch import Tensor -from torch import nn -from torch.ao.quantization import QuantStub, DeQuantStub -from torchvision.models.mobilenetv2 import InvertedResidual, MobileNetV2, MobileNet_V2_Weights +from torch import nn, Tensor +from torch.ao.quantization import DeQuantStub, QuantStub +from torchvision.models.mobilenetv2 import InvertedResidual, MobileNet_V2_Weights, MobileNetV2 from ...ops.misc import Conv2dNormActivation from ...transforms._presets import ImageClassification -from .._api import WeightsEnum, Weights +from .._api import Weights, WeightsEnum from .._meta import _IMAGENET_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_named_param +from .._utils import _ovewrite_named_param, handle_legacy_interface from .utils import _fuse_modules, _replace_relu, quantize_model diff --git a/torchvision/models/quantization/mobilenetv3.py b/torchvision/models/quantization/mobilenetv3.py index 94036143138..56341bb280e 100644 --- a/torchvision/models/quantization/mobilenetv3.py +++ b/torchvision/models/quantization/mobilenetv3.py @@ -3,19 +3,19 @@ import torch from torch import nn, Tensor -from torch.ao.quantization import QuantStub, DeQuantStub +from torch.ao.quantization import DeQuantStub, QuantStub from ...ops.misc import Conv2dNormActivation, SqueezeExcitation from ...transforms._presets import ImageClassification -from .._api import WeightsEnum, Weights +from .._api import Weights, WeightsEnum from .._meta import _IMAGENET_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_named_param +from .._utils import _ovewrite_named_param, handle_legacy_interface from ..mobilenetv3 import ( + _mobilenet_v3_conf, InvertedResidual, InvertedResidualConfig, - MobileNetV3, - _mobilenet_v3_conf, MobileNet_V3_Large_Weights, + MobileNetV3, ) from .utils import _fuse_modules, _replace_relu diff --git a/torchvision/models/quantization/resnet.py b/torchvision/models/quantization/resnet.py index d51bde50a57..bf3c733887e 100644 --- a/torchvision/models/quantization/resnet.py +++ b/torchvision/models/quantization/resnet.py @@ -1,12 +1,12 @@ from functools import partial -from typing import Any, Type, Union, List, Optional +from typing import Any, List, Optional, Type, Union import torch import torch.nn as nn from torch import Tensor from torchvision.models.resnet import ( - Bottleneck, BasicBlock, + Bottleneck, ResNet, ResNet18_Weights, ResNet50_Weights, @@ -15,9 +15,9 @@ ) from ...transforms._presets import ImageClassification -from .._api import WeightsEnum, Weights +from .._api import Weights, WeightsEnum from .._meta import _IMAGENET_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_named_param +from .._utils import _ovewrite_named_param, handle_legacy_interface from .utils import _fuse_modules, _replace_relu, quantize_model diff --git a/torchvision/models/quantization/shufflenetv2.py b/torchvision/models/quantization/shufflenetv2.py index 781591ae118..028df8be982 100644 --- a/torchvision/models/quantization/shufflenetv2.py +++ b/torchvision/models/quantization/shufflenetv2.py @@ -7,9 +7,9 @@ from torchvision.models import shufflenetv2 from ...transforms._presets import ImageClassification -from .._api import WeightsEnum, Weights +from .._api import Weights, WeightsEnum from .._meta import _IMAGENET_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_named_param +from .._utils import _ovewrite_named_param, handle_legacy_interface from ..shufflenetv2 import ( ShuffleNet_V2_X0_5_Weights, ShuffleNet_V2_X1_0_Weights, diff --git a/torchvision/models/regnet.py b/torchvision/models/regnet.py index d2958e8686c..d4b4147404c 100644 --- a/torchvision/models/regnet.py +++ b/torchvision/models/regnet.py @@ -9,9 +9,9 @@ from ..ops.misc import Conv2dNormActivation, SqueezeExcitation from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once -from ._api import WeightsEnum, Weights +from ._api import Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param, _make_divisible +from ._utils import _make_divisible, _ovewrite_named_param, handle_legacy_interface __all__ = [ diff --git a/torchvision/models/resnet.py b/torchvision/models/resnet.py index 70602705521..667bece5730 100644 --- a/torchvision/models/resnet.py +++ b/torchvision/models/resnet.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Type, Any, Callable, Union, List, Optional +from typing import Any, Callable, List, Optional, Type, Union import torch import torch.nn as nn @@ -7,9 +7,9 @@ from ..transforms._presets import ImageClassification from ..utils import _log_api_usage_once -from ._api import WeightsEnum, Weights +from ._api import Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param +from ._utils import _ovewrite_named_param, handle_legacy_interface __all__ = [ diff --git a/torchvision/models/segmentation/_utils.py b/torchvision/models/segmentation/_utils.py index 44a60a95c54..56560e9dab5 100644 --- a/torchvision/models/segmentation/_utils.py +++ b/torchvision/models/segmentation/_utils.py @@ -1,5 +1,5 @@ from collections import OrderedDict -from typing import Optional, Dict +from typing import Dict, Optional from torch import nn, Tensor from torch.nn import functional as F diff --git a/torchvision/models/segmentation/deeplabv3.py b/torchvision/models/segmentation/deeplabv3.py index e232235f0ff..0937369a1e7 100644 --- a/torchvision/models/segmentation/deeplabv3.py +++ b/torchvision/models/segmentation/deeplabv3.py @@ -6,11 +6,11 @@ from torch.nn import functional as F from ...transforms._presets import SemanticSegmentation -from .._api import WeightsEnum, Weights +from .._api import Weights, WeightsEnum from .._meta import _VOC_CATEGORIES -from .._utils import IntermediateLayerGetter, handle_legacy_interface, _ovewrite_value_param -from ..mobilenetv3 import MobileNetV3, MobileNet_V3_Large_Weights, mobilenet_v3_large -from ..resnet import ResNet, resnet50, resnet101, ResNet50_Weights, ResNet101_Weights +from .._utils import _ovewrite_value_param, handle_legacy_interface, IntermediateLayerGetter +from ..mobilenetv3 import mobilenet_v3_large, MobileNet_V3_Large_Weights, MobileNetV3 +from ..resnet import ResNet, resnet101, ResNet101_Weights, resnet50, ResNet50_Weights from ._utils import _SimpleSegmentationModel from .fcn import FCNHead diff --git a/torchvision/models/segmentation/fcn.py b/torchvision/models/segmentation/fcn.py index b44d0d7547a..2782d675ffe 100644 --- a/torchvision/models/segmentation/fcn.py +++ b/torchvision/models/segmentation/fcn.py @@ -4,10 +4,10 @@ from torch import nn from ...transforms._presets import SemanticSegmentation -from .._api import WeightsEnum, Weights +from .._api import Weights, WeightsEnum from .._meta import _VOC_CATEGORIES -from .._utils import IntermediateLayerGetter, handle_legacy_interface, _ovewrite_value_param -from ..resnet import ResNet, ResNet50_Weights, ResNet101_Weights, resnet50, resnet101 +from .._utils import _ovewrite_value_param, handle_legacy_interface, IntermediateLayerGetter +from ..resnet import ResNet, resnet101, ResNet101_Weights, resnet50, ResNet50_Weights from ._utils import _SimpleSegmentationModel diff --git a/torchvision/models/segmentation/lraspp.py b/torchvision/models/segmentation/lraspp.py index 385960cbde4..339d5feffe6 100644 --- a/torchvision/models/segmentation/lraspp.py +++ b/torchvision/models/segmentation/lraspp.py @@ -7,10 +7,10 @@ from ...transforms._presets import SemanticSegmentation from ...utils import _log_api_usage_once -from .._api import WeightsEnum, Weights +from .._api import Weights, WeightsEnum from .._meta import _VOC_CATEGORIES -from .._utils import IntermediateLayerGetter, handle_legacy_interface, _ovewrite_value_param -from ..mobilenetv3 import MobileNetV3, MobileNet_V3_Large_Weights, mobilenet_v3_large +from .._utils import _ovewrite_value_param, handle_legacy_interface, IntermediateLayerGetter +from ..mobilenetv3 import mobilenet_v3_large, MobileNet_V3_Large_Weights, MobileNetV3 __all__ = ["LRASPP", "LRASPP_MobileNet_V3_Large_Weights", "lraspp_mobilenet_v3_large"] diff --git a/torchvision/models/shufflenetv2.py b/torchvision/models/shufflenetv2.py index 48695c70193..cc4291c9a86 100644 --- a/torchvision/models/shufflenetv2.py +++ b/torchvision/models/shufflenetv2.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Callable, Any, List, Optional +from typing import Any, Callable, List, Optional import torch import torch.nn as nn @@ -7,9 +7,9 @@ from ..transforms._presets import ImageClassification from ..utils import _log_api_usage_once -from ._api import WeightsEnum, Weights +from ._api import Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param +from ._utils import _ovewrite_named_param, handle_legacy_interface __all__ = [ diff --git a/torchvision/models/squeezenet.py b/torchvision/models/squeezenet.py index dbc0f54fb77..8d43d3a0330 100644 --- a/torchvision/models/squeezenet.py +++ b/torchvision/models/squeezenet.py @@ -7,9 +7,9 @@ from ..transforms._presets import ImageClassification from ..utils import _log_api_usage_once -from ._api import WeightsEnum, Weights +from ._api import Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param +from ._utils import _ovewrite_named_param, handle_legacy_interface __all__ = ["SqueezeNet", "SqueezeNet1_0_Weights", "SqueezeNet1_1_Weights", "squeezenet1_0", "squeezenet1_1"] diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 7bc6b46c674..db5604fb377 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Optional, Callable, List, Any +from typing import Any, Callable, List, Optional import torch import torch.nn.functional as F @@ -9,7 +9,7 @@ from ..ops.stochastic_depth import StochasticDepth from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once -from ._api import WeightsEnum, Weights +from ._api import Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES from ._utils import _ovewrite_named_param @@ -366,7 +366,7 @@ def __init__( # build SwinTransformer blocks for i_stage in range(len(depths)): stage: List[nn.Module] = [] - dim = embed_dim * 2 ** i_stage + dim = embed_dim * 2**i_stage for i_layer in range(depths[i_stage]): # adjust stochastic depth probability based on the depth of the stage block sd_prob = stochastic_depth_prob * float(stage_block_id) / (total_stage_blocks - 1) diff --git a/torchvision/models/vgg.py b/torchvision/models/vgg.py index 937458b48cd..7c141381ee8 100644 --- a/torchvision/models/vgg.py +++ b/torchvision/models/vgg.py @@ -1,14 +1,14 @@ from functools import partial -from typing import Union, List, Dict, Any, Optional, cast +from typing import Any, cast, Dict, List, Optional, Union import torch import torch.nn as nn from ..transforms._presets import ImageClassification from ..utils import _log_api_usage_once -from ._api import WeightsEnum, Weights +from ._api import Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param +from ._utils import _ovewrite_named_param, handle_legacy_interface __all__ = [ diff --git a/torchvision/models/video/mvit.py b/torchvision/models/video/mvit.py index 0fd76399b5e..702116f047c 100644 --- a/torchvision/models/video/mvit.py +++ b/torchvision/models/video/mvit.py @@ -7,10 +7,10 @@ import torch.fx import torch.nn as nn -from ...ops import StochasticDepth, MLP +from ...ops import MLP, StochasticDepth from ...transforms._presets import VideoClassification from ...utils import _log_api_usage_once -from .._api import WeightsEnum, Weights +from .._api import Weights, WeightsEnum from .._meta import _KINETICS400_CATEGORIES from .._utils import _ovewrite_named_param diff --git a/torchvision/models/video/resnet.py b/torchvision/models/video/resnet.py index 6ec8bfc0b3e..ab369c55553 100644 --- a/torchvision/models/video/resnet.py +++ b/torchvision/models/video/resnet.py @@ -1,14 +1,14 @@ from functools import partial -from typing import Tuple, Optional, Callable, List, Sequence, Type, Any, Union +from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, Union import torch.nn as nn from torch import Tensor from ...transforms._presets import VideoClassification from ...utils import _log_api_usage_once -from .._api import WeightsEnum, Weights +from .._api import Weights, WeightsEnum from .._meta import _KINETICS400_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_named_param +from .._utils import _ovewrite_named_param, handle_legacy_interface __all__ = [ diff --git a/torchvision/models/vision_transformer.py b/torchvision/models/vision_transformer.py index 57c1479b13d..e9a8c94cc67 100644 --- a/torchvision/models/vision_transformer.py +++ b/torchvision/models/vision_transformer.py @@ -1,7 +1,7 @@ import math from collections import OrderedDict from functools import partial -from typing import Any, Callable, List, NamedTuple, Optional, Dict +from typing import Any, Callable, Dict, List, NamedTuple, Optional import torch import torch.nn as nn @@ -9,9 +9,9 @@ from ..ops.misc import Conv2dNormActivation, MLP from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once -from ._api import WeightsEnum, Weights +from ._api import Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param +from ._utils import _ovewrite_named_param, handle_legacy_interface __all__ = [ diff --git a/torchvision/ops/__init__.py b/torchvision/ops/__init__.py index 5d56f0bca42..827505b842d 100644 --- a/torchvision/ops/__init__.py +++ b/torchvision/ops/__init__.py @@ -1,25 +1,25 @@ from ._register_onnx_ops import _register_custom_op from .boxes import ( - nms, batched_nms, - remove_small_boxes, - clip_boxes_to_image, box_area, box_convert, box_iou, - generalized_box_iou, - distance_box_iou, + clip_boxes_to_image, complete_box_iou, + distance_box_iou, + generalized_box_iou, masks_to_boxes, + nms, + remove_small_boxes, ) from .ciou_loss import complete_box_iou_loss from .deform_conv import deform_conv2d, DeformConv2d from .diou_loss import distance_box_iou_loss -from .drop_block import drop_block2d, DropBlock2d, drop_block3d, DropBlock3d +from .drop_block import drop_block2d, drop_block3d, DropBlock2d, DropBlock3d from .feature_pyramid_network import FeaturePyramidNetwork from .focal_loss import sigmoid_focal_loss from .giou_loss import generalized_box_iou_loss -from .misc import FrozenBatchNorm2d, Conv2dNormActivation, Conv3dNormActivation, SqueezeExcitation, MLP, Permute +from .misc import Conv2dNormActivation, Conv3dNormActivation, FrozenBatchNorm2d, MLP, Permute, SqueezeExcitation from .poolers import MultiScaleRoIAlign from .ps_roi_align import ps_roi_align, PSRoIAlign from .ps_roi_pool import ps_roi_pool, PSRoIPool diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index 72c95442b78..e42e7e04a70 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -6,7 +6,7 @@ from torchvision.extension import _assert_has_ops from ..utils import _log_api_usage_once -from ._box_convert import _box_cxcywh_to_xyxy, _box_xyxy_to_cxcywh, _box_xywh_to_xyxy, _box_xyxy_to_xywh +from ._box_convert import _box_cxcywh_to_xyxy, _box_xywh_to_xyxy, _box_xyxy_to_cxcywh, _box_xyxy_to_xywh from ._utils import _upcast @@ -331,7 +331,7 @@ def complete_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso w_gt = boxes2[:, 2] - boxes2[:, 0] h_gt = boxes2[:, 3] - boxes2[:, 1] - v = (4 / (torch.pi ** 2)) * torch.pow((torch.atan(w_gt / h_gt) - torch.atan(w_pred / h_pred)), 2) + v = (4 / (torch.pi**2)) * torch.pow((torch.atan(w_gt / h_gt) - torch.atan(w_pred / h_pred)), 2) with torch.no_grad(): alpha = v / (1 - iou + v + eps) return diou - alpha * v diff --git a/torchvision/ops/ciou_loss.py b/torchvision/ops/ciou_loss.py index 1f271fb0a1d..a71baf28e70 100644 --- a/torchvision/ops/ciou_loss.py +++ b/torchvision/ops/ciou_loss.py @@ -58,7 +58,7 @@ def complete_box_iou_loss( h_pred = y2 - y1 w_gt = x2g - x1g h_gt = y2g - y1g - v = (4 / (torch.pi ** 2)) * torch.pow((torch.atan(w_gt / h_gt) - torch.atan(w_pred / h_pred)), 2) + v = (4 / (torch.pi**2)) * torch.pow((torch.atan(w_gt / h_gt) - torch.atan(w_pred / h_pred)), 2) with torch.no_grad(): alpha = v / (1 - iou + v + eps) diff --git a/torchvision/ops/drop_block.py b/torchvision/ops/drop_block.py index a798677f60f..e65496ea29a 100644 --- a/torchvision/ops/drop_block.py +++ b/torchvision/ops/drop_block.py @@ -37,7 +37,7 @@ def drop_block2d( N, C, H, W = input.size() block_size = min(block_size, W, H) # compute the gamma of Bernoulli distribution - gamma = (p * H * W) / ((block_size ** 2) * ((H - block_size + 1) * (W - block_size + 1))) + gamma = (p * H * W) / ((block_size**2) * ((H - block_size + 1) * (W - block_size + 1))) noise = torch.empty((N, C, H - block_size + 1, W - block_size + 1), dtype=input.dtype, device=input.device) noise.bernoulli_(gamma) @@ -83,7 +83,7 @@ def drop_block3d( N, C, D, H, W = input.size() block_size = min(block_size, D, H, W) # compute the gamma of Bernoulli distribution - gamma = (p * D * H * W) / ((block_size ** 3) * ((D - block_size + 1) * (H - block_size + 1) * (W - block_size + 1))) + gamma = (p * D * H * W) / ((block_size**3) * ((D - block_size + 1) * (H - block_size + 1) * (W - block_size + 1))) noise = torch.empty( (N, C, D - block_size + 1, H - block_size + 1, W - block_size + 1), dtype=input.dtype, device=input.device ) diff --git a/torchvision/ops/feature_pyramid_network.py b/torchvision/ops/feature_pyramid_network.py index 9062405a997..ffec3505ec0 100644 --- a/torchvision/ops/feature_pyramid_network.py +++ b/torchvision/ops/feature_pyramid_network.py @@ -1,5 +1,5 @@ from collections import OrderedDict -from typing import Tuple, List, Dict, Callable, Optional +from typing import Callable, Dict, List, Optional, Tuple import torch.nn.functional as F from torch import nn, Tensor diff --git a/torchvision/ops/giou_loss.py b/torchvision/ops/giou_loss.py index a7210f5739b..0c555ec4fe9 100644 --- a/torchvision/ops/giou_loss.py +++ b/torchvision/ops/giou_loss.py @@ -1,7 +1,7 @@ import torch from ..utils import _log_api_usage_once -from ._utils import _upcast_non_float, _loss_inter_union +from ._utils import _loss_inter_union, _upcast_non_float def generalized_box_iou_loss( diff --git a/torchvision/ops/misc.py b/torchvision/ops/misc.py index 422119ceaec..d4bda7decc5 100644 --- a/torchvision/ops/misc.py +++ b/torchvision/ops/misc.py @@ -1,5 +1,5 @@ import warnings -from typing import Callable, List, Optional, Union, Tuple, Sequence +from typing import Callable, List, Optional, Sequence, Tuple, Union import torch from torch import Tensor diff --git a/torchvision/ops/poolers.py b/torchvision/ops/poolers.py index f881201a2d2..cfcb9e94056 100644 --- a/torchvision/ops/poolers.py +++ b/torchvision/ops/poolers.py @@ -1,5 +1,5 @@ import warnings -from typing import Optional, List, Dict, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import torch import torch.fx diff --git a/torchvision/ops/ps_roi_align.py b/torchvision/ops/ps_roi_align.py index 7153e49ac05..0228a2a5554 100644 --- a/torchvision/ops/ps_roi_align.py +++ b/torchvision/ops/ps_roi_align.py @@ -4,7 +4,7 @@ from torchvision.extension import _assert_has_ops from ..utils import _log_api_usage_once -from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape +from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format def ps_roi_align( diff --git a/torchvision/ops/ps_roi_pool.py b/torchvision/ops/ps_roi_pool.py index a27c36ee76c..1a3eed35915 100644 --- a/torchvision/ops/ps_roi_pool.py +++ b/torchvision/ops/ps_roi_pool.py @@ -4,7 +4,7 @@ from torchvision.extension import _assert_has_ops from ..utils import _log_api_usage_once -from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape +from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format def ps_roi_pool( diff --git a/torchvision/ops/roi_align.py b/torchvision/ops/roi_align.py index 131c1b81d0f..afe9e42af16 100644 --- a/torchvision/ops/roi_align.py +++ b/torchvision/ops/roi_align.py @@ -7,7 +7,7 @@ from torchvision.extension import _assert_has_ops from ..utils import _log_api_usage_once -from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape +from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format def roi_align( diff --git a/torchvision/ops/roi_pool.py b/torchvision/ops/roi_pool.py index 37cbf7febee..50dc2f64421 100644 --- a/torchvision/ops/roi_pool.py +++ b/torchvision/ops/roi_pool.py @@ -7,7 +7,7 @@ from torchvision.extension import _assert_has_ops from ..utils import _log_api_usage_once -from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape +from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format def roi_pool( diff --git a/torchvision/prototype/__init__.py b/torchvision/prototype/__init__.py index e1be6c81f59..bef5ecc411d 100644 --- a/torchvision/prototype/__init__.py +++ b/torchvision/prototype/__init__.py @@ -1,5 +1 @@ -from . import datasets -from . import features -from . import models -from . import transforms -from . import utils +from . import datasets, features, models, transforms, utils diff --git a/torchvision/prototype/datasets/_api.py b/torchvision/prototype/datasets/_api.py index 407dc23f64b..f6f06c60a21 100644 --- a/torchvision/prototype/datasets/_api.py +++ b/torchvision/prototype/datasets/_api.py @@ -1,5 +1,5 @@ import pathlib -from typing import Any, Dict, List, Callable, Type, Optional, Union, TypeVar +from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union from torchvision.prototype.datasets import home from torchvision.prototype.datasets.utils import Dataset diff --git a/torchvision/prototype/datasets/_builtin/__init__.py b/torchvision/prototype/datasets/_builtin/__init__.py index 4acc1d53b4d..d84e9af9fc4 100644 --- a/torchvision/prototype/datasets/_builtin/__init__.py +++ b/torchvision/prototype/datasets/_builtin/__init__.py @@ -11,7 +11,7 @@ from .food101 import Food101 from .gtsrb import GTSRB from .imagenet import ImageNet -from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST +from .mnist import EMNIST, FashionMNIST, KMNIST, MNIST, QMNIST from .oxford_iiit_pet import OxfordIIITPet from .pcam import PCAM from .sbd import SBD diff --git a/torchvision/prototype/datasets/_builtin/caltech.py b/torchvision/prototype/datasets/_builtin/caltech.py index fe3dc2000e6..a00bf2e2cc9 100644 --- a/torchvision/prototype/datasets/_builtin/caltech.py +++ b/torchvision/prototype/datasets/_builtin/caltech.py @@ -1,23 +1,18 @@ import pathlib import re -from typing import Any, Dict, List, Tuple, BinaryIO, Union +from typing import Any, BinaryIO, Dict, List, Tuple, Union import numpy as np -from torchdata.datapipes.iter import ( - IterDataPipe, - Mapper, - Filter, - IterKeyZipper, -) +from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper from torchvision.prototype.datasets.utils import Dataset, GDriveResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( - INFINITE_BUFFER_SIZE, - read_mat, hint_sharding, hint_shuffling, + INFINITE_BUFFER_SIZE, read_categories_file, + read_mat, ) -from torchvision.prototype.features import Label, BoundingBox, _Feature, EncodedImage +from torchvision.prototype.features import _Feature, BoundingBox, EncodedImage, Label from .._api import register_dataset, register_info diff --git a/torchvision/prototype/datasets/_builtin/celeba.py b/torchvision/prototype/datasets/_builtin/celeba.py index 46ccf8de6f7..e42657e826e 100644 --- a/torchvision/prototype/datasets/_builtin/celeba.py +++ b/torchvision/prototype/datasets/_builtin/celeba.py @@ -1,27 +1,17 @@ import csv import pathlib -from typing import Any, Dict, List, Optional, Tuple, Iterator, Sequence, BinaryIO, Union - -from torchdata.datapipes.iter import ( - IterDataPipe, - Mapper, - Filter, - Zipper, - IterKeyZipper, -) -from torchvision.prototype.datasets.utils import ( - Dataset, - GDriveResource, - OnlineResource, -) +from typing import Any, BinaryIO, Dict, Iterator, List, Optional, Sequence, Tuple, Union + +from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper, Zipper +from torchvision.prototype.datasets.utils import Dataset, GDriveResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( - INFINITE_BUFFER_SIZE, getitem, - path_accessor, hint_sharding, hint_shuffling, + INFINITE_BUFFER_SIZE, + path_accessor, ) -from torchvision.prototype.features import EncodedImage, _Feature, Label, BoundingBox +from torchvision.prototype.features import _Feature, BoundingBox, EncodedImage, Label from .._api import register_dataset, register_info diff --git a/torchvision/prototype/datasets/_builtin/cifar.py b/torchvision/prototype/datasets/_builtin/cifar.py index 514938d6e5f..26196ded638 100644 --- a/torchvision/prototype/datasets/_builtin/cifar.py +++ b/torchvision/prototype/datasets/_builtin/cifar.py @@ -2,22 +2,18 @@ import io import pathlib import pickle -from typing import Any, Dict, List, Optional, Tuple, Iterator, cast, BinaryIO, Union +from typing import Any, BinaryIO, cast, Dict, Iterator, List, Optional, Tuple, Union import numpy as np -from torchdata.datapipes.iter import ( - IterDataPipe, - Filter, - Mapper, -) +from torchdata.datapipes.iter import Filter, IterDataPipe, Mapper from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( + hint_sharding, hint_shuffling, path_comparator, - hint_sharding, read_categories_file, ) -from torchvision.prototype.features import Label, Image +from torchvision.prototype.features import Image, Label from .._api import register_dataset, register_info diff --git a/torchvision/prototype/datasets/_builtin/clevr.py b/torchvision/prototype/datasets/_builtin/clevr.py index 3a139787c6f..4ddacdfb982 100644 --- a/torchvision/prototype/datasets/_builtin/clevr.py +++ b/torchvision/prototype/datasets/_builtin/clevr.py @@ -1,17 +1,17 @@ import pathlib -from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Union +from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union -from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, IterKeyZipper, Demultiplexer, JsonParser, UnBatcher +from torchdata.datapipes.iter import Demultiplexer, Filter, IterDataPipe, IterKeyZipper, JsonParser, Mapper, UnBatcher from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( - INFINITE_BUFFER_SIZE, + getitem, hint_sharding, hint_shuffling, - path_comparator, + INFINITE_BUFFER_SIZE, path_accessor, - getitem, + path_comparator, ) -from torchvision.prototype.features import Label, EncodedImage +from torchvision.prototype.features import EncodedImage, Label from .._api import register_dataset, register_info diff --git a/torchvision/prototype/datasets/_builtin/coco.py b/torchvision/prototype/datasets/_builtin/coco.py index ff3b5f37c96..16a16998bf7 100644 --- a/torchvision/prototype/datasets/_builtin/coco.py +++ b/torchvision/prototype/datasets/_builtin/coco.py @@ -1,35 +1,30 @@ import pathlib import re -from collections import OrderedDict -from collections import defaultdict -from typing import Any, Dict, List, Optional, Tuple, cast, BinaryIO, Union +from collections import defaultdict, OrderedDict +from typing import Any, BinaryIO, cast, Dict, List, Optional, Tuple, Union import torch from torchdata.datapipes.iter import ( - IterDataPipe, - Mapper, - Filter, Demultiplexer, + Filter, Grouper, + IterDataPipe, IterKeyZipper, JsonParser, + Mapper, UnBatcher, ) -from torchvision.prototype.datasets.utils import ( - HttpResource, - OnlineResource, - Dataset, -) +from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( - MappingIterator, - INFINITE_BUFFER_SIZE, getitem, - read_categories_file, - path_accessor, hint_sharding, hint_shuffling, + INFINITE_BUFFER_SIZE, + MappingIterator, + path_accessor, + read_categories_file, ) -from torchvision.prototype.features import BoundingBox, Label, _Feature, EncodedImage +from torchvision.prototype.features import _Feature, BoundingBox, EncodedImage, Label from .._api import register_dataset, register_info @@ -151,7 +146,7 @@ def _decode_captions_ann(self, anns: List[Dict[str, Any]], image_meta: Dict[str, ) _META_FILE_PATTERN = re.compile( - fr"(?P({'|'.join(_ANN_DECODERS.keys())}))_(?P[a-zA-Z]+)(?P\d+)[.]json" + rf"(?P({'|'.join(_ANN_DECODERS.keys())}))_(?P[a-zA-Z]+)(?P\d+)[.]json" ) def _filter_meta_files(self, data: Tuple[str, Any]) -> bool: diff --git a/torchvision/prototype/datasets/_builtin/country211.py b/torchvision/prototype/datasets/_builtin/country211.py index 012ecae19e2..f9821ea4eb6 100644 --- a/torchvision/prototype/datasets/_builtin/country211.py +++ b/torchvision/prototype/datasets/_builtin/country211.py @@ -1,12 +1,12 @@ import pathlib from typing import Any, Dict, List, Tuple, Union -from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter +from torchdata.datapipes.iter import Filter, IterDataPipe, Mapper from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( - path_comparator, hint_sharding, hint_shuffling, + path_comparator, read_categories_file, ) from torchvision.prototype.features import EncodedImage, Label diff --git a/torchvision/prototype/datasets/_builtin/cub200.py b/torchvision/prototype/datasets/_builtin/cub200.py index 0e5a80de825..bb3f712c59d 100644 --- a/torchvision/prototype/datasets/_builtin/cub200.py +++ b/torchvision/prototype/datasets/_builtin/cub200.py @@ -1,30 +1,30 @@ import csv import functools import pathlib -from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Callable, Union +from typing import Any, BinaryIO, Callable, Dict, List, Optional, Tuple, Union from torchdata.datapipes.iter import ( - IterDataPipe, - Mapper, + CSVDictParser, + CSVParser, + Demultiplexer, Filter, + IterDataPipe, IterKeyZipper, - Demultiplexer, LineReader, - CSVParser, - CSVDictParser, + Mapper, ) from torchvision.prototype.datasets.utils import Dataset, GDriveResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( - INFINITE_BUFFER_SIZE, - read_mat, + getitem, hint_sharding, hint_shuffling, - getitem, + INFINITE_BUFFER_SIZE, + path_accessor, path_comparator, read_categories_file, - path_accessor, + read_mat, ) -from torchvision.prototype.features import Label, BoundingBox, _Feature, EncodedImage +from torchvision.prototype.features import _Feature, BoundingBox, EncodedImage, Label from .._api import register_dataset, register_info diff --git a/torchvision/prototype/datasets/_builtin/dtd.py b/torchvision/prototype/datasets/_builtin/dtd.py index b082ada19ce..e7ff1e79559 100644 --- a/torchvision/prototype/datasets/_builtin/dtd.py +++ b/torchvision/prototype/datasets/_builtin/dtd.py @@ -1,22 +1,18 @@ import enum import pathlib -from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Union +from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union -from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, IterKeyZipper, Demultiplexer, LineReader, CSVParser -from torchvision.prototype.datasets.utils import ( - Dataset, - HttpResource, - OnlineResource, -) +from torchdata.datapipes.iter import CSVParser, Demultiplexer, Filter, IterDataPipe, IterKeyZipper, LineReader, Mapper +from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( - INFINITE_BUFFER_SIZE, + getitem, hint_sharding, + hint_shuffling, + INFINITE_BUFFER_SIZE, path_comparator, - getitem, read_categories_file, - hint_shuffling, ) -from torchvision.prototype.features import Label, EncodedImage +from torchvision.prototype.features import EncodedImage, Label from .._api import register_dataset, register_info diff --git a/torchvision/prototype/datasets/_builtin/fer2013.py b/torchvision/prototype/datasets/_builtin/fer2013.py index c1a914c6f63..b2693aa96c0 100644 --- a/torchvision/prototype/datasets/_builtin/fer2013.py +++ b/torchvision/prototype/datasets/_builtin/fer2013.py @@ -2,17 +2,10 @@ from typing import Any, Dict, List, Union import torch -from torchdata.datapipes.iter import IterDataPipe, Mapper, CSVDictParser -from torchvision.prototype.datasets.utils import ( - Dataset, - OnlineResource, - KaggleDownloadResource, -) -from torchvision.prototype.datasets.utils._internal import ( - hint_sharding, - hint_shuffling, -) -from torchvision.prototype.features import Label, Image +from torchdata.datapipes.iter import CSVDictParser, IterDataPipe, Mapper +from torchvision.prototype.datasets.utils import Dataset, KaggleDownloadResource, OnlineResource +from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling +from torchvision.prototype.features import Image, Label from .._api import register_dataset, register_info diff --git a/torchvision/prototype/datasets/_builtin/food101.py b/torchvision/prototype/datasets/_builtin/food101.py index 5100e5d8c74..3657116ae7a 100644 --- a/torchvision/prototype/datasets/_builtin/food101.py +++ b/torchvision/prototype/datasets/_builtin/food101.py @@ -1,24 +1,17 @@ from pathlib import Path -from typing import Any, Tuple, List, Dict, Optional, BinaryIO, Union - -from torchdata.datapipes.iter import ( - IterDataPipe, - Filter, - Mapper, - LineReader, - Demultiplexer, - IterKeyZipper, -) +from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union + +from torchdata.datapipes.iter import Demultiplexer, Filter, IterDataPipe, IterKeyZipper, LineReader, Mapper from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( - hint_shuffling, - hint_sharding, - path_comparator, getitem, + hint_sharding, + hint_shuffling, INFINITE_BUFFER_SIZE, + path_comparator, read_categories_file, ) -from torchvision.prototype.features import Label, EncodedImage +from torchvision.prototype.features import EncodedImage, Label from .._api import register_dataset, register_info diff --git a/torchvision/prototype/datasets/_builtin/gtsrb.py b/torchvision/prototype/datasets/_builtin/gtsrb.py index 01f754208e2..8dc0a8240c8 100644 --- a/torchvision/prototype/datasets/_builtin/gtsrb.py +++ b/torchvision/prototype/datasets/_builtin/gtsrb.py @@ -1,19 +1,15 @@ import pathlib from typing import Any, Dict, List, Optional, Tuple, Union -from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, CSVDictParser, Zipper, Demultiplexer -from torchvision.prototype.datasets.utils import ( - Dataset, - OnlineResource, - HttpResource, -) +from torchdata.datapipes.iter import CSVDictParser, Demultiplexer, Filter, IterDataPipe, Mapper, Zipper +from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( - path_comparator, hint_sharding, hint_shuffling, INFINITE_BUFFER_SIZE, + path_comparator, ) -from torchvision.prototype.features import Label, BoundingBox, EncodedImage +from torchvision.prototype.features import BoundingBox, EncodedImage, Label from .._api import register_dataset, register_info diff --git a/torchvision/prototype/datasets/_builtin/imagenet.py b/torchvision/prototype/datasets/_builtin/imagenet.py index 1307757cef6..062e240a8b8 100644 --- a/torchvision/prototype/datasets/_builtin/imagenet.py +++ b/torchvision/prototype/datasets/_builtin/imagenet.py @@ -2,33 +2,29 @@ import functools import pathlib import re -from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Match, cast, Union +from typing import Any, BinaryIO, cast, Dict, List, Match, Optional, Tuple, Union from torchdata.datapipes.iter import ( + Demultiplexer, + Enumerator, + Filter, IterDataPipe, - LineReader, IterKeyZipper, + LineReader, Mapper, - Filter, - Demultiplexer, TarArchiveLoader, - Enumerator, -) -from torchvision.prototype.datasets.utils import ( - OnlineResource, - ManualDownloadResource, - Dataset, ) +from torchvision.prototype.datasets.utils import Dataset, ManualDownloadResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( - INFINITE_BUFFER_SIZE, getitem, - read_mat, hint_sharding, hint_shuffling, - read_categories_file, + INFINITE_BUFFER_SIZE, path_accessor, + read_categories_file, + read_mat, ) -from torchvision.prototype.features import Label, EncodedImage +from torchvision.prototype.features import EncodedImage, Label from .._api import register_dataset, register_info diff --git a/torchvision/prototype/datasets/_builtin/mnist.py b/torchvision/prototype/datasets/_builtin/mnist.py index e5537a1ef66..7a459b2d0ea 100644 --- a/torchvision/prototype/datasets/_builtin/mnist.py +++ b/torchvision/prototype/datasets/_builtin/mnist.py @@ -3,12 +3,12 @@ import operator import pathlib import string -from typing import Any, Dict, Iterator, List, Optional, Tuple, cast, BinaryIO, Union, Sequence +from typing import Any, BinaryIO, cast, Dict, Iterator, List, Optional, Sequence, Tuple, Union import torch -from torchdata.datapipes.iter import IterDataPipe, Demultiplexer, Mapper, Zipper, Decompressor +from torchdata.datapipes.iter import Decompressor, Demultiplexer, IterDataPipe, Mapper, Zipper from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource -from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, hint_sharding, hint_shuffling +from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling, INFINITE_BUFFER_SIZE from torchvision.prototype.features import Image, Label from torchvision.prototype.utils._internal import fromfile diff --git a/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py b/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py index f7da02a4765..499dbd837ed 100644 --- a/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py +++ b/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py @@ -1,23 +1,19 @@ import enum import pathlib -from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Union +from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union -from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, IterKeyZipper, Demultiplexer, CSVDictParser -from torchvision.prototype.datasets.utils import ( - Dataset, - HttpResource, - OnlineResource, -) +from torchdata.datapipes.iter import CSVDictParser, Demultiplexer, Filter, IterDataPipe, IterKeyZipper, Mapper +from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( - INFINITE_BUFFER_SIZE, + getitem, hint_sharding, hint_shuffling, - getitem, + INFINITE_BUFFER_SIZE, path_accessor, - read_categories_file, path_comparator, + read_categories_file, ) -from torchvision.prototype.features import Label, EncodedImage +from torchvision.prototype.features import EncodedImage, Label from .._api import register_dataset, register_info diff --git a/torchvision/prototype/datasets/_builtin/pcam.py b/torchvision/prototype/datasets/_builtin/pcam.py index 7cd31469139..162f22f1abd 100644 --- a/torchvision/prototype/datasets/_builtin/pcam.py +++ b/torchvision/prototype/datasets/_builtin/pcam.py @@ -1,19 +1,12 @@ import io import pathlib from collections import namedtuple -from typing import Any, Dict, List, Optional, Tuple, Iterator, Union +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union from torchdata.datapipes.iter import IterDataPipe, Mapper, Zipper from torchvision.prototype import features -from torchvision.prototype.datasets.utils import ( - Dataset, - OnlineResource, - GDriveResource, -) -from torchvision.prototype.datasets.utils._internal import ( - hint_sharding, - hint_shuffling, -) +from torchvision.prototype.datasets.utils import Dataset, GDriveResource, OnlineResource +from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling from torchvision.prototype.features import Label from .._api import register_dataset, register_info diff --git a/torchvision/prototype/datasets/_builtin/sbd.py b/torchvision/prototype/datasets/_builtin/sbd.py index 0c806fe098c..c7a79c4188e 100644 --- a/torchvision/prototype/datasets/_builtin/sbd.py +++ b/torchvision/prototype/datasets/_builtin/sbd.py @@ -1,26 +1,19 @@ import pathlib import re -from typing import Any, Dict, List, Optional, Tuple, cast, BinaryIO, Union +from typing import Any, BinaryIO, cast, Dict, List, Optional, Tuple, Union import numpy as np -from torchdata.datapipes.iter import ( - IterDataPipe, - Mapper, - Demultiplexer, - Filter, - IterKeyZipper, - LineReader, -) +from torchdata.datapipes.iter import Demultiplexer, Filter, IterDataPipe, IterKeyZipper, LineReader, Mapper from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( - INFINITE_BUFFER_SIZE, - read_mat, getitem, - path_accessor, - path_comparator, hint_sharding, hint_shuffling, + INFINITE_BUFFER_SIZE, + path_accessor, + path_comparator, read_categories_file, + read_mat, ) from torchvision.prototype.features import _Feature, EncodedImage diff --git a/torchvision/prototype/datasets/_builtin/semeion.py b/torchvision/prototype/datasets/_builtin/semeion.py index 5051bde4047..8107f6565e4 100644 --- a/torchvision/prototype/datasets/_builtin/semeion.py +++ b/torchvision/prototype/datasets/_builtin/semeion.py @@ -2,16 +2,8 @@ from typing import Any, Dict, List, Tuple, Union import torch -from torchdata.datapipes.iter import ( - IterDataPipe, - Mapper, - CSVParser, -) -from torchvision.prototype.datasets.utils import ( - Dataset, - HttpResource, - OnlineResource, -) +from torchdata.datapipes.iter import CSVParser, IterDataPipe, Mapper +from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling from torchvision.prototype.features import Image, OneHotLabel diff --git a/torchvision/prototype/datasets/_builtin/stanford_cars.py b/torchvision/prototype/datasets/_builtin/stanford_cars.py index 465d753c2e5..011204f2bfb 100644 --- a/torchvision/prototype/datasets/_builtin/stanford_cars.py +++ b/torchvision/prototype/datasets/_builtin/stanford_cars.py @@ -1,5 +1,5 @@ import pathlib -from typing import Any, Dict, List, Tuple, Iterator, BinaryIO, Union +from typing import Any, BinaryIO, Dict, Iterator, List, Tuple, Union from torchdata.datapipes.iter import Filter, IterDataPipe, Mapper, Zipper from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource @@ -7,8 +7,8 @@ hint_sharding, hint_shuffling, path_comparator, - read_mat, read_categories_file, + read_mat, ) from torchvision.prototype.features import BoundingBox, EncodedImage, Label diff --git a/torchvision/prototype/datasets/_builtin/svhn.py b/torchvision/prototype/datasets/_builtin/svhn.py index 175aa6c0a51..6dd55a77c99 100644 --- a/torchvision/prototype/datasets/_builtin/svhn.py +++ b/torchvision/prototype/datasets/_builtin/svhn.py @@ -1,23 +1,11 @@ import pathlib -from typing import Any, Dict, List, Tuple, BinaryIO, Union +from typing import Any, BinaryIO, Dict, List, Tuple, Union import numpy as np -from torchdata.datapipes.iter import ( - IterDataPipe, - Mapper, - UnBatcher, -) -from torchvision.prototype.datasets.utils import ( - Dataset, - HttpResource, - OnlineResource, -) -from torchvision.prototype.datasets.utils._internal import ( - read_mat, - hint_sharding, - hint_shuffling, -) -from torchvision.prototype.features import Label, Image +from torchdata.datapipes.iter import IterDataPipe, Mapper, UnBatcher +from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource +from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling, read_mat +from torchvision.prototype.features import Image, Label from .._api import register_dataset, register_info diff --git a/torchvision/prototype/datasets/_builtin/usps.py b/torchvision/prototype/datasets/_builtin/usps.py index e732f3b788a..e5ca58f8428 100644 --- a/torchvision/prototype/datasets/_builtin/usps.py +++ b/torchvision/prototype/datasets/_builtin/usps.py @@ -2,8 +2,8 @@ from typing import Any, Dict, List, Union import torch -from torchdata.datapipes.iter import IterDataPipe, LineReader, Mapper, Decompressor -from torchvision.prototype.datasets.utils import Dataset, OnlineResource, HttpResource +from torchdata.datapipes.iter import Decompressor, IterDataPipe, LineReader, Mapper +from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling from torchvision.prototype.features import Image, Label diff --git a/torchvision/prototype/datasets/_builtin/voc.py b/torchvision/prototype/datasets/_builtin/voc.py index d875df521f2..2f13ce10d6f 100644 --- a/torchvision/prototype/datasets/_builtin/voc.py +++ b/torchvision/prototype/datasets/_builtin/voc.py @@ -1,29 +1,22 @@ import enum import functools import pathlib -from typing import Any, Dict, List, Optional, Tuple, BinaryIO, cast, Union +from typing import Any, BinaryIO, cast, Dict, List, Optional, Tuple, Union from xml.etree import ElementTree -from torchdata.datapipes.iter import ( - IterDataPipe, - Mapper, - Filter, - Demultiplexer, - IterKeyZipper, - LineReader, -) +from torchdata.datapipes.iter import Demultiplexer, Filter, IterDataPipe, IterKeyZipper, LineReader, Mapper from torchvision.datasets import VOCDetection -from torchvision.prototype.datasets.utils import OnlineResource, HttpResource, Dataset +from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( - path_accessor, getitem, - INFINITE_BUFFER_SIZE, - path_comparator, hint_sharding, hint_shuffling, + INFINITE_BUFFER_SIZE, + path_accessor, + path_comparator, read_categories_file, ) -from torchvision.prototype.features import BoundingBox, Label, EncodedImage +from torchvision.prototype.features import BoundingBox, EncodedImage, Label from .._api import register_dataset, register_info diff --git a/torchvision/prototype/datasets/_folder.py b/torchvision/prototype/datasets/_folder.py index c3a38becb6c..b2ec23c5e3d 100644 --- a/torchvision/prototype/datasets/_folder.py +++ b/torchvision/prototype/datasets/_folder.py @@ -2,11 +2,11 @@ import os import os.path import pathlib -from typing import BinaryIO, Optional, Collection, Union, Tuple, List, Dict, Any +from typing import Any, BinaryIO, Collection, Dict, List, Optional, Tuple, Union -from torchdata.datapipes.iter import IterDataPipe, FileLister, Mapper, Filter, FileOpener +from torchdata.datapipes.iter import FileLister, FileOpener, Filter, IterDataPipe, Mapper from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling -from torchvision.prototype.features import Label, EncodedImage, EncodedData +from torchvision.prototype.features import EncodedData, EncodedImage, Label __all__ = ["from_data_folder", "from_image_folder"] diff --git a/torchvision/prototype/datasets/utils/__init__.py b/torchvision/prototype/datasets/utils/__init__.py index 94c5907b47d..41ccbf48951 100644 --- a/torchvision/prototype/datasets/utils/__init__.py +++ b/torchvision/prototype/datasets/utils/__init__.py @@ -1,3 +1,3 @@ from . import _internal # usort: skip from ._dataset import Dataset -from ._resource import OnlineResource, HttpResource, GDriveResource, ManualDownloadResource, KaggleDownloadResource +from ._resource import GDriveResource, HttpResource, KaggleDownloadResource, ManualDownloadResource, OnlineResource diff --git a/torchvision/prototype/datasets/utils/_dataset.py b/torchvision/prototype/datasets/utils/_dataset.py index 528d0a0f25f..e7486c854ac 100644 --- a/torchvision/prototype/datasets/utils/_dataset.py +++ b/torchvision/prototype/datasets/utils/_dataset.py @@ -1,7 +1,7 @@ import abc import importlib import pathlib -from typing import Any, Dict, List, Optional, Sequence, Union, Collection, Iterator +from typing import Any, Collection, Dict, Iterator, List, Optional, Sequence, Union from torch.utils.data import IterDataPipe from torchvision.datasets.utils import verify_str_arg diff --git a/torchvision/prototype/datasets/utils/_internal.py b/torchvision/prototype/datasets/utils/_internal.py index 007e91eb657..6768469be67 100644 --- a/torchvision/prototype/datasets/utils/_internal.py +++ b/torchvision/prototype/datasets/utils/_internal.py @@ -2,21 +2,7 @@ import functools import pathlib import pickle -from typing import BinaryIO -from typing import ( - Sequence, - Callable, - Union, - Any, - Tuple, - TypeVar, - List, - Iterator, - Dict, - IO, - Sized, -) -from typing import cast +from typing import Any, BinaryIO, Callable, cast, Dict, IO, Iterator, List, Sequence, Sized, Tuple, TypeVar, Union import torch import torch.distributed as dist diff --git a/torchvision/prototype/datasets/utils/_resource.py b/torchvision/prototype/datasets/utils/_resource.py index 3c9b95cb498..dc01c72de28 100644 --- a/torchvision/prototype/datasets/utils/_resource.py +++ b/torchvision/prototype/datasets/utils/_resource.py @@ -2,26 +2,26 @@ import hashlib import itertools import pathlib -from typing import Optional, Sequence, Tuple, Callable, IO, Any, Union, NoReturn, Set +from typing import Any, Callable, IO, NoReturn, Optional, Sequence, Set, Tuple, Union from urllib.parse import urlparse from torchdata.datapipes.iter import ( - IterableWrapper, FileLister, FileOpener, + IterableWrapper, IterDataPipe, - ZipArchiveLoader, - TarArchiveLoader, RarArchiveLoader, + TarArchiveLoader, + ZipArchiveLoader, ) from torchvision.datasets.utils import ( - download_url, - _detect_file_type, - extract_archive, _decompress, - download_file_from_google_drive, - _get_redirect_url, + _detect_file_type, _get_google_drive_file_id, + _get_redirect_url, + download_file_from_google_drive, + download_url, + extract_archive, ) from typing_extensions import Literal diff --git a/torchvision/prototype/features/_bounding_box.py b/torchvision/prototype/features/_bounding_box.py index f1957df5f1d..59b88d2931f 100644 --- a/torchvision/prototype/features/_bounding_box.py +++ b/torchvision/prototype/features/_bounding_box.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, List, Tuple, Union, Optional, Sequence +from typing import Any, List, Optional, Sequence, Tuple, Union import torch from torchvision._utils import StrEnum diff --git a/torchvision/prototype/features/_encoded.py b/torchvision/prototype/features/_encoded.py index 612714c4c3a..ccab0b1b8a8 100644 --- a/torchvision/prototype/features/_encoded.py +++ b/torchvision/prototype/features/_encoded.py @@ -2,7 +2,7 @@ import os import sys -from typing import BinaryIO, Tuple, Type, TypeVar, Union, Optional, Any +from typing import Any, BinaryIO, Optional, Tuple, Type, TypeVar, Union import PIL.Image import torch diff --git a/torchvision/prototype/features/_feature.py b/torchvision/prototype/features/_feature.py index b2dedea86d3..8ccbfda57e0 100644 --- a/torchvision/prototype/features/_feature.py +++ b/torchvision/prototype/features/_feature.py @@ -1,4 +1,4 @@ -from typing import Any, cast, TypeVar, Union, Optional, Type, Callable, List, Tuple, Sequence, Mapping +from typing import Any, Callable, cast, List, Mapping, Optional, Sequence, Tuple, Type, TypeVar, Union import torch from torch._C import _TensorBase, DisableTorchFunction diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py index 18cf7c1964d..1a55e5c5acb 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/features/_image.py @@ -1,13 +1,12 @@ from __future__ import annotations import warnings -from typing import Any, List, Optional, Union, Sequence, Tuple, cast +from typing import Any, cast, List, Optional, Sequence, Tuple, Union import torch from torchvision._utils import StrEnum -from torchvision.transforms.functional import to_pil_image, InterpolationMode -from torchvision.utils import draw_bounding_boxes -from torchvision.utils import make_grid +from torchvision.transforms.functional import InterpolationMode, to_pil_image +from torchvision.utils import draw_bounding_boxes, make_grid from ._bounding_box import BoundingBox from ._feature import _Feature diff --git a/torchvision/prototype/features/_label.py b/torchvision/prototype/features/_label.py index e3433b7bb08..c61419a61b6 100644 --- a/torchvision/prototype/features/_label.py +++ b/torchvision/prototype/features/_label.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Optional, Sequence, cast, Union +from typing import Any, cast, Optional, Sequence, Union import torch from torchvision.prototype.utils._internal import apply_recursively diff --git a/torchvision/prototype/features/_segmentation_mask.py b/torchvision/prototype/features/_segmentation_mask.py index 5f7ea80430b..406e06aef11 100644 --- a/torchvision/prototype/features/_segmentation_mask.py +++ b/torchvision/prototype/features/_segmentation_mask.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import List, Optional, Union, Sequence +from typing import List, Optional, Sequence, Union import torch from torchvision.transforms import InterpolationMode diff --git a/torchvision/prototype/models/depth/stereo/raft_stereo.py b/torchvision/prototype/models/depth/stereo/raft_stereo.py index 418e2629c48..fa636f8ef00 100644 --- a/torchvision/prototype/models/depth/stereo/raft_stereo.py +++ b/torchvision/prototype/models/depth/stereo/raft_stereo.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Callable, Tuple +from typing import Callable, List, Optional, Tuple import torch import torch.nn as nn @@ -6,8 +6,8 @@ import torchvision.models.optical_flow.raft as raft from torch import Tensor from torchvision.models._api import WeightsEnum -from torchvision.models.optical_flow._utils import make_coords_grid, grid_sample, upsample_flow -from torchvision.models.optical_flow.raft import ResidualBlock, MotionEncoder, FlowHead +from torchvision.models.optical_flow._utils import grid_sample, make_coords_grid, upsample_flow +from torchvision.models.optical_flow.raft import FlowHead, MotionEncoder, ResidualBlock from torchvision.ops import Conv2dNormActivation from torchvision.utils import _log_api_usage_once diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index 483f811ef7d..15abd4f77f2 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -2,38 +2,38 @@ from ._transform import Transform # usort: skip -from ._augment import RandomErasing, RandomMixup, RandomCutmix -from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment, AugMix +from ._augment import RandomCutmix, RandomErasing, RandomMixup +from ._auto_augment import AugMix, AutoAugment, RandAugment, TrivialAugmentWide from ._color import ( ColorJitter, - RandomPhotometricDistort, + RandomAdjustSharpness, + RandomAutocontrast, RandomEqualize, RandomInvert, + RandomPhotometricDistort, RandomPosterize, RandomSolarize, - RandomAutocontrast, - RandomAdjustSharpness, ) from ._container import Compose, RandomApply, RandomChoice, RandomOrder from ._geometry import ( - Resize, + BatchMultiCrop, CenterCrop, - RandomResizedCrop, - RandomCrop, + ElasticTransform, FiveCrop, - TenCrop, - BatchMultiCrop, - RandomHorizontalFlip, - RandomVerticalFlip, Pad, - RandomZoomOut, - RandomRotation, RandomAffine, + RandomCrop, + RandomHorizontalFlip, RandomPerspective, - ElasticTransform, + RandomResizedCrop, + RandomRotation, + RandomVerticalFlip, + RandomZoomOut, + Resize, + TenCrop, ) -from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace -from ._misc import Identity, GaussianBlur, Normalize, ToDtype, Lambda +from ._meta import ConvertBoundingBoxFormat, ConvertImageColorSpace, ConvertImageDtype +from ._misc import GaussianBlur, Identity, Lambda, Normalize, ToDtype from ._type_conversion import DecodeImage, LabelToOneHot from ._deprecated import Grayscale, RandomGrayscale, ToTensor, ToPILImage, PILToTensor # usort: skip diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index 8ed81eef8f2..d1c3db816ad 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -6,10 +6,10 @@ import PIL.Image import torch from torchvision.prototype import features -from torchvision.prototype.transforms import Transform, functional as F +from torchvision.prototype.transforms import functional as F, Transform from ._transform import _RandomApplyTransform -from ._utils import query_image, get_image_dimensions, has_any, has_all +from ._utils import get_image_dimensions, has_all, has_any, query_image class RandomErasing(_RandomApplyTransform): diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index 03aa96e08fb..f4f1a3547b1 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -1,13 +1,13 @@ import math -from typing import Any, Dict, Tuple, Optional, Callable, List, cast, Sequence, TypeVar, Union, Type +from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Type, TypeVar, Union import PIL.Image import torch from torchvision.prototype import features -from torchvision.prototype.transforms import Transform, functional as F +from torchvision.prototype.transforms import functional as F, Transform from torchvision.prototype.utils._internal import query_recursively from torchvision.transforms.autoaugment import AutoAugmentPolicy -from torchvision.transforms.functional import pil_to_tensor, to_pil_image, InterpolationMode +from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image from ._utils import get_image_dimensions diff --git a/torchvision/prototype/transforms/_color.py b/torchvision/prototype/transforms/_color.py index 7fd198161c9..bc29fe5b677 100644 --- a/torchvision/prototype/transforms/_color.py +++ b/torchvision/prototype/transforms/_color.py @@ -1,14 +1,14 @@ import collections.abc -from typing import Any, Dict, Union, Tuple, Optional, Sequence, TypeVar +from typing import Any, Dict, Optional, Sequence, Tuple, TypeVar, Union import PIL.Image import torch from torchvision.prototype import features -from torchvision.prototype.transforms import Transform, functional as F +from torchvision.prototype.transforms import functional as F, Transform from torchvision.transforms import functional as _F from ._transform import _RandomApplyTransform -from ._utils import is_simple_tensor, get_image_dimensions, query_image +from ._utils import get_image_dimensions, is_simple_tensor, query_image T = TypeVar("T", features.Image, torch.Tensor, PIL.Image.Image) diff --git a/torchvision/prototype/transforms/_container.py b/torchvision/prototype/transforms/_container.py index e9c72e2e020..fd1f58f3351 100644 --- a/torchvision/prototype/transforms/_container.py +++ b/torchvision/prototype/transforms/_container.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, List, Dict +from typing import Any, Dict, List, Optional import torch from torchvision.prototype.transforms import Transform diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 8a1b94060c4..b1618b0eef5 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -2,18 +2,18 @@ import math import numbers import warnings -from typing import Any, Dict, List, Optional, Union, Sequence, Tuple, cast +from typing import Any, cast, Dict, List, Optional, Sequence, Tuple, Union import PIL.Image import torch from torchvision.prototype import features -from torchvision.prototype.transforms import Transform, functional as F -from torchvision.transforms.functional import pil_to_tensor, InterpolationMode -from torchvision.transforms.transforms import _setup_size, _setup_angle, _check_sequence_input +from torchvision.prototype.transforms import functional as F, Transform +from torchvision.transforms.functional import InterpolationMode, pil_to_tensor +from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size from typing_extensions import Literal from ._transform import _RandomApplyTransform -from ._utils import query_image, get_image_dimensions, has_any, is_simple_tensor +from ._utils import get_image_dimensions, has_any, is_simple_tensor, query_image class RandomHorizontalFlip(_RandomApplyTransform): diff --git a/torchvision/prototype/transforms/_meta.py b/torchvision/prototype/transforms/_meta.py index 6791bbbc69c..fcf0e0db883 100644 --- a/torchvision/prototype/transforms/_meta.py +++ b/torchvision/prototype/transforms/_meta.py @@ -1,9 +1,9 @@ -from typing import Union, Any, Dict, Optional +from typing import Any, Dict, Optional, Union import PIL.Image import torch from torchvision.prototype import features -from torchvision.prototype.transforms import Transform, functional as F +from torchvision.prototype.transforms import functional as F, Transform from torchvision.transforms.functional import convert_image_dtype from ._utils import is_simple_tensor diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index b8e9101f2a0..14c33db3ecb 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -1,8 +1,8 @@ import functools -from typing import Any, List, Type, Callable, Dict, Sequence, Union +from typing import Any, Callable, Dict, List, Sequence, Type, Union import torch -from torchvision.prototype.transforms import Transform, functional as F +from torchvision.prototype.transforms import functional as F, Transform from torchvision.transforms.transforms import _setup_size diff --git a/torchvision/prototype/transforms/_type_conversion.py b/torchvision/prototype/transforms/_type_conversion.py index 09c071a27e0..9a698aa5e23 100644 --- a/torchvision/prototype/transforms/_type_conversion.py +++ b/torchvision/prototype/transforms/_type_conversion.py @@ -3,7 +3,7 @@ import numpy as np import PIL.Image from torchvision.prototype import features -from torchvision.prototype.transforms import Transform, functional as F +from torchvision.prototype.transforms import functional as F, Transform from ._utils import is_simple_tensor diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index c41ef294975..c9fe79e41fe 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Tuple, Union, Type, Iterator +from typing import Any, Iterator, Optional, Tuple, Type, Union import PIL.Image import torch @@ -6,7 +6,7 @@ from torchvision.prototype import features from torchvision.prototype.utils._internal import query_recursively -from .functional._meta import get_dimensions_image_tensor, get_dimensions_image_pil +from .functional._meta import get_dimensions_image_pil, get_dimensions_image_tensor def query_image(sample: Any) -> Union[PIL.Image.Image, torch.Tensor, features.Image]: diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index 638049f96fb..1aef37600d6 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -5,116 +5,109 @@ convert_image_color_space_pil, ) # usort: skip -from ._augment import ( - erase_image_tensor, -) +from ._augment import erase_image_tensor from ._color import ( adjust_brightness, - adjust_brightness_image_tensor, adjust_brightness_image_pil, + adjust_brightness_image_tensor, adjust_contrast, - adjust_contrast_image_tensor, adjust_contrast_image_pil, + adjust_contrast_image_tensor, + adjust_gamma, + adjust_gamma_image_pil, + adjust_gamma_image_tensor, + adjust_hue, + adjust_hue_image_pil, + adjust_hue_image_tensor, adjust_saturation, - adjust_saturation_image_tensor, adjust_saturation_image_pil, + adjust_saturation_image_tensor, adjust_sharpness, - adjust_sharpness_image_tensor, adjust_sharpness_image_pil, - adjust_hue, - adjust_hue_image_tensor, - adjust_hue_image_pil, - adjust_gamma, - adjust_gamma_image_tensor, - adjust_gamma_image_pil, - posterize, - posterize_image_tensor, - posterize_image_pil, - solarize, - solarize_image_tensor, - solarize_image_pil, + adjust_sharpness_image_tensor, autocontrast, - autocontrast_image_tensor, autocontrast_image_pil, + autocontrast_image_tensor, equalize, - equalize_image_tensor, equalize_image_pil, + equalize_image_tensor, invert, - invert_image_tensor, invert_image_pil, + invert_image_tensor, + posterize, + posterize_image_pil, + posterize_image_tensor, + solarize, + solarize_image_pil, + solarize_image_tensor, ) from ._geometry import ( + affine, + affine_bounding_box, + affine_image_pil, + affine_image_tensor, + affine_segmentation_mask, + center_crop, + center_crop_bounding_box, + center_crop_image_pil, + center_crop_image_tensor, + center_crop_segmentation_mask, + crop, + crop_bounding_box, + crop_image_pil, + crop_image_tensor, + crop_segmentation_mask, + elastic, + elastic_bounding_box, + elastic_image_pil, + elastic_image_tensor, + elastic_segmentation_mask, + elastic_transform, + five_crop_image_pil, + five_crop_image_tensor, horizontal_flip, horizontal_flip_bounding_box, - horizontal_flip_image_tensor, horizontal_flip_image_pil, + horizontal_flip_image_tensor, horizontal_flip_segmentation_mask, + pad, + pad_bounding_box, + pad_image_pil, + pad_image_tensor, + pad_segmentation_mask, + perspective, + perspective_bounding_box, + perspective_image_pil, + perspective_image_tensor, + perspective_segmentation_mask, resize, resize_bounding_box, - resize_image_tensor, resize_image_pil, + resize_image_tensor, resize_segmentation_mask, - center_crop, - center_crop_bounding_box, - center_crop_segmentation_mask, - center_crop_image_tensor, - center_crop_image_pil, resized_crop, resized_crop_bounding_box, - resized_crop_image_tensor, resized_crop_image_pil, + resized_crop_image_tensor, resized_crop_segmentation_mask, - affine, - affine_bounding_box, - affine_image_tensor, - affine_image_pil, - affine_segmentation_mask, rotate, rotate_bounding_box, - rotate_image_tensor, rotate_image_pil, + rotate_image_tensor, rotate_segmentation_mask, - pad, - pad_bounding_box, - pad_image_tensor, - pad_image_pil, - pad_segmentation_mask, - crop, - crop_bounding_box, - crop_image_tensor, - crop_image_pil, - crop_segmentation_mask, - perspective, - perspective_bounding_box, - perspective_image_tensor, - perspective_image_pil, - perspective_segmentation_mask, - elastic, - elastic_transform, - elastic_bounding_box, - elastic_image_tensor, - elastic_image_pil, - elastic_segmentation_mask, + ten_crop_image_pil, + ten_crop_image_tensor, vertical_flip, - vertical_flip_image_tensor, - vertical_flip_image_pil, vertical_flip_bounding_box, + vertical_flip_image_pil, + vertical_flip_image_tensor, vertical_flip_segmentation_mask, - five_crop_image_tensor, - five_crop_image_pil, - ten_crop_image_tensor, - ten_crop_image_pil, -) -from ._misc import ( - normalize_image_tensor, - gaussian_blur, - gaussian_blur_image_tensor, - gaussian_blur_image_pil, ) +from ._misc import gaussian_blur, gaussian_blur_image_pil, gaussian_blur_image_tensor, normalize_image_tensor from ._type_conversion import ( decode_image_with_pil, decode_video_with_av, label_to_one_hot, - to_image_tensor, to_image_pil, + to_image_tensor, ) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index d5c5d305722..554fb98ae52 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -3,7 +3,7 @@ import PIL.Image import torch from torchvision.prototype import features -from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP +from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT # shortcut type diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index e48701bb436..61506393a4e 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -1,22 +1,22 @@ import numbers import warnings -from typing import Tuple, List, Optional, Sequence, Union +from typing import List, Optional, Sequence, Tuple, Union import PIL.Image import torch from torchvision.prototype import features -from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP +from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT from torchvision.transforms.functional import ( - pil_modes_mapping, - _get_inverse_affine_matrix, - InterpolationMode, _compute_output_size, + _get_inverse_affine_matrix, _get_perspective_coeffs, + InterpolationMode, + pil_modes_mapping, pil_to_tensor, to_pil_image, ) -from ._meta import convert_bounding_box_format, get_dimensions_image_tensor, get_dimensions_image_pil +from ._meta import convert_bounding_box_format, get_dimensions_image_pil, get_dimensions_image_tensor # shortcut type diff --git a/torchvision/prototype/transforms/functional/_meta.py b/torchvision/prototype/transforms/functional/_meta.py index 2386f47b226..db7918558bc 100644 --- a/torchvision/prototype/transforms/functional/_meta.py +++ b/torchvision/prototype/transforms/functional/_meta.py @@ -1,9 +1,9 @@ -from typing import Tuple, Optional +from typing import Optional, Tuple import PIL.Image import torch from torchvision.prototype.features import BoundingBoxFormat, ColorSpace -from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP +from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT get_dimensions_image_tensor = _FT.get_dimensions get_dimensions_image_pil = _FP.get_dimensions diff --git a/torchvision/prototype/transforms/functional/_misc.py b/torchvision/prototype/transforms/functional/_misc.py index e51cac1745e..d93194e2eab 100644 --- a/torchvision/prototype/transforms/functional/_misc.py +++ b/torchvision/prototype/transforms/functional/_misc.py @@ -1,4 +1,4 @@ -from typing import Optional, List, Union +from typing import List, Optional, Union import PIL.Image import torch diff --git a/torchvision/prototype/transforms/functional/_type_conversion.py b/torchvision/prototype/transforms/functional/_type_conversion.py index 37f8f9b70a3..0619852900f 100644 --- a/torchvision/prototype/transforms/functional/_type_conversion.py +++ b/torchvision/prototype/transforms/functional/_type_conversion.py @@ -1,5 +1,5 @@ import unittest.mock -from typing import Dict, Any, Tuple, Union +from typing import Any, Dict, Tuple, Union import numpy as np import PIL.Image diff --git a/torchvision/prototype/utils/_internal.py b/torchvision/prototype/utils/_internal.py index 233128880e3..fb5c3b83de6 100644 --- a/torchvision/prototype/utils/_internal.py +++ b/torchvision/prototype/utils/_internal.py @@ -3,18 +3,7 @@ import io import mmap import platform -from typing import ( - Any, - BinaryIO, - Callable, - Collection, - Iterator, - Sequence, - Tuple, - TypeVar, - Union, - Optional, -) +from typing import Any, BinaryIO, Callable, Collection, Iterator, Optional, Sequence, Tuple, TypeVar, Union import numpy as np import torch diff --git a/torchvision/transforms/_presets.py b/torchvision/transforms/_presets.py index e49912e0f00..33b94d01c9d 100644 --- a/torchvision/transforms/_presets.py +++ b/torchvision/transforms/_presets.py @@ -5,7 +5,7 @@ from typing import Optional, Tuple import torch -from torch import Tensor, nn +from torch import nn, Tensor from . import functional as F, InterpolationMode diff --git a/torchvision/transforms/_transforms_video.py b/torchvision/transforms/_transforms_video.py index 69512af6eb1..1ed6de7612d 100644 --- a/torchvision/transforms/_transforms_video.py +++ b/torchvision/transforms/_transforms_video.py @@ -4,10 +4,7 @@ import random import warnings -from torchvision.transforms import ( - RandomCrop, - RandomResizedCrop, -) +from torchvision.transforms import RandomCrop, RandomResizedCrop from . import _functional_video as F diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py index 357e5bf250e..9dbbe91e741 100644 --- a/torchvision/transforms/autoaugment.py +++ b/torchvision/transforms/autoaugment.py @@ -1,6 +1,6 @@ import math from enum import Enum -from typing import List, Tuple, Optional, Dict +from typing import Dict, List, Optional, Tuple import torch from torch import Tensor diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 442e8d4288d..8e94733651b 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -2,7 +2,7 @@ import numbers import warnings from enum import Enum -from typing import List, Tuple, Any, Optional, Union +from typing import Any, List, Optional, Tuple, Union import numpy as np import torch @@ -15,8 +15,7 @@ accimage = None from ..utils import _log_api_usage_once -from . import functional_pil as F_pil -from . import functional_tensor as F_t +from . import functional_pil as F_pil, functional_tensor as F_t class InterpolationMode(Enum): diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index c35edfb74b0..8f37005298b 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -1,9 +1,9 @@ import warnings -from typing import Optional, Tuple, List, Union +from typing import List, Optional, Tuple, Union import torch from torch import Tensor -from torch.nn.functional import grid_sample, conv2d, interpolate, pad as torch_pad +from torch.nn.functional import conv2d, grid_sample, interpolate, pad as torch_pad def _is_tensor_a_torch_image(x: Tensor) -> bool: @@ -247,7 +247,7 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor: if not torch.is_floating_point(img): result = convert_image_dtype(result, torch.float32) - result = (gain * result ** gamma).clamp(0, 1) + result = (gain * result**gamma).clamp(0, 1) result = convert_image_dtype(result, dtype) return result diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index cf119759982..29b24dc4414 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -3,7 +3,7 @@ import random import warnings from collections.abc import Sequence -from typing import Tuple, List, Optional +from typing import List, Optional, Tuple import torch from torch import Tensor @@ -15,7 +15,7 @@ from ..utils import _log_api_usage_once from . import functional as F -from .functional import InterpolationMode, _interpolation_modes_from_int +from .functional import _interpolation_modes_from_int, InterpolationMode __all__ = [ "Compose", diff --git a/torchvision/utils.py b/torchvision/utils.py index abb8e7f0e45..3809a13c049 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -449,7 +449,7 @@ def flow_to_image(flow: torch.Tensor) -> torch.Tensor: if flow.ndim != 4 or flow.shape[1] != 2: raise ValueError(f"Input flow should have shape (2, H, W) or (N, 2, H, W), got {orig_shape}.") - max_norm = torch.sum(flow ** 2, dim=1).sqrt().max() + max_norm = torch.sum(flow**2, dim=1).sqrt().max() epsilon = torch.finfo((flow).dtype).eps normalized_flow = flow / (max_norm + epsilon) img = _normalized_flow_to_image(normalized_flow) @@ -476,7 +476,7 @@ def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor: flow_image = torch.zeros((N, 3, H, W), dtype=torch.uint8, device=device) colorwheel = _make_colorwheel().to(device) # shape [55x3] num_cols = colorwheel.shape[0] - norm = torch.sum(normalized_flow ** 2, dim=1).sqrt() + norm = torch.sum(normalized_flow**2, dim=1).sqrt() a = torch.atan2(-normalized_flow[:, 1, :, :], -normalized_flow[:, 0, :, :]) / torch.pi fk = (a + 1) / 2 * (num_cols - 1) k0 = torch.floor(fk).to(torch.long) @@ -542,7 +542,7 @@ def _make_colorwheel() -> torch.Tensor: def _generate_color_palette(num_objects: int): - palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) + palette = torch.tensor([2**25 - 1, 2**15 - 1, 2**21 - 1]) return [tuple((i * palette) % 255) for i in range(num_objects)] From 14d221d70ed0b182eeeb3fa8df9fda7389175a80 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 25 Jul 2022 14:08:14 +0200 Subject: [PATCH 09/13] Fixed bug in affine get_params test --- test/test_prototype_transforms.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 6ccf779129e..27c2e1d581c 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -525,8 +525,10 @@ def test__get_params(self, degrees, translate, scale, shear, mocker): assert degrees[0] <= params["angle"] <= degrees[1] if translate is not None: - assert -translate[0] * w <= params["translations"][0] <= translate[0] * w - assert -translate[1] * h <= params["translations"][1] <= translate[1] * h + w_max = int(round(translate[0] * w)) + h_max = int(round(translate[1] * h)) + assert -w_max <= params["translations"][0] <= w_max + assert -h_max <= params["translations"][1] <= h_max else: assert params["translations"] == (0, 0) From 2ee8dcac5eedc07a2184fc1c38f8fbd7957c9c96 Mon Sep 17 00:00:00 2001 From: vfdev Date: Mon, 25 Jul 2022 15:50:49 +0200 Subject: [PATCH 10/13] Implemented RandomErase on PIL input as fallback to tensors (#6309) Added tests --- test/test_prototype_transforms.py | 98 ++++++++++++++++++++ torchvision/prototype/transforms/_augment.py | 6 +- 2 files changed, 102 insertions(+), 2 deletions(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 27c2e1d581c..b21a3c62878 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -1,5 +1,7 @@ import itertools +import PIL.Image + import pytest import torch from common_utils import assert_equal @@ -879,3 +881,99 @@ def test__transform(self, alpha, sigma, mocker): _ = transform(inpt) params = transform._get_params(inpt) fn.assert_called_once_with(inpt, **params, fill=fill, interpolation=interpolation) + + +class TestRandomErasing: + def test_assertions(self, mocker): + with pytest.raises(TypeError, match="Argument value should be either a number or str or a sequence"): + transforms.RandomErasing(value={}) + + with pytest.raises(ValueError, match="If value is str, it should be 'random'"): + transforms.RandomErasing(value="abc") + + with pytest.raises(TypeError, match="Scale should be a sequence"): + transforms.RandomErasing(scale=123) + + with pytest.raises(TypeError, match="Ratio should be a sequence"): + transforms.RandomErasing(ratio=123) + + with pytest.raises(ValueError, match="Scale should be between 0 and 1"): + transforms.RandomErasing(scale=[-1, 2]) + + image = mocker.MagicMock(spec=features.Image) + image.num_channels = 3 + image.image_size = (24, 32) + + transform = transforms.RandomErasing(value=[1, 2, 3, 4]) + + with pytest.raises(ValueError, match="If value is a sequence, it should have either a single value"): + transform._get_params(image) + + @pytest.mark.parametrize("value", [5.0, [1, 2, 3], "random"]) + def test__get_params(self, value, mocker): + image = mocker.MagicMock(spec=features.Image) + image.num_channels = 3 + image.image_size = (24, 32) + + transform = transforms.RandomErasing(value=value) + params = transform._get_params(image) + + v = params["v"] + h, w = params["h"], params["w"] + i, j = params["i"], params["j"] + assert isinstance(v, torch.Tensor) + if value == "random": + assert v.shape == (image.num_channels, h, w) + elif isinstance(value, (int, float)): + assert v.shape == (1, 1, 1) + elif isinstance(value, (list, tuple)): + assert v.shape == (image.num_channels, 1, 1) + + assert 0 <= i <= image.image_size[0] - h + assert 0 <= j <= image.image_size[1] - w + + @pytest.mark.parametrize("p", [0.0, 1.0]) + @pytest.mark.parametrize( + "inpt_type", + [ + (torch.Tensor, {"shape": (3, 24, 32)}), + (PIL.Image.Image, {"size": (24, 32), "mode": "RGB"}), + ], + ) + def test__transform(self, p, inpt_type, mocker): + value = 1.0 + transform = transforms.RandomErasing(p=p, value=value) + + inpt = mocker.MagicMock(spec=inpt_type[0], **inpt_type[1]) + erase_image_tensor_inpt = inpt + fn = mocker.patch( + "torchvision.prototype.transforms.functional.erase_image_tensor", + return_value=mocker.MagicMock(spec=torch.Tensor), + ) + if inpt_type[0] == PIL.Image.Image: + erase_image_tensor_inpt = mocker.MagicMock(spec=torch.Tensor) + + # vfdev-5: I do not know how to patch pil_to_tensor if it is already imported + # TODO: patch pil_to_tensor and run below checks for PIL.Image.Image inputs + if p > 0.0: + return + + mocker.patch( + "torchvision.transforms.functional.pil_to_tensor", + return_value=erase_image_tensor_inpt, + ) + mocker.patch( + "torchvision.transforms.functional.to_pil_image", + return_value=mocker.MagicMock(spec=PIL.Image.Image), + ) + + # Let's mock transform._get_params to control the output: + transform._get_params = mocker.MagicMock() + output = transform(inpt) + print(inpt_type) + assert isinstance(output, inpt_type[0]) + params = transform._get_params(inpt) + if p > 0.0: + fn.assert_called_once_with(erase_image_tensor_inpt, **params) + else: + fn.call_count == 0 diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index d1c3db816ad..2c71a5faf64 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -7,6 +7,7 @@ import torch from torchvision.prototype import features from torchvision.prototype.transforms import functional as F, Transform +from torchvision.transforms.functional import pil_to_tensor, to_pil_image from ._transform import _RandomApplyTransform from ._utils import get_image_dimensions, has_all, has_any, query_image @@ -92,8 +93,9 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return features.Image.new_like(inpt, output) return output elif isinstance(inpt, PIL.Image.Image): - # TODO: We should implement a fallback to tensor, like gaussian_blur etc - raise RuntimeError("Not implemented") + t_img = pil_to_tensor(inpt) + output = F.erase_image_tensor(t_img, **params) + return to_pil_image(output, mode=inpt.mode) else: return inpt From 23112f813455409636a0ef5927601bb0ba2ef5c4 Mon Sep 17 00:00:00 2001 From: vfdev Date: Wed, 27 Jul 2022 11:15:00 +0200 Subject: [PATCH 11/13] Added image_size computation for BoundingBox.rotate if expand (#6319) * Added image_size computation for BoundingBox.rotate if expand * Added tests --- test/test_prototype_transforms.py | 14 ++++++++++++++ test/test_prototype_transforms_functional.py | 5 +++-- torchvision/prototype/features/_bounding_box.py | 16 +++++++++++++--- torchvision/prototype/features/_image.py | 2 +- torchvision/transforms/functional_tensor.py | 2 +- 5 files changed, 32 insertions(+), 7 deletions(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index b21a3c62878..33dd94925b6 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -467,6 +467,20 @@ def test__transform(self, degrees, expand, fill, center, mocker): fn.assert_called_once_with(inpt, **params, interpolation=interpolation, expand=expand, fill=fill, center=center) + @pytest.mark.parametrize("angle", [34, -87]) + @pytest.mark.parametrize("expand", [False, True]) + def test_boundingbox_image_size(self, angle, expand): + # Specific test for BoundingBox.rotate + bbox = features.BoundingBox( + torch.tensor([1, 2, 3, 4]), format=features.BoundingBoxFormat.XYXY, image_size=(32, 32) + ) + img = features.Image(torch.rand(1, 3, 32, 32)) + + out_img = img.rotate(angle, expand=expand) + out_bbox = bbox.rotate(angle, expand=expand) + + assert out_img.image_size == out_bbox.image_size + class TestRandomAffine: def test_assertions(self): diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index fb5f10459fe..d3353a0932d 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -693,13 +693,11 @@ def test_scriptable(kernel): "InterpolationMode", "decode_video_with_av", "crop", - "rotate", "perspective", "elastic_transform", "elastic", } # We skip 'crop' due to missing 'height' and 'width' - # We skip 'rotate' due to non implemented yet expand=True case for bboxes # We skip 'perspective' as it requires different input args than perspective_image_tensor etc # Skip 'elastic', TODO: inspect why test is failing ], @@ -999,6 +997,9 @@ def _compute_expected_bbox(bbox, angle_, expand_, center_): out_bbox[2] -= tr_x out_bbox[3] -= tr_y + # image_size should be updated, but it is OK here to skip its computation + # as we do not compute it in F.rotate_bounding_box + out_bbox = features.BoundingBox( out_bbox, format=features.BoundingBoxFormat.XYXY, diff --git a/torchvision/prototype/features/_bounding_box.py b/torchvision/prototype/features/_bounding_box.py index 59b88d2931f..54e1315c9ab 100644 --- a/torchvision/prototype/features/_bounding_box.py +++ b/torchvision/prototype/features/_bounding_box.py @@ -5,6 +5,8 @@ import torch from torchvision._utils import StrEnum from torchvision.transforms import InterpolationMode +from torchvision.transforms.functional import _get_inverse_affine_matrix +from torchvision.transforms.functional_tensor import _compute_output_size from ._feature import _Feature @@ -168,10 +170,18 @@ def rotate( output = _F.rotate_bounding_box( self, format=self.format, image_size=self.image_size, angle=angle, expand=expand, center=center ) - # TODO: update output image size if expand is True + image_size = self.image_size if expand: - raise RuntimeError("Not yet implemented") - return BoundingBox.new_like(self, output, dtype=output.dtype) + # The way we recompute image_size is not optimal due to redundant computations of + # - rotation matrix (_get_inverse_affine_matrix) + # - points dot matrix (_compute_output_size) + # Alternatively, we could return new image size by _F.rotate_bounding_box + height, width = image_size + rotation_matrix = _get_inverse_affine_matrix([0.0, 0.0], angle, [0.0, 0.0], 1.0, [0.0, 0.0]) + new_width, new_height = _compute_output_size(rotation_matrix, width, height) + image_size = (new_height, new_width) + + return BoundingBox.new_like(self, output, dtype=output.dtype, image_size=image_size) def affine( self, diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py index 1a55e5c5acb..303486f98ba 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/features/_image.py @@ -74,7 +74,7 @@ def new_like( @property def image_size(self) -> Tuple[int, int]: - return cast(Tuple[int, int], self.shape[-2:]) + return cast(Tuple[int, int], tuple(self.shape[-2:])) @property def num_channels(self) -> int: diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 8f37005298b..df5396a063c 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -634,7 +634,7 @@ def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int] cmax = torch.ceil((max_vals / tol).trunc_() * tol) cmin = torch.floor((min_vals / tol).trunc_() * tol) size = cmax - cmin - return int(size[0]), int(size[1]) + return int(size[0]), int(size[1]) # w, h def rotate( From 2586de6cefbe44e9748c1b381278f95fd7194f96 Mon Sep 17 00:00:00 2001 From: vfdev Date: Wed, 27 Jul 2022 11:15:23 +0200 Subject: [PATCH 12/13] Added erase_image_pil and eager/jit erase_image_tensor test (#6320) --- test/test_prototype_transforms_functional.py | 7 +++++++ torchvision/prototype/transforms/_augment.py | 5 +---- .../prototype/transforms/functional/__init__.py | 2 +- .../prototype/transforms/functional/_augment.py | 17 ++++++++++------- 4 files changed, 19 insertions(+), 12 deletions(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index d3353a0932d..5f105b3f6e2 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -654,6 +654,13 @@ def adjust_sharpness_image_tensor(): yield SampleInput(image, sharpness_factor=sharpness_factor) +@register_kernel_info_from_sample_inputs_fn +def erase_image_tensor(): + for image in make_images(): + c = image.shape[-3] + yield SampleInput(image, i=1, j=2, h=6, w=7, v=torch.rand(c, 6, 7)) + + @pytest.mark.parametrize( "kernel", [ diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index 2c71a5faf64..12e2cd3cc6d 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -7,7 +7,6 @@ import torch from torchvision.prototype import features from torchvision.prototype.transforms import functional as F, Transform -from torchvision.transforms.functional import pil_to_tensor, to_pil_image from ._transform import _RandomApplyTransform from ._utils import get_image_dimensions, has_all, has_any, query_image @@ -93,9 +92,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return features.Image.new_like(inpt, output) return output elif isinstance(inpt, PIL.Image.Image): - t_img = pil_to_tensor(inpt) - output = F.erase_image_tensor(t_img, **params) - return to_pil_image(output, mode=inpt.mode) + return F.erase_image_pil(inpt, **params) else: return inpt diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index 1aef37600d6..82e3096821a 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -5,7 +5,7 @@ convert_image_color_space_pil, ) # usort: skip -from ._augment import erase_image_tensor +from ._augment import erase_image_pil, erase_image_tensor from ._color import ( adjust_brightness, adjust_brightness_image_pil, diff --git a/torchvision/prototype/transforms/functional/_augment.py b/torchvision/prototype/transforms/functional/_augment.py index 3920d1b3065..84b069cf396 100644 --- a/torchvision/prototype/transforms/functional/_augment.py +++ b/torchvision/prototype/transforms/functional/_augment.py @@ -1,13 +1,16 @@ +import PIL.Image + +import torch from torchvision.transforms import functional_tensor as _FT +from torchvision.transforms.functional import pil_to_tensor, to_pil_image erase_image_tensor = _FT.erase -# TODO: Don't forget to clean up from the primitives kernels those that shouldn't be kernels. -# Like the mixup and cutmix stuff - -# This function is copy-pasted to Image and OneHotLabel and may be refactored -# def _mixup_tensor(input: torch.Tensor, batch_dim: int, lam: float) -> torch.Tensor: -# input = input.clone() -# return input.roll(1, batch_dim).mul_(1 - lam).add_(input.mul_(lam)) +def erase_image_pil( + img: PIL.Image.Image, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False +) -> PIL.Image.Image: + t_img = pil_to_tensor(img) + output = erase_image_tensor(t_img, i=i, j=j, h=h, w=w, v=v, inplace=inplace) + return to_pil_image(output, mode=img.mode) From d226e1603680633d06244fd59a217979c4ca5362 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Thu, 28 Jul 2022 12:45:01 +0200 Subject: [PATCH 13/13] Updates according to the review --- torchvision/prototype/transforms/_geometry.py | 62 ++++++++++--------- .../prototype/transforms/_transform.py | 17 +++-- torchvision/prototype/transforms/_utils.py | 29 ++------- .../transforms/functional/_geometry.py | 1 + torchvision/transforms/functional_pil.py | 2 +- 5 files changed, 53 insertions(+), 58 deletions(-) diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index b1618b0eef5..decdee06073 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -233,6 +233,8 @@ def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None: raise ValueError(f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple") +# TODO: let's use torchvision._utils.StrEnum to have the best of both worlds (strings and enums) +# https://github.com/pytorch/vision/issues/6250 def _check_padding_mode_arg(padding_mode: Literal["constant", "edge", "reflect", "symmetric"]) -> None: if padding_mode not in ["constant", "edge", "reflect", "symmetric"]: raise ValueError("Padding mode should be either constant, edge, reflect or symmetric") @@ -437,18 +439,18 @@ def __init__( self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") - if padding is not None: - _check_padding_arg(padding) - - if (padding is not None) or pad_if_needed: - _check_padding_mode_arg(padding_mode) - _check_fill_arg(fill) - self.padding = padding self.pad_if_needed = pad_if_needed self.fill = fill self.padding_mode = padding_mode + self._pad_op = None + if self.padding is not None: + self._pad_op = Pad(self.padding, fill=self.fill, padding_mode=self.padding_mode) + + if self.pad_if_needed: + self._pad_op = Pad(0, fill=self.fill, padding_mode=self.padding_mode) + def _get_params(self, sample: Any) -> Dict[str, Any]: image = query_image(sample) _, height, width = get_image_dimensions(image) @@ -466,34 +468,36 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: left = torch.randint(0, width - output_width + 1, size=(1,)).item() return dict(top=top, left=left, height=output_height, width=output_width) - def _forward(self, flat_inputs: List[Any]) -> List[Any]: - if self.padding is not None: - flat_inputs = [F.pad(flat_input, self.padding, self.fill, self.padding_mode) for flat_input in flat_inputs] - - image = query_image(flat_inputs) - _, height, width = get_image_dimensions(image) + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.crop(inpt, **params) - # pad the width if needed - if self.pad_if_needed and width < self.size[1]: - padding = [self.size[1] - width, 0] - flat_inputs = [F.pad(flat_input, padding, self.fill, self.padding_mode) for flat_input in flat_inputs] - # pad the height if needed - if self.pad_if_needed and height < self.size[0]: - padding = [0, self.size[0] - height] - flat_inputs = [F.pad(flat_input, padding, self.fill, self.padding_mode) for flat_input in flat_inputs] + def forward(self, *inputs: Any) -> Any: + sample = inputs if len(inputs) > 1 else inputs[0] - params = self._get_params(flat_inputs) + if self._pad_op is not None: + sample = self._pad_op(sample) - return [F.crop(flat_input, **params) for flat_input in flat_inputs] + image = query_image(sample) + _, height, width = get_image_dimensions(image) - def forward(self, *inputs: Any) -> Any: - from torch.utils._pytree import tree_flatten, tree_unflatten + if self.pad_if_needed: + # This check is to explicitly ensure that self._pad_op is defined + if self._pad_op is None: + raise RuntimeError( + "Internal error, self._pad_op is None. " + "Please, fill an issue about that on https://github.com/pytorch/vision/issues" + ) - sample = inputs if len(inputs) > 1 else inputs[0] + # pad the width if needed + if width < self.size[1]: + self._pad_op.padding = [self.size[1] - width, 0] + sample = self._pad_op(sample) + # pad the height if needed + if height < self.size[0]: + self._pad_op.padding = [0, self.size[0] - height] + sample = self._pad_op(sample) - flat_inputs, spec = tree_flatten(sample) - out_flat_inputs = self._forward(flat_inputs) - return tree_unflatten(out_flat_inputs, spec) + return super().forward(sample) class RandomPerspective(_RandomApplyTransform): diff --git a/torchvision/prototype/transforms/_transform.py b/torchvision/prototype/transforms/_transform.py index d02732f552c..e7277748d3a 100644 --- a/torchvision/prototype/transforms/_transform.py +++ b/torchvision/prototype/transforms/_transform.py @@ -1,10 +1,11 @@ import enum -import functools from typing import Any, Dict +import PIL.Image import torch from torch import nn -from torchvision.prototype.utils._internal import apply_recursively +from torch.utils._pytree import tree_flatten, tree_unflatten +from torchvision.prototype.features import _Feature from torchvision.utils import _log_api_usage_once @@ -16,12 +17,20 @@ def __init__(self) -> None: def _get_params(self, sample: Any) -> Dict[str, Any]: return dict() - def _transform(self, input: Any, params: Dict[str, Any]) -> Any: + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: raise NotImplementedError def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] - return apply_recursively(functools.partial(self._transform, params=self._get_params(sample)), sample) + + params = self._get_params(sample) + + flat_inputs, spec = tree_flatten(sample) + transformed_types = (torch.Tensor, _Feature, PIL.Image.Image) + flat_outputs = [ + self._transform(inpt, params) if isinstance(inpt, transformed_types) else inpt for inpt in flat_inputs + ] + return tree_unflatten(flat_outputs, spec) def extra_repr(self) -> str: extra = [] diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index c9fe79e41fe..3de2f196c9f 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -1,10 +1,9 @@ -from typing import Any, Iterator, Optional, Tuple, Type, Union +from typing import Any, Tuple, Type, Union import PIL.Image import torch from torch.utils._pytree import tree_flatten from torchvision.prototype import features -from torchvision.prototype.utils._internal import query_recursively from .functional._meta import get_dimensions_image_pil, get_dimensions_image_tensor @@ -18,22 +17,6 @@ def query_image(sample: Any) -> Union[PIL.Image.Image, torch.Tensor, features.Im raise TypeError("No image was found in the sample") -# vfdev-5: let's use tree_flatten instead of query_recursively and internal fn to make the code simplier -def query_image_(sample: Any) -> Union[PIL.Image.Image, torch.Tensor, features.Image]: - def fn( - id: Tuple[Any, ...], input: Any - ) -> Optional[Tuple[Tuple[Any, ...], Union[PIL.Image.Image, torch.Tensor, features.Image]]]: - if type(input) == torch.Tensor or isinstance(input, (PIL.Image.Image, features.Image)): - return id, input - - return None - - try: - return next(query_recursively(fn, sample))[1] - except StopIteration: - raise TypeError("No image was found in the sample") - - def get_image_dimensions(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tuple[int, int, int]: if isinstance(image, features.Image): channels = image.num_channels @@ -47,16 +30,14 @@ def get_image_dimensions(image: Union[PIL.Image.Image, torch.Tensor, features.Im return channels, height, width -def _extract_types(sample: Any) -> Iterator[Type]: - return query_recursively(lambda id, input: type(input), sample) - - def has_any(sample: Any, *types: Type) -> bool: - return any(issubclass(type, types) for type in _extract_types(sample)) + flat_sample, _ = tree_flatten(sample) + return any(issubclass(type(obj), types) for obj in flat_sample) def has_all(sample: Any, *types: Type) -> bool: - return not bool(set(types) - set(_extract_types(sample))) + flat_sample, _ = tree_flatten(sample) + return not bool(set(types) - set([type(obj) for obj in flat_sample])) def is_simple_tensor(input: Any) -> bool: diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 61506393a4e..d5eec09bf2f 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -814,6 +814,7 @@ def elastic_bounding_box( format: features.BoundingBoxFormat, displacement: torch.Tensor, ) -> torch.Tensor: + # TODO: add in docstring about approximation we are doing for grid inversion displacement = displacement.to(bounding_box.device) original_shape = bounding_box.shape diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index 768176e6783..ec65b62314c 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -260,7 +260,7 @@ def _parse_fill( ) -> Dict[str, Optional[Union[float, List[float], Tuple[float, ...]]]]: # Process fill color for affine transforms - num_bands = len(img.getbands()) + num_bands = get_image_num_channels(img) if fill is None: fill = 0 if isinstance(fill, (int, float)) and num_bands > 1: