diff --git a/test/datasets_utils.py b/test/datasets_utils.py index c02ffeb0d68..e8290b55c4b 100644 --- a/test/datasets_utils.py +++ b/test/datasets_utils.py @@ -584,11 +584,8 @@ def test_transforms(self, config): @test_all_configs def test_transforms_v2_wrapper(self, config): - # Although this is a stable test, we unconditionally import from `torchvision.prototype` here. The wrapper needs - # to be available with the next release when v2 is released. Thus, if this import somehow fails on the release - # branch, we screwed up the roll-out - from torchvision.prototype.datapoints import wrap_dataset_for_transforms_v2 - from torchvision.prototype.datapoints._datapoint import Datapoint + from torchvision.datapoints import wrap_dataset_for_transforms_v2 + from torchvision.datapoints._datapoint import Datapoint try: with self.create_dataset(config) as (dataset, _): @@ -596,12 +593,13 @@ def test_transforms_v2_wrapper(self, config): wrapped_sample = wrapped_dataset[0] assert tree_any(lambda item: isinstance(item, (Datapoint, PIL.Image.Image)), wrapped_sample) except TypeError as error: - if str(error).startswith(f"No wrapper exists for dataset class {type(dataset).__name__}"): - return + msg = f"No wrapper exists for dataset class {type(dataset).__name__}" + if str(error).startswith(msg): + pytest.skip(msg) raise error except RuntimeError as error: if "currently not supported by this wrapper" in str(error): - return + pytest.skip("Config is currently not supported by this wrapper") raise error diff --git a/test/prototype_common_utils.py b/test/prototype_common_utils.py index f2ae8d2b9e5..8648a09ad94 100644 --- a/test/prototype_common_utils.py +++ b/test/prototype_common_utils.py @@ -12,12 +12,13 @@ import pytest import torch import torch.testing +import torchvision.prototype.datapoints as proto_datapoints from datasets_utils import combinations_grid from torch.nn.functional import one_hot from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair -from torchvision.prototype import datapoints -from torchvision.prototype.transforms.functional import convert_dtype_image_tensor, to_image_tensor +from torchvision import datapoints from torchvision.transforms.functional_tensor import _max_value as get_max_value +from torchvision.transforms.v2.functional import convert_dtype_image_tensor, to_image_tensor __all__ = [ "assert_close", @@ -457,7 +458,7 @@ def fn(shape, dtype, device): # The idiom `make_tensor(..., dtype=torch.int64).to(dtype)` is intentional to only get integer values, # regardless of the requested dtype, e.g. 0 or 0.0 rather than 0 or 0.123 data = torch.testing.make_tensor(shape, low=0, high=num_categories, dtype=torch.int64, device=device).to(dtype) - return datapoints.Label(data, categories=categories) + return proto_datapoints.Label(data, categories=categories) return LabelLoader(fn, shape=extra_dims, dtype=dtype, categories=categories) @@ -481,7 +482,7 @@ def fn(shape, dtype, device): # since `one_hot` only supports int64 label = make_label_loader(extra_dims=extra_dims, categories=num_categories, dtype=torch.int64).load(device) data = one_hot(label, num_classes=num_categories).to(dtype) - return datapoints.OneHotLabel(data, categories=categories) + return proto_datapoints.OneHotLabel(data, categories=categories) return OneHotLabelLoader(fn, shape=(*extra_dims, num_categories), dtype=dtype, categories=categories) diff --git a/test/prototype_transforms_dispatcher_infos.py b/test/prototype_transforms_dispatcher_infos.py index 442dd526ed3..308f787ba6b 100644 --- a/test/prototype_transforms_dispatcher_infos.py +++ b/test/prototype_transforms_dispatcher_infos.py @@ -1,10 +1,10 @@ import collections.abc import pytest -import torchvision.prototype.transforms.functional as F +import torchvision.transforms.v2.functional as F from prototype_common_utils import InfoBase, TestMark from prototype_transforms_kernel_infos import KERNEL_INFOS, pad_xfail_jit_fill_condition -from torchvision.prototype import datapoints +from torchvision import datapoints __all__ = ["DispatcherInfo", "DISPATCHER_INFOS"] diff --git a/test/prototype_transforms_kernel_infos.py b/test/prototype_transforms_kernel_infos.py index ce05c980a87..a0f7da5e262 100644 --- a/test/prototype_transforms_kernel_infos.py +++ b/test/prototype_transforms_kernel_infos.py @@ -8,7 +8,7 @@ import pytest import torch.testing import torchvision.ops -import torchvision.prototype.transforms.functional as F +import torchvision.transforms.v2.functional as F from datasets_utils import combinations_grid from prototype_common_utils import ( ArgsKwargs, @@ -28,7 +28,7 @@ TestMark, ) from torch.utils._pytree import tree_map -from torchvision.prototype import datapoints +from torchvision import datapoints from torchvision.transforms.functional_tensor import _max_value as get_max_value, _parse_pad_padding __all__ = ["KernelInfo", "KERNEL_INFOS"] @@ -2383,19 +2383,18 @@ def sample_inputs_convert_dtype_video(): def sample_inputs_uniform_temporal_subsample_video(): for video_loader in make_video_loaders(sizes=["random"], num_frames=[4]): - for temporal_dim in [-4, len(video_loader.shape) - 4]: - yield ArgsKwargs(video_loader, num_samples=2, temporal_dim=temporal_dim) + yield ArgsKwargs(video_loader, num_samples=2) -def reference_uniform_temporal_subsample_video(x, num_samples, temporal_dim=-4): +def reference_uniform_temporal_subsample_video(x, num_samples): # Copy-pasted from # https://github.com/facebookresearch/pytorchvideo/blob/c8d23d8b7e597586a9e2d18f6ed31ad8aa379a7a/pytorchvideo/transforms/functional.py#L19 - t = x.shape[temporal_dim] + t = x.shape[-4] assert num_samples > 0 and t > 0 # Sample by nearest neighbor interpolation if num_samples > t. indices = torch.linspace(0, t - 1, num_samples) indices = torch.clamp(indices, 0, t - 1).long() - return torch.index_select(x, temporal_dim, indices) + return torch.index_select(x, -4, indices) def reference_inputs_uniform_temporal_subsample_video(): @@ -2410,12 +2409,5 @@ def reference_inputs_uniform_temporal_subsample_video(): sample_inputs_fn=sample_inputs_uniform_temporal_subsample_video, reference_fn=reference_uniform_temporal_subsample_video, reference_inputs_fn=reference_inputs_uniform_temporal_subsample_video, - test_marks=[ - TestMark( - ("TestKernels", "test_batched_vs_single"), - pytest.mark.skip("Positive `temporal_dim` arguments are not equivalent for batched and single inputs"), - condition=lambda args_kwargs: args_kwargs.kwargs.get("temporal_dim") >= 0, - ), - ], ) ) diff --git a/test/test_prototype_datapoints.py b/test/test_prototype_datapoints.py index c2cc0986b71..615fa9f614d 100644 --- a/test/test_prototype_datapoints.py +++ b/test/test_prototype_datapoints.py @@ -5,8 +5,8 @@ from PIL import Image -from torchvision import datasets -from torchvision.prototype import datapoints +from torchvision import datapoints, datasets +from torchvision.prototype import datapoints as proto_datapoints @pytest.mark.parametrize( @@ -24,38 +24,38 @@ ], ) def test_new_requires_grad(data, input_requires_grad, expected_requires_grad): - datapoint = datapoints.Label(data, requires_grad=input_requires_grad) + datapoint = proto_datapoints.Label(data, requires_grad=input_requires_grad) assert datapoint.requires_grad is expected_requires_grad def test_isinstance(): assert isinstance( - datapoints.Label([0, 1, 0], categories=["foo", "bar"]), + proto_datapoints.Label([0, 1, 0], categories=["foo", "bar"]), torch.Tensor, ) def test_wrapping_no_copy(): tensor = torch.tensor([0, 1, 0], dtype=torch.int64) - label = datapoints.Label(tensor, categories=["foo", "bar"]) + label = proto_datapoints.Label(tensor, categories=["foo", "bar"]) assert label.data_ptr() == tensor.data_ptr() def test_to_wrapping(): tensor = torch.tensor([0, 1, 0], dtype=torch.int64) - label = datapoints.Label(tensor, categories=["foo", "bar"]) + label = proto_datapoints.Label(tensor, categories=["foo", "bar"]) label_to = label.to(torch.int32) - assert type(label_to) is datapoints.Label + assert type(label_to) is proto_datapoints.Label assert label_to.dtype is torch.int32 assert label_to.categories is label.categories def test_to_datapoint_reference(): tensor = torch.tensor([0, 1, 0], dtype=torch.int64) - label = datapoints.Label(tensor, categories=["foo", "bar"]).to(torch.int32) + label = proto_datapoints.Label(tensor, categories=["foo", "bar"]).to(torch.int32) tensor_to = tensor.to(label) @@ -65,31 +65,31 @@ def test_to_datapoint_reference(): def test_clone_wrapping(): tensor = torch.tensor([0, 1, 0], dtype=torch.int64) - label = datapoints.Label(tensor, categories=["foo", "bar"]) + label = proto_datapoints.Label(tensor, categories=["foo", "bar"]) label_clone = label.clone() - assert type(label_clone) is datapoints.Label + assert type(label_clone) is proto_datapoints.Label assert label_clone.data_ptr() != label.data_ptr() assert label_clone.categories is label.categories def test_requires_grad__wrapping(): tensor = torch.tensor([0, 1, 0], dtype=torch.float32) - label = datapoints.Label(tensor, categories=["foo", "bar"]) + label = proto_datapoints.Label(tensor, categories=["foo", "bar"]) assert not label.requires_grad label_requires_grad = label.requires_grad_(True) - assert type(label_requires_grad) is datapoints.Label + assert type(label_requires_grad) is proto_datapoints.Label assert label.requires_grad assert label_requires_grad.requires_grad def test_other_op_no_wrapping(): tensor = torch.tensor([0, 1, 0], dtype=torch.int64) - label = datapoints.Label(tensor, categories=["foo", "bar"]) + label = proto_datapoints.Label(tensor, categories=["foo", "bar"]) # any operation besides .to() and .clone() will do here output = label * 2 @@ -107,33 +107,33 @@ def test_other_op_no_wrapping(): ) def test_no_tensor_output_op_no_wrapping(op): tensor = torch.tensor([0, 1, 0], dtype=torch.int64) - label = datapoints.Label(tensor, categories=["foo", "bar"]) + label = proto_datapoints.Label(tensor, categories=["foo", "bar"]) output = op(label) - assert type(output) is not datapoints.Label + assert type(output) is not proto_datapoints.Label def test_inplace_op_no_wrapping(): tensor = torch.tensor([0, 1, 0], dtype=torch.int64) - label = datapoints.Label(tensor, categories=["foo", "bar"]) + label = proto_datapoints.Label(tensor, categories=["foo", "bar"]) output = label.add_(0) assert type(output) is torch.Tensor - assert type(label) is datapoints.Label + assert type(label) is proto_datapoints.Label def test_wrap_like(): tensor = torch.tensor([0, 1, 0], dtype=torch.int64) - label = datapoints.Label(tensor, categories=["foo", "bar"]) + label = proto_datapoints.Label(tensor, categories=["foo", "bar"]) # any operation besides .to() and .clone() will do here output = label * 2 - label_new = datapoints.Label.wrap_like(label, output) + label_new = proto_datapoints.Label.wrap_like(label, output) - assert type(label_new) is datapoints.Label + assert type(label_new) is proto_datapoints.Label assert label_new.data_ptr() == output.data_ptr() assert label_new.categories is label.categories diff --git a/test/test_prototype_datasets_builtin.py b/test/test_prototype_datasets_builtin.py index 7b33dc3e8a0..4848e799f04 100644 --- a/test/test_prototype_datasets_builtin.py +++ b/test/test_prototype_datasets_builtin.py @@ -5,8 +5,8 @@ import pytest import torch +import torchvision.transforms.v2 as transforms -import torchvision.prototype.transforms.utils from builtin_dataset_mocks import DATASET_MOCKS, parametrize_dataset_mocks from torch.testing._comparison import not_close_error_metas, ObjectPair, TensorLikePair @@ -19,10 +19,13 @@ from torchdata.dataloader2.graph.utils import traverse_dps from torchdata.datapipes.iter import ShardingFilter, Shuffler from torchdata.datapipes.utils import StreamWrapper +from torchvision import datapoints from torchvision._utils import sequence_to_str -from torchvision.prototype import datapoints, datasets, transforms +from torchvision.prototype import datasets +from torchvision.prototype.datapoints import Label from torchvision.prototype.datasets.utils import EncodedImage from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE +from torchvision.transforms.v2.utils import is_simple_tensor def assert_samples_equal(*args, msg=None, **kwargs): @@ -141,9 +144,7 @@ def test_no_unaccompanied_simple_tensors(self, dataset_mock, config): dataset, _ = dataset_mock.load(config) sample = next_consume(iter(dataset)) - simple_tensors = { - key for key, value in sample.items() if torchvision.prototype.transforms.utils.is_simple_tensor(value) - } + simple_tensors = {key for key, value in sample.items() if is_simple_tensor(value)} if simple_tensors and not any( isinstance(item, (datapoints.Image, datapoints.Video, EncodedImage)) for item in sample.values() @@ -276,6 +277,6 @@ def test_sample_content(self, dataset_mock, config): assert "label" in sample assert isinstance(sample["image"], datapoints.Image) - assert isinstance(sample["label"], datapoints.Label) + assert isinstance(sample["label"], Label) assert sample["image"].shape == (1, 16, 16) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index ff772c5151f..b8f20a26b24 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -10,8 +10,11 @@ import PIL.Image import pytest import torch +import torchvision.prototype.datapoints as proto_datapoints +import torchvision.prototype.transforms as proto_transforms +import torchvision.transforms.v2 as transforms -import torchvision.prototype.transforms.utils +import torchvision.transforms.v2.utils from common_utils import cpu_and_gpu from prototype_common_utils import ( assert_equal, @@ -28,11 +31,12 @@ make_videos, ) from torch.utils._pytree import tree_flatten, tree_unflatten +from torchvision import datapoints from torchvision.ops.boxes import box_iou -from torchvision.prototype import datapoints, transforms -from torchvision.prototype.transforms import functional as F -from torchvision.prototype.transforms.utils import check_type, is_simple_tensor, query_chw from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image +from torchvision.transforms.v2 import functional as F +from torchvision.transforms.v2._utils import _convert_fill_arg +from torchvision.transforms.v2.utils import check_type, is_simple_tensor, query_chw BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims] @@ -281,8 +285,8 @@ def test_common(self, transform, adapter, container_type, image_or_video, device ], ) for transform in [ - transforms.RandomMixup(alpha=1.0), - transforms.RandomCutmix(alpha=1.0), + proto_transforms.RandomMixup(alpha=1.0), + proto_transforms.RandomCutmix(alpha=1.0), ] ] ) @@ -563,7 +567,7 @@ def test_assertions(self): def test__transform(self, padding, fill, padding_mode, mocker): transform = transforms.Pad(padding, fill=fill, padding_mode=padding_mode) - fn = mocker.patch("torchvision.prototype.transforms.functional.pad") + fn = mocker.patch("torchvision.transforms.v2.functional.pad") inpt = mocker.MagicMock(spec=datapoints.Image) _ = transform(inpt) @@ -576,7 +580,7 @@ def test__transform(self, padding, fill, padding_mode, mocker): def test__transform_image_mask(self, fill, mocker): transform = transforms.Pad(1, fill=fill, padding_mode="constant") - fn = mocker.patch("torchvision.prototype.transforms.functional.pad") + fn = mocker.patch("torchvision.transforms.v2.functional.pad") image = datapoints.Image(torch.rand(3, 32, 32)) mask = datapoints.Mask(torch.randint(0, 5, size=(32, 32))) inpt = [image, mask] @@ -634,7 +638,7 @@ def test__transform(self, fill, side_range, mocker): transform = transforms.RandomZoomOut(fill=fill, side_range=side_range, p=1) - fn = mocker.patch("torchvision.prototype.transforms.functional.pad") + fn = mocker.patch("torchvision.transforms.v2.functional.pad") # vfdev-5, Feature Request: let's store params as Transform attribute # This could be also helpful for users # Otherwise, we can mock transform._get_params @@ -651,7 +655,7 @@ def test__transform(self, fill, side_range, mocker): def test__transform_image_mask(self, fill, mocker): transform = transforms.RandomZoomOut(fill=fill, p=1.0) - fn = mocker.patch("torchvision.prototype.transforms.functional.pad") + fn = mocker.patch("torchvision.transforms.v2.functional.pad") image = datapoints.Image(torch.rand(3, 32, 32)) mask = datapoints.Mask(torch.randint(0, 5, size=(32, 32))) inpt = [image, mask] @@ -724,7 +728,7 @@ def test__transform(self, degrees, expand, fill, center, mocker): else: assert transform.degrees == [float(-degrees), float(degrees)] - fn = mocker.patch("torchvision.prototype.transforms.functional.rotate") + fn = mocker.patch("torchvision.transforms.v2.functional.rotate") inpt = mocker.MagicMock(spec=datapoints.Image) # vfdev-5, Feature Request: let's store params as Transform attribute # This could be also helpful for users @@ -859,7 +863,7 @@ def test__transform(self, degrees, translate, scale, shear, fill, center, mocker else: assert transform.degrees == [float(-degrees), float(degrees)] - fn = mocker.patch("torchvision.prototype.transforms.functional.affine") + fn = mocker.patch("torchvision.transforms.v2.functional.affine") inpt = mocker.MagicMock(spec=datapoints.Image) inpt.num_channels = 3 inpt.spatial_size = (24, 32) @@ -964,8 +968,8 @@ def test__transform(self, padding, pad_if_needed, fill, padding_mode, mocker): ) else: expected.spatial_size = inpt.spatial_size - _ = mocker.patch("torchvision.prototype.transforms.functional.pad", return_value=expected) - fn_crop = mocker.patch("torchvision.prototype.transforms.functional.crop") + _ = mocker.patch("torchvision.transforms.v2.functional.pad", return_value=expected) + fn_crop = mocker.patch("torchvision.transforms.v2.functional.crop") # vfdev-5, Feature Request: let's store params as Transform attribute # This could be also helpful for users @@ -1036,7 +1040,7 @@ def test__transform(self, kernel_size, sigma, mocker): else: assert transform.sigma == [sigma, sigma] - fn = mocker.patch("torchvision.prototype.transforms.functional.gaussian_blur") + fn = mocker.patch("torchvision.transforms.v2.functional.gaussian_blur") inpt = mocker.MagicMock(spec=datapoints.Image) inpt.num_channels = 3 inpt.spatial_size = (24, 32) @@ -1068,7 +1072,7 @@ class TestRandomColorOp: def test__transform(self, p, transform_cls, func_op_name, kwargs, mocker): transform = transform_cls(p=p, **kwargs) - fn = mocker.patch(f"torchvision.prototype.transforms.functional.{func_op_name}") + fn = mocker.patch(f"torchvision.transforms.v2.functional.{func_op_name}") inpt = mocker.MagicMock(spec=datapoints.Image) _ = transform(inpt) if p > 0.0: @@ -1104,7 +1108,7 @@ def test__transform(self, distortion_scale, mocker): fill = 12 transform = transforms.RandomPerspective(distortion_scale, fill=fill, interpolation=interpolation) - fn = mocker.patch("torchvision.prototype.transforms.functional.perspective") + fn = mocker.patch("torchvision.transforms.v2.functional.perspective") inpt = mocker.MagicMock(spec=datapoints.Image) inpt.num_channels = 3 inpt.spatial_size = (24, 32) @@ -1178,7 +1182,7 @@ def test__transform(self, alpha, sigma, mocker): else: assert transform.sigma == sigma - fn = mocker.patch("torchvision.prototype.transforms.functional.elastic") + fn = mocker.patch("torchvision.transforms.v2.functional.elastic") inpt = mocker.MagicMock(spec=datapoints.Image) inpt.num_channels = 3 inpt.spatial_size = (24, 32) @@ -1251,13 +1255,13 @@ def test__transform(self, mocker, p): w_sentinel = mocker.MagicMock() v_sentinel = mocker.MagicMock() mocker.patch( - "torchvision.prototype.transforms._augment.RandomErasing._get_params", + "torchvision.transforms.v2._augment.RandomErasing._get_params", return_value=dict(i=i_sentinel, j=j_sentinel, h=h_sentinel, w=w_sentinel, v=v_sentinel), ) inpt_sentinel = mocker.MagicMock() - mock = mocker.patch("torchvision.prototype.transforms._augment.F.erase") + mock = mocker.patch("torchvision.transforms.v2._augment.F.erase") output = transform(inpt_sentinel) if p: @@ -1300,7 +1304,7 @@ class TestToImageTensor: ) def test__transform(self, inpt_type, mocker): fn = mocker.patch( - "torchvision.prototype.transforms.functional.to_image_tensor", + "torchvision.transforms.v2.functional.to_image_tensor", return_value=torch.rand(1, 3, 8, 8), ) @@ -1319,7 +1323,7 @@ class TestToImagePIL: [torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBox, str, int], ) def test__transform(self, inpt_type, mocker): - fn = mocker.patch("torchvision.prototype.transforms.functional.to_image_pil") + fn = mocker.patch("torchvision.transforms.v2.functional.to_image_pil") inpt = mocker.MagicMock(spec=inpt_type) transform = transforms.ToImagePIL() @@ -1336,7 +1340,7 @@ class TestToPILImage: [torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBox, str, int], ) def test__transform(self, inpt_type, mocker): - fn = mocker.patch("torchvision.prototype.transforms.functional.to_image_pil") + fn = mocker.patch("torchvision.transforms.v2.functional.to_image_pil") inpt = mocker.MagicMock(spec=inpt_type) transform = transforms.ToPILImage() @@ -1443,7 +1447,7 @@ def test__transform_empty_params(self, mocker): transform = transforms.RandomIoUCrop(sampler_options=[2.0]) image = datapoints.Image(torch.rand(1, 3, 4, 4)) bboxes = datapoints.BoundingBox(torch.tensor([[1, 1, 2, 2]]), format="XYXY", spatial_size=(4, 4)) - label = datapoints.Label(torch.tensor([1])) + label = proto_datapoints.Label(torch.tensor([1])) sample = [image, bboxes, label] # Let's mock transform._get_params to control the output: transform._get_params = mocker.MagicMock(return_value={}) @@ -1454,7 +1458,7 @@ def test_forward_assertion(self): transform = transforms.RandomIoUCrop() with pytest.raises( TypeError, - match="requires input sample to contain Images or PIL Images, BoundingBoxes and Labels or OneHotLabels", + match="requires input sample to contain tensor or PIL images and bounding boxes", ): transform(torch.tensor(0)) @@ -1463,13 +1467,11 @@ def test__transform(self, mocker): image = datapoints.Image(torch.rand(3, 32, 24)) bboxes = make_bounding_box(format="XYXY", spatial_size=(32, 24), extra_dims=(6,)) - label = datapoints.Label(torch.randint(0, 10, size=(6,))) - ohe_label = datapoints.OneHotLabel(torch.zeros(6, 10).scatter_(1, label.unsqueeze(1), 1)) masks = make_detection_mask((32, 24), num_objects=6) - sample = [image, bboxes, label, ohe_label, masks] + sample = [image, bboxes, masks] - fn = mocker.patch("torchvision.prototype.transforms.functional.crop", side_effect=lambda x, **params: x) + fn = mocker.patch("torchvision.transforms.v2.functional.crop", side_effect=lambda x, **params: x) is_within_crop_area = torch.tensor([0, 1, 0, 1, 0, 1], dtype=torch.bool) params = dict(top=1, left=2, height=12, width=12, is_within_crop_area=is_within_crop_area) @@ -1493,17 +1495,7 @@ def test__transform(self, mocker): assert isinstance(output_bboxes, datapoints.BoundingBox) assert len(output_bboxes) == expected_within_targets - # check labels - output_label = output[2] - assert isinstance(output_label, datapoints.Label) - assert len(output_label) == expected_within_targets - torch.testing.assert_close(output_label, label[is_within_crop_area]) - - output_ohe_label = output[3] - assert isinstance(output_ohe_label, datapoints.OneHotLabel) - torch.testing.assert_close(output_ohe_label, ohe_label[is_within_crop_area]) - - output_masks = output[4] + output_masks = output[2] assert isinstance(output_masks, datapoints.Mask) assert len(output_masks) == expected_within_targets @@ -1545,12 +1537,12 @@ def test__transform(self, mocker): size_sentinel = mocker.MagicMock() mocker.patch( - "torchvision.prototype.transforms._geometry.ScaleJitter._get_params", return_value=dict(size=size_sentinel) + "torchvision.transforms.v2._geometry.ScaleJitter._get_params", return_value=dict(size=size_sentinel) ) inpt_sentinel = mocker.MagicMock() - mock = mocker.patch("torchvision.prototype.transforms._geometry.F.resize") + mock = mocker.patch("torchvision.transforms.v2._geometry.F.resize") transform(inpt_sentinel) mock.assert_called_once_with( @@ -1592,13 +1584,13 @@ def test__transform(self, mocker): size_sentinel = mocker.MagicMock() mocker.patch( - "torchvision.prototype.transforms._geometry.RandomShortestSize._get_params", + "torchvision.transforms.v2._geometry.RandomShortestSize._get_params", return_value=dict(size=size_sentinel), ) inpt_sentinel = mocker.MagicMock() - mock = mocker.patch("torchvision.prototype.transforms._geometry.F.resize") + mock = mocker.patch("torchvision.transforms.v2._geometry.F.resize") transform(inpt_sentinel) mock.assert_called_once_with( @@ -1613,13 +1605,13 @@ def create_fake_image(self, mocker, image_type): return mocker.MagicMock(spec=image_type) def test__extract_image_targets_assertion(self, mocker): - transform = transforms.SimpleCopyPaste() + transform = proto_transforms.SimpleCopyPaste() flat_sample = [ # images, batch size = 2 self.create_fake_image(mocker, datapoints.Image), # labels, bboxes, masks - mocker.MagicMock(spec=datapoints.Label), + mocker.MagicMock(spec=proto_datapoints.Label), mocker.MagicMock(spec=datapoints.BoundingBox), mocker.MagicMock(spec=datapoints.Mask), # labels, bboxes, masks @@ -1631,9 +1623,9 @@ def test__extract_image_targets_assertion(self, mocker): transform._extract_image_targets(flat_sample) @pytest.mark.parametrize("image_type", [datapoints.Image, PIL.Image.Image, torch.Tensor]) - @pytest.mark.parametrize("label_type", [datapoints.Label, datapoints.OneHotLabel]) + @pytest.mark.parametrize("label_type", [proto_datapoints.Label, proto_datapoints.OneHotLabel]) def test__extract_image_targets(self, image_type, label_type, mocker): - transform = transforms.SimpleCopyPaste() + transform = proto_transforms.SimpleCopyPaste() flat_sample = [ # images, batch size = 2 @@ -1669,7 +1661,7 @@ def test__extract_image_targets(self, image_type, label_type, mocker): assert isinstance(target[key], type_) assert target[key] in flat_sample - @pytest.mark.parametrize("label_type", [datapoints.Label, datapoints.OneHotLabel]) + @pytest.mark.parametrize("label_type", [proto_datapoints.Label, proto_datapoints.OneHotLabel]) def test__copy_paste(self, label_type): image = 2 * torch.ones(3, 32, 32) masks = torch.zeros(2, 32, 32) @@ -1679,7 +1671,7 @@ def test__copy_paste(self, label_type): blending = True resize_interpolation = InterpolationMode.BILINEAR antialias = None - if label_type == datapoints.OneHotLabel: + if label_type == proto_datapoints.OneHotLabel: labels = torch.nn.functional.one_hot(labels, num_classes=5) target = { "boxes": datapoints.BoundingBox( @@ -1694,7 +1686,7 @@ def test__copy_paste(self, label_type): paste_masks[0, 13:19, 12:18] = 1 paste_masks[1, 15:19, 1:8] = 1 paste_labels = torch.tensor([3, 4]) - if label_type == datapoints.OneHotLabel: + if label_type == proto_datapoints.OneHotLabel: paste_labels = torch.nn.functional.one_hot(paste_labels, num_classes=5) paste_target = { "boxes": datapoints.BoundingBox( @@ -1704,7 +1696,7 @@ def test__copy_paste(self, label_type): "labels": label_type(paste_labels), } - transform = transforms.SimpleCopyPaste() + transform = proto_transforms.SimpleCopyPaste() random_selection = torch.tensor([0, 1]) output_image, output_target = transform._copy_paste( image, target, paste_image, paste_target, random_selection, blending, resize_interpolation, antialias @@ -1716,7 +1708,7 @@ def test__copy_paste(self, label_type): torch.testing.assert_close(output_target["boxes"][2:, :], paste_target["boxes"]) expected_labels = torch.tensor([1, 2, 3, 4]) - if label_type == datapoints.OneHotLabel: + if label_type == proto_datapoints.OneHotLabel: expected_labels = torch.nn.functional.one_hot(expected_labels, num_classes=5) torch.testing.assert_close(output_target["labels"], label_type(expected_labels)) @@ -1731,7 +1723,7 @@ def test__get_params(self, mocker): batch_shape = (10,) spatial_size = (11, 5) - transform = transforms.FixedSizeCrop(size=crop_size) + transform = proto_transforms.FixedSizeCrop(size=crop_size) flat_inputs = [ make_image(size=spatial_size, color_space="RGB"), @@ -1759,9 +1751,8 @@ def test__transform(self, mocker, needs): fill_sentinel = 12 padding_mode_sentinel = mocker.MagicMock() - transform = transforms.FixedSizeCrop((-1, -1), fill=fill_sentinel, padding_mode=padding_mode_sentinel) + transform = proto_transforms.FixedSizeCrop((-1, -1), fill=fill_sentinel, padding_mode=padding_mode_sentinel) transform._transformed_types = (mocker.MagicMock,) - mocker.patch("torchvision.prototype.transforms._geometry.has_all", return_value=True) mocker.patch("torchvision.prototype.transforms._geometry.has_any", return_value=True) needs_crop, needs_pad = needs @@ -1810,7 +1801,7 @@ def test__transform(self, mocker, needs): if not needs_crop: assert args[0] is inpt_sentinel assert args[1] is padding_sentinel - fill_sentinel = transforms._utils._convert_fill_arg(fill_sentinel) + fill_sentinel = _convert_fill_arg(fill_sentinel) assert kwargs == dict(fill=fill_sentinel, padding_mode=padding_mode_sentinel) else: mock_pad.assert_not_called() @@ -1839,8 +1830,7 @@ def test__transform_culling(self, mocker): masks = make_detection_mask(size=spatial_size, extra_dims=(batch_size,)) labels = make_label(extra_dims=(batch_size,)) - transform = transforms.FixedSizeCrop((-1, -1)) - mocker.patch("torchvision.prototype.transforms._geometry.has_all", return_value=True) + transform = proto_transforms.FixedSizeCrop((-1, -1)) mocker.patch("torchvision.prototype.transforms._geometry.has_any", return_value=True) output = transform( @@ -1877,8 +1867,7 @@ def test__transform_bounding_box_clamping(self, mocker): ) mock = mocker.patch("torchvision.prototype.transforms._geometry.F.clamp_bounding_box") - transform = transforms.FixedSizeCrop((-1, -1)) - mocker.patch("torchvision.prototype.transforms._geometry.has_all", return_value=True) + transform = proto_transforms.FixedSizeCrop((-1, -1)) mocker.patch("torchvision.prototype.transforms._geometry.has_any", return_value=True) transform(bounding_box) @@ -1922,10 +1911,10 @@ def test__transform(self, inpt): class TestLabelToOneHot: def test__transform(self): categories = ["apple", "pear", "pineapple"] - labels = datapoints.Label(torch.tensor([0, 1, 2, 1]), categories=categories) - transform = transforms.LabelToOneHot() + labels = proto_datapoints.Label(torch.tensor([0, 1, 2, 1]), categories=categories) + transform = proto_transforms.LabelToOneHot() ohe_labels = transform(labels) - assert isinstance(ohe_labels, datapoints.OneHotLabel) + assert isinstance(ohe_labels, proto_datapoints.OneHotLabel) assert ohe_labels.shape == (4, 3) assert ohe_labels.categories == labels.categories == categories @@ -1956,13 +1945,13 @@ def test__transform(self, mocker): size_sentinel = mocker.MagicMock() mocker.patch( - "torchvision.prototype.transforms._geometry.RandomResize._get_params", + "torchvision.transforms.v2._geometry.RandomResize._get_params", return_value=dict(size=size_sentinel), ) inpt_sentinel = mocker.MagicMock() - mock_resize = mocker.patch("torchvision.prototype.transforms._geometry.F.resize") + mock_resize = mocker.patch("torchvision.transforms.v2._geometry.F.resize") transform(inpt_sentinel) mock_resize.assert_called_with( @@ -2048,7 +2037,7 @@ def test_call(self, dims, inverse_dims): int=0, ) - transform = transforms.PermuteDimensions(dims) + transform = proto_transforms.PermuteDimensions(dims) transformed_sample = transform(sample) for key, value in sample.items(): @@ -2056,7 +2045,7 @@ def test_call(self, dims, inverse_dims): transformed_value = transformed_sample[key] if check_type( - value, (datapoints.Image, torchvision.prototype.transforms.utils.is_simple_tensor, datapoints.Video) + value, (datapoints.Image, torchvision.transforms.v2.utils.is_simple_tensor, datapoints.Video) ): if transform.dims.get(value_type) is not None: assert transformed_value.permute(inverse_dims[value_type]).equal(value) @@ -2067,14 +2056,14 @@ def test_call(self, dims, inverse_dims): @pytest.mark.filterwarnings("error") def test_plain_tensor_call(self): tensor = torch.empty((2, 3, 4)) - transform = transforms.PermuteDimensions(dims=(1, 2, 0)) + transform = proto_transforms.PermuteDimensions(dims=(1, 2, 0)) assert transform(tensor).shape == (3, 4, 2) @pytest.mark.parametrize("other_type", [datapoints.Image, datapoints.Video]) def test_plain_tensor_warning(self, other_type): with pytest.warns(UserWarning, match=re.escape("`torch.Tensor` will *not* be transformed")): - transforms.PermuteDimensions(dims={torch.Tensor: (0, 1), other_type: (1, 0)}) + proto_transforms.PermuteDimensions(dims={torch.Tensor: (0, 1), other_type: (1, 0)}) class TestTransposeDimensions: @@ -2094,7 +2083,7 @@ def test_call(self, dims): int=0, ) - transform = transforms.TransposeDimensions(dims) + transform = proto_transforms.TransposeDimensions(dims) transformed_sample = transform(sample) for key, value in sample.items(): @@ -2103,7 +2092,7 @@ def test_call(self, dims): transposed_dims = transform.dims.get(value_type) if check_type( - value, (datapoints.Image, torchvision.prototype.transforms.utils.is_simple_tensor, datapoints.Video) + value, (datapoints.Image, torchvision.transforms.v2.utils.is_simple_tensor, datapoints.Video) ): if transposed_dims is not None: assert transformed_value.transpose(*transposed_dims).equal(value) @@ -2114,14 +2103,14 @@ def test_call(self, dims): @pytest.mark.filterwarnings("error") def test_plain_tensor_call(self): tensor = torch.empty((2, 3, 4)) - transform = transforms.TransposeDimensions(dims=(0, 2)) + transform = proto_transforms.TransposeDimensions(dims=(0, 2)) assert transform(tensor).shape == (4, 3, 2) @pytest.mark.parametrize("other_type", [datapoints.Image, datapoints.Video]) def test_plain_tensor_warning(self, other_type): with pytest.warns(UserWarning, match=re.escape("`torch.Tensor` will *not* be transformed")): - transforms.TransposeDimensions(dims={torch.Tensor: (0, 1), other_type: (1, 0)}) + proto_transforms.TransposeDimensions(dims={torch.Tensor: (0, 1), other_type: (1, 0)}) class TestUniformTemporalSubsample: diff --git a/test/test_prototype_transforms_consistency.py b/test/test_prototype_transforms_consistency.py index c6709a5e550..9b3482f3f0a 100644 --- a/test/test_prototype_transforms_consistency.py +++ b/test/test_prototype_transforms_consistency.py @@ -12,6 +12,8 @@ import pytest import torch +import torchvision.prototype.transforms as prototype_transforms +import torchvision.transforms.v2 as v2_transforms from prototype_common_utils import ( ArgsKwargs, assert_close, @@ -24,13 +26,13 @@ make_segmentation_mask, ) from torch import nn -from torchvision import transforms as legacy_transforms +from torchvision import datapoints, transforms as legacy_transforms from torchvision._utils import sequence_to_str -from torchvision.prototype import datapoints, transforms as prototype_transforms -from torchvision.prototype.transforms import functional as prototype_F -from torchvision.prototype.transforms.functional import to_image_pil -from torchvision.prototype.transforms.utils import query_spatial_size + from torchvision.transforms import functional as legacy_F +from torchvision.transforms.v2 import functional as prototype_F +from torchvision.transforms.v2.functional import to_image_pil +from torchvision.transforms.v2.utils import query_spatial_size DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=["RGB"], extra_dims=[(4,)]) @@ -71,7 +73,7 @@ def __init__( CONSISTENCY_CONFIGS = [ ConsistencyConfig( - prototype_transforms.Normalize, + v2_transforms.Normalize, legacy_transforms.Normalize, [ ArgsKwargs(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), @@ -80,14 +82,14 @@ def __init__( make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.float]), ), ConsistencyConfig( - prototype_transforms.Resize, + v2_transforms.Resize, legacy_transforms.Resize, [ NotScriptableArgsKwargs(32), ArgsKwargs([32]), ArgsKwargs((32, 29)), - ArgsKwargs((31, 28), interpolation=prototype_transforms.InterpolationMode.NEAREST), - ArgsKwargs((33, 26), interpolation=prototype_transforms.InterpolationMode.BICUBIC), + ArgsKwargs((31, 28), interpolation=v2_transforms.InterpolationMode.NEAREST), + ArgsKwargs((33, 26), interpolation=v2_transforms.InterpolationMode.BICUBIC), ArgsKwargs((30, 27), interpolation=PIL.Image.NEAREST), ArgsKwargs((35, 29), interpolation=PIL.Image.BILINEAR), ArgsKwargs((34, 25), interpolation=PIL.Image.BICUBIC), @@ -100,7 +102,7 @@ def __init__( ], ), ConsistencyConfig( - prototype_transforms.CenterCrop, + v2_transforms.CenterCrop, legacy_transforms.CenterCrop, [ ArgsKwargs(18), @@ -108,7 +110,7 @@ def __init__( ], ), ConsistencyConfig( - prototype_transforms.FiveCrop, + v2_transforms.FiveCrop, legacy_transforms.FiveCrop, [ ArgsKwargs(18), @@ -117,7 +119,7 @@ def __init__( make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(20, 19)]), ), ConsistencyConfig( - prototype_transforms.TenCrop, + v2_transforms.TenCrop, legacy_transforms.TenCrop, [ ArgsKwargs(18), @@ -127,7 +129,7 @@ def __init__( make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(20, 19)]), ), ConsistencyConfig( - prototype_transforms.Pad, + v2_transforms.Pad, legacy_transforms.Pad, [ NotScriptableArgsKwargs(3), @@ -143,7 +145,7 @@ def __init__( ), *[ ConsistencyConfig( - prototype_transforms.LinearTransformation, + v2_transforms.LinearTransformation, legacy_transforms.LinearTransformation, [ ArgsKwargs(LINEAR_TRANSFORMATION_MATRIX.to(matrix_dtype), LINEAR_TRANSFORMATION_MEAN.to(matrix_dtype)), @@ -164,7 +166,7 @@ def __init__( ] ], ConsistencyConfig( - prototype_transforms.Grayscale, + v2_transforms.Grayscale, legacy_transforms.Grayscale, [ ArgsKwargs(num_output_channels=1), @@ -175,7 +177,7 @@ def __init__( closeness_kwargs=dict(rtol=None, atol=None), ), ConsistencyConfig( - prototype_transforms.ConvertDtype, + v2_transforms.ConvertDtype, legacy_transforms.ConvertImageDtype, [ ArgsKwargs(torch.float16), @@ -189,7 +191,7 @@ def __init__( closeness_kwargs=dict(rtol=None, atol=None), ), ConsistencyConfig( - prototype_transforms.ToPILImage, + v2_transforms.ToPILImage, legacy_transforms.ToPILImage, [NotScriptableArgsKwargs()], make_images_kwargs=dict( @@ -204,7 +206,7 @@ def __init__( supports_pil=False, ), ConsistencyConfig( - prototype_transforms.Lambda, + v2_transforms.Lambda, legacy_transforms.Lambda, [ NotScriptableArgsKwargs(lambda image: image / 2), @@ -214,7 +216,7 @@ def __init__( supports_pil=False, ), ConsistencyConfig( - prototype_transforms.RandomHorizontalFlip, + v2_transforms.RandomHorizontalFlip, legacy_transforms.RandomHorizontalFlip, [ ArgsKwargs(p=0), @@ -222,7 +224,7 @@ def __init__( ], ), ConsistencyConfig( - prototype_transforms.RandomVerticalFlip, + v2_transforms.RandomVerticalFlip, legacy_transforms.RandomVerticalFlip, [ ArgsKwargs(p=0), @@ -230,7 +232,7 @@ def __init__( ], ), ConsistencyConfig( - prototype_transforms.RandomEqualize, + v2_transforms.RandomEqualize, legacy_transforms.RandomEqualize, [ ArgsKwargs(p=0), @@ -239,7 +241,7 @@ def __init__( make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.uint8]), ), ConsistencyConfig( - prototype_transforms.RandomInvert, + v2_transforms.RandomInvert, legacy_transforms.RandomInvert, [ ArgsKwargs(p=0), @@ -247,7 +249,7 @@ def __init__( ], ), ConsistencyConfig( - prototype_transforms.RandomPosterize, + v2_transforms.RandomPosterize, legacy_transforms.RandomPosterize, [ ArgsKwargs(p=0, bits=5), @@ -257,7 +259,7 @@ def __init__( make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.uint8]), ), ConsistencyConfig( - prototype_transforms.RandomSolarize, + v2_transforms.RandomSolarize, legacy_transforms.RandomSolarize, [ ArgsKwargs(p=0, threshold=0.5), @@ -267,7 +269,7 @@ def __init__( ), *[ ConsistencyConfig( - prototype_transforms.RandomAutocontrast, + v2_transforms.RandomAutocontrast, legacy_transforms.RandomAutocontrast, [ ArgsKwargs(p=0), @@ -279,7 +281,7 @@ def __init__( for dt, ckw in [(torch.uint8, dict(atol=1, rtol=0)), (torch.float32, dict(rtol=None, atol=None))] ], ConsistencyConfig( - prototype_transforms.RandomAdjustSharpness, + v2_transforms.RandomAdjustSharpness, legacy_transforms.RandomAdjustSharpness, [ ArgsKwargs(p=0, sharpness_factor=0.5), @@ -289,7 +291,7 @@ def __init__( closeness_kwargs={"atol": 1e-6, "rtol": 1e-6}, ), ConsistencyConfig( - prototype_transforms.RandomGrayscale, + v2_transforms.RandomGrayscale, legacy_transforms.RandomGrayscale, [ ArgsKwargs(p=0), @@ -300,14 +302,14 @@ def __init__( closeness_kwargs=dict(rtol=None, atol=None), ), ConsistencyConfig( - prototype_transforms.RandomResizedCrop, + v2_transforms.RandomResizedCrop, legacy_transforms.RandomResizedCrop, [ ArgsKwargs(16), ArgsKwargs(17, scale=(0.3, 0.7)), ArgsKwargs(25, ratio=(0.5, 1.5)), - ArgsKwargs((31, 28), interpolation=prototype_transforms.InterpolationMode.NEAREST), - ArgsKwargs((33, 26), interpolation=prototype_transforms.InterpolationMode.BICUBIC), + ArgsKwargs((31, 28), interpolation=v2_transforms.InterpolationMode.NEAREST), + ArgsKwargs((33, 26), interpolation=v2_transforms.InterpolationMode.BICUBIC), ArgsKwargs((31, 28), interpolation=PIL.Image.NEAREST), ArgsKwargs((33, 26), interpolation=PIL.Image.BICUBIC), ArgsKwargs((29, 32), antialias=False), @@ -315,7 +317,7 @@ def __init__( ], ), ConsistencyConfig( - prototype_transforms.RandomErasing, + v2_transforms.RandomErasing, legacy_transforms.RandomErasing, [ ArgsKwargs(p=0), @@ -329,7 +331,7 @@ def __init__( supports_pil=False, ), ConsistencyConfig( - prototype_transforms.ColorJitter, + v2_transforms.ColorJitter, legacy_transforms.ColorJitter, [ ArgsKwargs(), @@ -347,7 +349,7 @@ def __init__( ), *[ ConsistencyConfig( - prototype_transforms.ElasticTransform, + v2_transforms.ElasticTransform, legacy_transforms.ElasticTransform, [ ArgsKwargs(), @@ -355,8 +357,8 @@ def __init__( ArgsKwargs(alpha=(15.3, 27.2)), ArgsKwargs(sigma=3.0), ArgsKwargs(sigma=(2.5, 3.9)), - ArgsKwargs(interpolation=prototype_transforms.InterpolationMode.NEAREST), - ArgsKwargs(interpolation=prototype_transforms.InterpolationMode.BICUBIC), + ArgsKwargs(interpolation=v2_transforms.InterpolationMode.NEAREST), + ArgsKwargs(interpolation=v2_transforms.InterpolationMode.BICUBIC), ArgsKwargs(interpolation=PIL.Image.NEAREST), ArgsKwargs(interpolation=PIL.Image.BICUBIC), ArgsKwargs(fill=1), @@ -370,7 +372,7 @@ def __init__( for dt, ckw in [(torch.uint8, {"rtol": 1e-1, "atol": 1}), (torch.float32, {"rtol": 1e-2, "atol": 1e-3})] ], ConsistencyConfig( - prototype_transforms.GaussianBlur, + v2_transforms.GaussianBlur, legacy_transforms.GaussianBlur, [ ArgsKwargs(kernel_size=3), @@ -381,7 +383,7 @@ def __init__( closeness_kwargs={"rtol": 1e-5, "atol": 1e-5}, ), ConsistencyConfig( - prototype_transforms.RandomAffine, + v2_transforms.RandomAffine, legacy_transforms.RandomAffine, [ ArgsKwargs(degrees=30.0), @@ -392,7 +394,7 @@ def __init__( ArgsKwargs(degrees=0.0, shear=(8, 17)), ArgsKwargs(degrees=0.0, shear=(4, 5, 4, 13)), ArgsKwargs(degrees=(-20.0, 10.0), translate=(0.4, 0.6), scale=(0.3, 0.8), shear=(4, 5, 4, 13)), - ArgsKwargs(degrees=30.0, interpolation=prototype_transforms.InterpolationMode.NEAREST), + ArgsKwargs(degrees=30.0, interpolation=v2_transforms.InterpolationMode.NEAREST), ArgsKwargs(degrees=30.0, interpolation=PIL.Image.NEAREST), ArgsKwargs(degrees=30.0, fill=1), ArgsKwargs(degrees=30.0, fill=(2, 3, 4)), @@ -401,7 +403,7 @@ def __init__( removed_params=["fillcolor", "resample"], ), ConsistencyConfig( - prototype_transforms.RandomCrop, + v2_transforms.RandomCrop, legacy_transforms.RandomCrop, [ ArgsKwargs(12), @@ -421,13 +423,13 @@ def __init__( make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(26, 26), (18, 33), (29, 22)]), ), ConsistencyConfig( - prototype_transforms.RandomPerspective, + v2_transforms.RandomPerspective, legacy_transforms.RandomPerspective, [ ArgsKwargs(p=0), ArgsKwargs(p=1), ArgsKwargs(p=1, distortion_scale=0.3), - ArgsKwargs(p=1, distortion_scale=0.2, interpolation=prototype_transforms.InterpolationMode.NEAREST), + ArgsKwargs(p=1, distortion_scale=0.2, interpolation=v2_transforms.InterpolationMode.NEAREST), ArgsKwargs(p=1, distortion_scale=0.2, interpolation=PIL.Image.NEAREST), ArgsKwargs(p=1, distortion_scale=0.1, fill=1), ArgsKwargs(p=1, distortion_scale=0.4, fill=(1, 2, 3)), @@ -435,12 +437,12 @@ def __init__( closeness_kwargs={"atol": None, "rtol": None}, ), ConsistencyConfig( - prototype_transforms.RandomRotation, + v2_transforms.RandomRotation, legacy_transforms.RandomRotation, [ ArgsKwargs(degrees=30.0), ArgsKwargs(degrees=(-20.0, 10.0)), - ArgsKwargs(degrees=30.0, interpolation=prototype_transforms.InterpolationMode.BILINEAR), + ArgsKwargs(degrees=30.0, interpolation=v2_transforms.InterpolationMode.BILINEAR), ArgsKwargs(degrees=30.0, interpolation=PIL.Image.BILINEAR), ArgsKwargs(degrees=30.0, expand=True), ArgsKwargs(degrees=30.0, center=(0, 0)), @@ -450,43 +452,43 @@ def __init__( removed_params=["resample"], ), ConsistencyConfig( - prototype_transforms.PILToTensor, + v2_transforms.PILToTensor, legacy_transforms.PILToTensor, ), ConsistencyConfig( - prototype_transforms.ToTensor, + v2_transforms.ToTensor, legacy_transforms.ToTensor, ), ConsistencyConfig( - prototype_transforms.Compose, + v2_transforms.Compose, legacy_transforms.Compose, ), ConsistencyConfig( - prototype_transforms.RandomApply, + v2_transforms.RandomApply, legacy_transforms.RandomApply, ), ConsistencyConfig( - prototype_transforms.RandomChoice, + v2_transforms.RandomChoice, legacy_transforms.RandomChoice, ), ConsistencyConfig( - prototype_transforms.RandomOrder, + v2_transforms.RandomOrder, legacy_transforms.RandomOrder, ), ConsistencyConfig( - prototype_transforms.AugMix, + v2_transforms.AugMix, legacy_transforms.AugMix, ), ConsistencyConfig( - prototype_transforms.AutoAugment, + v2_transforms.AutoAugment, legacy_transforms.AutoAugment, ), ConsistencyConfig( - prototype_transforms.RandAugment, + v2_transforms.RandAugment, legacy_transforms.RandAugment, ), ConsistencyConfig( - prototype_transforms.TrivialAugmentWide, + v2_transforms.TrivialAugmentWide, legacy_transforms.TrivialAugmentWide, ), ] @@ -680,19 +682,19 @@ def test_call_consistency(config, args_kwargs): id=transform_cls.__name__, ) for transform_cls, get_params_args_kwargs in [ - (prototype_transforms.RandomResizedCrop, ArgsKwargs(make_image(), scale=[0.3, 0.7], ratio=[0.5, 1.5])), - (prototype_transforms.RandomErasing, ArgsKwargs(make_image(), scale=(0.3, 0.7), ratio=(0.5, 1.5))), - (prototype_transforms.ColorJitter, ArgsKwargs(brightness=None, contrast=None, saturation=None, hue=None)), - (prototype_transforms.ElasticTransform, ArgsKwargs(alpha=[15.3, 27.2], sigma=[2.5, 3.9], size=[17, 31])), - (prototype_transforms.GaussianBlur, ArgsKwargs(0.3, 1.4)), + (v2_transforms.RandomResizedCrop, ArgsKwargs(make_image(), scale=[0.3, 0.7], ratio=[0.5, 1.5])), + (v2_transforms.RandomErasing, ArgsKwargs(make_image(), scale=(0.3, 0.7), ratio=(0.5, 1.5))), + (v2_transforms.ColorJitter, ArgsKwargs(brightness=None, contrast=None, saturation=None, hue=None)), + (v2_transforms.ElasticTransform, ArgsKwargs(alpha=[15.3, 27.2], sigma=[2.5, 3.9], size=[17, 31])), + (v2_transforms.GaussianBlur, ArgsKwargs(0.3, 1.4)), ( - prototype_transforms.RandomAffine, + v2_transforms.RandomAffine, ArgsKwargs(degrees=[-20.0, 10.0], translate=None, scale_ranges=None, shears=None, img_size=[15, 29]), ), - (prototype_transforms.RandomCrop, ArgsKwargs(make_image(size=(61, 47)), output_size=(19, 25))), - (prototype_transforms.RandomPerspective, ArgsKwargs(23, 17, 0.5)), - (prototype_transforms.RandomRotation, ArgsKwargs(degrees=[-20.0, 10.0])), - (prototype_transforms.AutoAugment, ArgsKwargs(5)), + (v2_transforms.RandomCrop, ArgsKwargs(make_image(size=(61, 47)), output_size=(19, 25))), + (v2_transforms.RandomPerspective, ArgsKwargs(23, 17, 0.5)), + (v2_transforms.RandomRotation, ArgsKwargs(degrees=[-20.0, 10.0])), + (v2_transforms.AutoAugment, ArgsKwargs(5)), ] ], ) @@ -767,10 +769,10 @@ class TestContainerTransforms: """ def test_compose(self): - prototype_transform = prototype_transforms.Compose( + prototype_transform = v2_transforms.Compose( [ - prototype_transforms.Resize(256), - prototype_transforms.CenterCrop(224), + v2_transforms.Resize(256), + v2_transforms.CenterCrop(224), ] ) legacy_transform = legacy_transforms.Compose( @@ -785,11 +787,11 @@ def test_compose(self): @pytest.mark.parametrize("p", [0, 0.1, 0.5, 0.9, 1]) @pytest.mark.parametrize("sequence_type", [list, nn.ModuleList]) def test_random_apply(self, p, sequence_type): - prototype_transform = prototype_transforms.RandomApply( + prototype_transform = v2_transforms.RandomApply( sequence_type( [ - prototype_transforms.Resize(256), - prototype_transforms.CenterCrop(224), + v2_transforms.Resize(256), + v2_transforms.CenterCrop(224), ] ), p=p, @@ -814,9 +816,9 @@ def test_random_apply(self, p, sequence_type): # We can't test other values for `p` since the random parameter generation is different @pytest.mark.parametrize("probabilities", [(0, 1), (1, 0)]) def test_random_choice(self, probabilities): - prototype_transform = prototype_transforms.RandomChoice( + prototype_transform = v2_transforms.RandomChoice( [ - prototype_transforms.Resize(256), + v2_transforms.Resize(256), legacy_transforms.CenterCrop(224), ], probabilities=probabilities, @@ -834,7 +836,7 @@ def test_random_choice(self, probabilities): class TestToTensorTransforms: def test_pil_to_tensor(self): - prototype_transform = prototype_transforms.PILToTensor() + prototype_transform = v2_transforms.PILToTensor() legacy_transform = legacy_transforms.PILToTensor() for image in make_images(extra_dims=[()]): @@ -844,7 +846,7 @@ def test_pil_to_tensor(self): def test_to_tensor(self): with pytest.warns(UserWarning, match=re.escape("The transform `ToTensor()` is deprecated")): - prototype_transform = prototype_transforms.ToTensor() + prototype_transform = v2_transforms.ToTensor() legacy_transform = legacy_transforms.ToTensor() for image in make_images(extra_dims=[()]): @@ -867,14 +869,14 @@ class TestAATransforms: @pytest.mark.parametrize( "interpolation", [ - prototype_transforms.InterpolationMode.NEAREST, - prototype_transforms.InterpolationMode.BILINEAR, + v2_transforms.InterpolationMode.NEAREST, + v2_transforms.InterpolationMode.BILINEAR, PIL.Image.NEAREST, ], ) def test_randaug(self, inpt, interpolation, mocker): t_ref = legacy_transforms.RandAugment(interpolation=interpolation, num_ops=1) - t = prototype_transforms.RandAugment(interpolation=interpolation, num_ops=1) + t = v2_transforms.RandAugment(interpolation=interpolation, num_ops=1) le = len(t._AUGMENTATION_SPACE) keys = list(t._AUGMENTATION_SPACE.keys()) @@ -909,14 +911,14 @@ def test_randaug(self, inpt, interpolation, mocker): @pytest.mark.parametrize( "interpolation", [ - prototype_transforms.InterpolationMode.NEAREST, - prototype_transforms.InterpolationMode.BILINEAR, + v2_transforms.InterpolationMode.NEAREST, + v2_transforms.InterpolationMode.BILINEAR, PIL.Image.NEAREST, ], ) def test_trivial_aug(self, inpt, interpolation, mocker): t_ref = legacy_transforms.TrivialAugmentWide(interpolation=interpolation) - t = prototype_transforms.TrivialAugmentWide(interpolation=interpolation) + t = v2_transforms.TrivialAugmentWide(interpolation=interpolation) le = len(t._AUGMENTATION_SPACE) keys = list(t._AUGMENTATION_SPACE.keys()) @@ -961,15 +963,15 @@ def test_trivial_aug(self, inpt, interpolation, mocker): @pytest.mark.parametrize( "interpolation", [ - prototype_transforms.InterpolationMode.NEAREST, - prototype_transforms.InterpolationMode.BILINEAR, + v2_transforms.InterpolationMode.NEAREST, + v2_transforms.InterpolationMode.BILINEAR, PIL.Image.NEAREST, ], ) def test_augmix(self, inpt, interpolation, mocker): t_ref = legacy_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1) t_ref._sample_dirichlet = lambda t: t.softmax(dim=-1) - t = prototype_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1) + t = v2_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1) t._sample_dirichlet = lambda t: t.softmax(dim=-1) le = len(t._AUGMENTATION_SPACE) @@ -1014,15 +1016,15 @@ def test_augmix(self, inpt, interpolation, mocker): @pytest.mark.parametrize( "interpolation", [ - prototype_transforms.InterpolationMode.NEAREST, - prototype_transforms.InterpolationMode.BILINEAR, + v2_transforms.InterpolationMode.NEAREST, + v2_transforms.InterpolationMode.BILINEAR, PIL.Image.NEAREST, ], ) def test_aa(self, inpt, interpolation): aa_policy = legacy_transforms.AutoAugmentPolicy("imagenet") t_ref = legacy_transforms.AutoAugment(aa_policy, interpolation=interpolation) - t = prototype_transforms.AutoAugment(aa_policy, interpolation=interpolation) + t = v2_transforms.AutoAugment(aa_policy, interpolation=interpolation) torch.manual_seed(12) expected_output = t_ref(inpt) @@ -1087,10 +1089,16 @@ def make_datapoints(self, with_mask=True): @pytest.mark.parametrize( "t_ref, t, data_kwargs", [ - (det_transforms.RandomHorizontalFlip(p=1.0), prototype_transforms.RandomHorizontalFlip(p=1.0), {}), - (det_transforms.RandomIoUCrop(), prototype_transforms.RandomIoUCrop(), {"with_mask": False}), - (det_transforms.RandomZoomOut(), prototype_transforms.RandomZoomOut(), {"with_mask": False}), - (det_transforms.ScaleJitter((1024, 1024)), prototype_transforms.ScaleJitter((1024, 1024)), {}), + (det_transforms.RandomHorizontalFlip(p=1.0), v2_transforms.RandomHorizontalFlip(p=1.0), {}), + # FIXME: make + # v2_transforms.Compose([ + # v2_transforms.RandomIoUCrop(), + # v2_transforms.SanitizeBoundingBoxes() + # ]) + # work + # (det_transforms.RandomIoUCrop(), v2_transforms.RandomIoUCrop(), {"with_mask": False}), + (det_transforms.RandomZoomOut(), v2_transforms.RandomZoomOut(), {"with_mask": False}), + (det_transforms.ScaleJitter((1024, 1024)), v2_transforms.ScaleJitter((1024, 1024)), {}), ( det_transforms.FixedSizeCrop((1024, 1024), fill=0), prototype_transforms.FixedSizeCrop((1024, 1024), fill=0), @@ -1100,7 +1108,7 @@ def make_datapoints(self, with_mask=True): det_transforms.RandomShortestSize( min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333 ), - prototype_transforms.RandomShortestSize( + v2_transforms.RandomShortestSize( min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333 ), {}, @@ -1127,11 +1135,11 @@ def test_transform(self, t_ref, t, data_kwargs): # 1. transforms.RandomCrop uses a different scheme to pad images and masks of insufficient size than its name # counterpart in the detection references. Thus, we cannot use it with `pad_if_needed=True` # 2. transforms.Pad only supports a fixed padding, but the segmentation datasets don't have a fixed image size. -class PadIfSmaller(prototype_transforms.Transform): +class PadIfSmaller(v2_transforms.Transform): def __init__(self, size, fill=0): super().__init__() self.size = size - self.fill = prototype_transforms._geometry._setup_fill_arg(fill) + self.fill = v2_transforms._geometry._setup_fill_arg(fill) def _get_params(self, sample): height, width = query_spatial_size(sample) @@ -1193,27 +1201,27 @@ def check(self, t, t_ref, data_kwargs=None): [ ( seg_transforms.RandomHorizontalFlip(flip_prob=1.0), - prototype_transforms.RandomHorizontalFlip(p=1.0), + v2_transforms.RandomHorizontalFlip(p=1.0), dict(), ), ( seg_transforms.RandomHorizontalFlip(flip_prob=0.0), - prototype_transforms.RandomHorizontalFlip(p=0.0), + v2_transforms.RandomHorizontalFlip(p=0.0), dict(), ), ( seg_transforms.RandomCrop(size=480), - prototype_transforms.Compose( + v2_transforms.Compose( [ PadIfSmaller(size=480, fill=defaultdict(lambda: 0, {datapoints.Mask: 255})), - prototype_transforms.RandomCrop(size=480), + v2_transforms.RandomCrop(size=480), ] ), dict(), ), ( seg_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), - prototype_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + v2_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), dict(supports_pil=False, image_dtype=torch.float), ), ], @@ -1222,7 +1230,7 @@ def test_common(self, t_ref, t, data_kwargs): self.check(t, t_ref, data_kwargs) def check_resize(self, mocker, t_ref, t): - mock = mocker.patch("torchvision.prototype.transforms._geometry.F.resize") + mock = mocker.patch("torchvision.transforms.v2._geometry.F.resize") mock_ref = mocker.patch("torchvision.transforms.functional.resize") for dp, dp_ref in self.make_datapoints(): @@ -1263,9 +1271,9 @@ def patched_randint(a, b, *other_args, **kwargs): # We are patching torch.randint -> random.randint here, because we can't patch the modules that are not imported # normally - t = prototype_transforms.RandomResize(min_size=min_size, max_size=max_size, antialias=True) + t = v2_transforms.RandomResize(min_size=min_size, max_size=max_size, antialias=True) mocker.patch( - "torchvision.prototype.transforms._geometry.torch.randint", + "torchvision.transforms.v2._geometry.torch.randint", new=patched_randint, ) @@ -1277,7 +1285,7 @@ def test_random_resize_eval(self, mocker): torch.manual_seed(0) base_size = 520 - t = prototype_transforms.Resize(size=base_size, antialias=True) + t = v2_transforms.Resize(size=base_size, antialias=True) t_ref = seg_transforms.RandomResize(min_size=base_size, max_size=base_size) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index bb4b6ef1158..7dff7a509ad 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -11,7 +11,6 @@ import torch -import torchvision.prototype.transforms.utils from common_utils import cache, cpu_and_gpu, needs_cuda, set_rng_seed from prototype_common_utils import ( assert_close, @@ -22,11 +21,12 @@ from prototype_transforms_dispatcher_infos import DISPATCHER_INFOS from prototype_transforms_kernel_infos import KERNEL_INFOS from torch.utils._pytree import tree_map -from torchvision.prototype import datapoints -from torchvision.prototype.transforms import functional as F -from torchvision.prototype.transforms.functional._geometry import _center_crop_compute_padding -from torchvision.prototype.transforms.functional._meta import clamp_bounding_box, convert_format_bounding_box +from torchvision import datapoints from torchvision.transforms.functional import _get_perspective_coeffs +from torchvision.transforms.v2 import functional as F +from torchvision.transforms.v2.functional._geometry import _center_crop_compute_padding +from torchvision.transforms.v2.functional._meta import clamp_bounding_box, convert_format_bounding_box +from torchvision.transforms.v2.utils import is_simple_tensor KERNEL_INFOS_MAP = {info.kernel: info for info in KERNEL_INFOS} @@ -168,11 +168,7 @@ def _unbatch(self, batch, *, data_dims): def test_batched_vs_single(self, test_id, info, args_kwargs, device): (batched_input, *other_args), kwargs = args_kwargs.load(device) - datapoint_type = ( - datapoints.Image - if torchvision.prototype.transforms.utils.is_simple_tensor(batched_input) - else type(batched_input) - ) + datapoint_type = datapoints.Image if is_simple_tensor(batched_input) else type(batched_input) # This dictionary contains the number of rightmost dimensions that contain the actual data. # Everything to the left is considered a batch dimension. data_dims = { diff --git a/test/test_prototype_transforms_utils.py b/test/test_prototype_transforms_utils.py index befccf0bea3..c9d37466046 100644 --- a/test/test_prototype_transforms_utils.py +++ b/test/test_prototype_transforms_utils.py @@ -3,12 +3,12 @@ import torch -import torchvision.prototype.transforms.utils +import torchvision.transforms.v2.utils from prototype_common_utils import make_bounding_box, make_detection_mask, make_image -from torchvision.prototype import datapoints -from torchvision.prototype.transforms.functional import to_image_pil -from torchvision.prototype.transforms.utils import has_all, has_any +from torchvision import datapoints +from torchvision.transforms.v2.functional import to_image_pil +from torchvision.transforms.v2.utils import has_all, has_any IMAGE = make_image(color_space="RGB") @@ -37,15 +37,15 @@ ((IMAGE, BOUNDING_BOX, MASK), (lambda obj: isinstance(obj, datapoints.Image),), True), ((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False), ((IMAGE, BOUNDING_BOX, MASK), (lambda _: True,), True), - ((IMAGE,), (datapoints.Image, PIL.Image.Image, torchvision.prototype.transforms.utils.is_simple_tensor), True), + ((IMAGE,), (datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_simple_tensor), True), ( (torch.Tensor(IMAGE),), - (datapoints.Image, PIL.Image.Image, torchvision.prototype.transforms.utils.is_simple_tensor), + (datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_simple_tensor), True, ), ( (to_image_pil(IMAGE),), - (datapoints.Image, PIL.Image.Image, torchvision.prototype.transforms.utils.is_simple_tensor), + (datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_simple_tensor), True, ), ], diff --git a/torchvision/datapoints/__init__.py b/torchvision/datapoints/__init__.py new file mode 100644 index 00000000000..04d5a05731e --- /dev/null +++ b/torchvision/datapoints/__init__.py @@ -0,0 +1,7 @@ +from ._bounding_box import BoundingBox, BoundingBoxFormat +from ._datapoint import FillType, FillTypeJIT, InputType, InputTypeJIT +from ._image import Image, ImageType, ImageTypeJIT, TensorImageType, TensorImageTypeJIT +from ._mask import Mask +from ._video import TensorVideoType, TensorVideoTypeJIT, Video, VideoType, VideoTypeJIT + +from ._dataset_wrapper import wrap_dataset_for_transforms_v2 # type: ignore[attr-defined] # usort: skip diff --git a/torchvision/prototype/datapoints/_bounding_box.py b/torchvision/datapoints/_bounding_box.py similarity index 100% rename from torchvision/prototype/datapoints/_bounding_box.py rename to torchvision/datapoints/_bounding_box.py diff --git a/torchvision/prototype/datapoints/_datapoint.py b/torchvision/datapoints/_datapoint.py similarity index 99% rename from torchvision/prototype/datapoints/_datapoint.py rename to torchvision/datapoints/_datapoint.py index 5f4a0d96ea2..2a2f34fc60e 100644 --- a/torchvision/prototype/datapoints/_datapoint.py +++ b/torchvision/datapoints/_datapoint.py @@ -105,7 +105,7 @@ def _F(self) -> ModuleType: # the class. This approach avoids the DataLoader issue described at # https://github.com/pytorch/vision/pull/6476#discussion_r953588621 if Datapoint.__F is None: - from ..transforms import functional + from ..transforms.v2 import functional Datapoint.__F = functional return Datapoint.__F diff --git a/torchvision/prototype/datapoints/_dataset_wrapper.py b/torchvision/datapoints/_dataset_wrapper.py similarity index 98% rename from torchvision/prototype/datapoints/_dataset_wrapper.py rename to torchvision/datapoints/_dataset_wrapper.py index 74f83095177..dc4d1f4723c 100644 --- a/torchvision/prototype/datapoints/_dataset_wrapper.py +++ b/torchvision/datapoints/_dataset_wrapper.py @@ -8,9 +8,8 @@ import torch from torch.utils.data import Dataset -from torchvision import datasets -from torchvision.prototype import datapoints -from torchvision.prototype.transforms import functional as F +from torchvision import datapoints, datasets +from torchvision.transforms.v2 import functional as F __all__ = ["wrap_dataset_for_transforms_v2"] diff --git a/torchvision/prototype/datapoints/_image.py b/torchvision/datapoints/_image.py similarity index 99% rename from torchvision/prototype/datapoints/_image.py rename to torchvision/datapoints/_image.py index 4fc14323abe..9c61740c563 100644 --- a/torchvision/prototype/datapoints/_image.py +++ b/torchvision/datapoints/_image.py @@ -24,7 +24,7 @@ def __new__( requires_grad: Optional[bool] = None, ) -> Image: if isinstance(data, PIL.Image.Image): - from torchvision.prototype.transforms import functional as F + from torchvision.transforms.v2 import functional as F data = F.pil_to_tensor(data) diff --git a/torchvision/prototype/datapoints/_mask.py b/torchvision/datapoints/_mask.py similarity index 98% rename from torchvision/prototype/datapoints/_mask.py rename to torchvision/datapoints/_mask.py index 41dce097c6c..2746feaaf14 100644 --- a/torchvision/prototype/datapoints/_mask.py +++ b/torchvision/datapoints/_mask.py @@ -23,7 +23,7 @@ def __new__( requires_grad: Optional[bool] = None, ) -> Mask: if isinstance(data, PIL.Image.Image): - from torchvision.prototype.transforms import functional as F + from torchvision.transforms.v2 import functional as F data = F.pil_to_tensor(data) diff --git a/torchvision/prototype/datapoints/_video.py b/torchvision/datapoints/_video.py similarity index 100% rename from torchvision/prototype/datapoints/_video.py rename to torchvision/datapoints/_video.py diff --git a/torchvision/prototype/datapoints/__init__.py b/torchvision/prototype/datapoints/__init__.py index 554088b912a..604628b2540 100644 --- a/torchvision/prototype/datapoints/__init__.py +++ b/torchvision/prototype/datapoints/__init__.py @@ -1,8 +1 @@ -from ._bounding_box import BoundingBox, BoundingBoxFormat -from ._datapoint import FillType, FillTypeJIT, InputType, InputTypeJIT -from ._image import Image, ImageType, ImageTypeJIT, TensorImageType, TensorImageTypeJIT from ._label import Label, OneHotLabel -from ._mask import Mask -from ._video import TensorVideoType, TensorVideoTypeJIT, Video, VideoType, VideoTypeJIT - -from ._dataset_wrapper import wrap_dataset_for_transforms_v2 # type: ignore[attr-defined] # usort: skip diff --git a/torchvision/prototype/datapoints/_label.py b/torchvision/prototype/datapoints/_label.py index 0ee2eb9f8e1..7ed2f7522b0 100644 --- a/torchvision/prototype/datapoints/_label.py +++ b/torchvision/prototype/datapoints/_label.py @@ -5,7 +5,7 @@ import torch from torch.utils._pytree import tree_map -from ._datapoint import Datapoint +from torchvision.datapoints._datapoint import Datapoint L = TypeVar("L", bound="_LabelBase") diff --git a/torchvision/prototype/datasets/_builtin/caltech.py b/torchvision/prototype/datasets/_builtin/caltech.py index d8f560a36ff..f3882361638 100644 --- a/torchvision/prototype/datasets/_builtin/caltech.py +++ b/torchvision/prototype/datasets/_builtin/caltech.py @@ -6,7 +6,8 @@ import torch from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper -from torchvision.prototype.datapoints import BoundingBox, Label +from torchvision.datapoints import BoundingBox +from torchvision.prototype.datapoints import Label from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( hint_sharding, diff --git a/torchvision/prototype/datasets/_builtin/celeba.py b/torchvision/prototype/datasets/_builtin/celeba.py index 66999c4c50b..2c819468778 100644 --- a/torchvision/prototype/datasets/_builtin/celeba.py +++ b/torchvision/prototype/datasets/_builtin/celeba.py @@ -4,7 +4,8 @@ import torch from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper, Zipper -from torchvision.prototype.datapoints import BoundingBox, Label +from torchvision.datapoints import BoundingBox +from torchvision.prototype.datapoints import Label from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( getitem, diff --git a/torchvision/prototype/datasets/_builtin/cifar.py b/torchvision/prototype/datasets/_builtin/cifar.py index de87f46c8b1..7d178291992 100644 --- a/torchvision/prototype/datasets/_builtin/cifar.py +++ b/torchvision/prototype/datasets/_builtin/cifar.py @@ -6,7 +6,8 @@ import numpy as np from torchdata.datapipes.iter import Filter, IterDataPipe, Mapper -from torchvision.prototype.datapoints import Image, Label +from torchvision.datapoints import Image +from torchvision.prototype.datapoints import Label from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( hint_sharding, diff --git a/torchvision/prototype/datasets/_builtin/coco.py b/torchvision/prototype/datasets/_builtin/coco.py index e02ca706b1e..6616b4e3491 100644 --- a/torchvision/prototype/datasets/_builtin/coco.py +++ b/torchvision/prototype/datasets/_builtin/coco.py @@ -14,7 +14,8 @@ Mapper, UnBatcher, ) -from torchvision.prototype.datapoints import BoundingBox, Label, Mask +from torchvision.datapoints import BoundingBox, Mask +from torchvision.prototype.datapoints import Label from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( getitem, diff --git a/torchvision/prototype/datasets/_builtin/cub200.py b/torchvision/prototype/datasets/_builtin/cub200.py index db561f89ec6..bc41ba028c5 100644 --- a/torchvision/prototype/datasets/_builtin/cub200.py +++ b/torchvision/prototype/datasets/_builtin/cub200.py @@ -15,7 +15,8 @@ Mapper, ) from torchdata.datapipes.map import IterToMapConverter -from torchvision.prototype.datapoints import BoundingBox, Label +from torchvision.datapoints import BoundingBox +from torchvision.prototype.datapoints import Label from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( getitem, diff --git a/torchvision/prototype/datasets/_builtin/fer2013.py b/torchvision/prototype/datasets/_builtin/fer2013.py index 73c6184b6e7..17f092aa328 100644 --- a/torchvision/prototype/datasets/_builtin/fer2013.py +++ b/torchvision/prototype/datasets/_builtin/fer2013.py @@ -3,7 +3,8 @@ import torch from torchdata.datapipes.iter import CSVDictParser, IterDataPipe, Mapper -from torchvision.prototype.datapoints import Image, Label +from torchvision.datapoints import Image +from torchvision.prototype.datapoints import Label from torchvision.prototype.datasets.utils import Dataset, KaggleDownloadResource, OnlineResource from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling diff --git a/torchvision/prototype/datasets/_builtin/gtsrb.py b/torchvision/prototype/datasets/_builtin/gtsrb.py index adcc31b277a..85116ca3860 100644 --- a/torchvision/prototype/datasets/_builtin/gtsrb.py +++ b/torchvision/prototype/datasets/_builtin/gtsrb.py @@ -2,7 +2,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union from torchdata.datapipes.iter import CSVDictParser, Demultiplexer, Filter, IterDataPipe, Mapper, Zipper -from torchvision.prototype.datapoints import BoundingBox, Label +from torchvision.datapoints import BoundingBox +from torchvision.prototype.datapoints import Label from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( hint_sharding, diff --git a/torchvision/prototype/datasets/_builtin/mnist.py b/torchvision/prototype/datasets/_builtin/mnist.py index 9364aa3ade9..8f22a33ae01 100644 --- a/torchvision/prototype/datasets/_builtin/mnist.py +++ b/torchvision/prototype/datasets/_builtin/mnist.py @@ -7,7 +7,8 @@ import torch from torchdata.datapipes.iter import Decompressor, Demultiplexer, IterDataPipe, Mapper, Zipper -from torchvision.prototype.datapoints import Image, Label +from torchvision.datapoints import Image +from torchvision.prototype.datapoints import Label from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling, INFINITE_BUFFER_SIZE from torchvision.prototype.utils._internal import fromfile diff --git a/torchvision/prototype/datasets/_builtin/pcam.py b/torchvision/prototype/datasets/_builtin/pcam.py index 9de224b95f0..4de5ae2765b 100644 --- a/torchvision/prototype/datasets/_builtin/pcam.py +++ b/torchvision/prototype/datasets/_builtin/pcam.py @@ -4,7 +4,8 @@ from typing import Any, Dict, Iterator, List, Optional, Tuple, Union from torchdata.datapipes.iter import IterDataPipe, Mapper, Zipper -from torchvision.prototype.datapoints import Image, Label +from torchvision.datapoints import Image +from torchvision.prototype.datapoints import Label from torchvision.prototype.datasets.utils import Dataset, GDriveResource, OnlineResource from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling diff --git a/torchvision/prototype/datasets/_builtin/semeion.py b/torchvision/prototype/datasets/_builtin/semeion.py index 9ae2c17ab5d..92e1b93b410 100644 --- a/torchvision/prototype/datasets/_builtin/semeion.py +++ b/torchvision/prototype/datasets/_builtin/semeion.py @@ -3,7 +3,8 @@ import torch from torchdata.datapipes.iter import CSVParser, IterDataPipe, Mapper -from torchvision.prototype.datapoints import Image, OneHotLabel +from torchvision.datapoints import Image +from torchvision.prototype.datapoints import OneHotLabel from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling diff --git a/torchvision/prototype/datasets/_builtin/stanford_cars.py b/torchvision/prototype/datasets/_builtin/stanford_cars.py index 02db37169c1..a76b2dba270 100644 --- a/torchvision/prototype/datasets/_builtin/stanford_cars.py +++ b/torchvision/prototype/datasets/_builtin/stanford_cars.py @@ -2,7 +2,8 @@ from typing import Any, BinaryIO, Dict, Iterator, List, Tuple, Union from torchdata.datapipes.iter import Filter, IterDataPipe, Mapper, Zipper -from torchvision.prototype.datapoints import BoundingBox, Label +from torchvision.datapoints import BoundingBox +from torchvision.prototype.datapoints import Label from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( hint_sharding, diff --git a/torchvision/prototype/datasets/_builtin/svhn.py b/torchvision/prototype/datasets/_builtin/svhn.py index d276298ca02..94de4cf42c3 100644 --- a/torchvision/prototype/datasets/_builtin/svhn.py +++ b/torchvision/prototype/datasets/_builtin/svhn.py @@ -3,7 +3,8 @@ import numpy as np from torchdata.datapipes.iter import IterDataPipe, Mapper, UnBatcher -from torchvision.prototype.datapoints import Image, Label +from torchvision.datapoints import Image +from torchvision.prototype.datapoints import Label from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling, read_mat diff --git a/torchvision/prototype/datasets/_builtin/usps.py b/torchvision/prototype/datasets/_builtin/usps.py index 7d1fed04e07..b5486669e21 100644 --- a/torchvision/prototype/datasets/_builtin/usps.py +++ b/torchvision/prototype/datasets/_builtin/usps.py @@ -3,7 +3,8 @@ import torch from torchdata.datapipes.iter import Decompressor, IterDataPipe, LineReader, Mapper -from torchvision.prototype.datapoints import Image, Label +from torchvision.datapoints import Image +from torchvision.prototype.datapoints import Label from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling diff --git a/torchvision/prototype/datasets/_builtin/voc.py b/torchvision/prototype/datasets/_builtin/voc.py index d14189132be..a13cfb764e4 100644 --- a/torchvision/prototype/datasets/_builtin/voc.py +++ b/torchvision/prototype/datasets/_builtin/voc.py @@ -5,8 +5,9 @@ from xml.etree import ElementTree from torchdata.datapipes.iter import Demultiplexer, Filter, IterDataPipe, IterKeyZipper, LineReader, Mapper +from torchvision.datapoints import BoundingBox from torchvision.datasets import VOCDetection -from torchvision.prototype.datapoints import BoundingBox, Label +from torchvision.prototype.datapoints import Label from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( getitem, diff --git a/torchvision/prototype/datasets/utils/_encoded.py b/torchvision/prototype/datasets/utils/_encoded.py index 64cd9f7b951..8adc1e57acb 100644 --- a/torchvision/prototype/datasets/utils/_encoded.py +++ b/torchvision/prototype/datasets/utils/_encoded.py @@ -7,7 +7,7 @@ import PIL.Image import torch -from torchvision.prototype.datapoints._datapoint import Datapoint +from torchvision.datapoints._datapoint import Datapoint from torchvision.prototype.utils._internal import fromfile, ReadOnlyTensorBuffer D = TypeVar("D", bound="EncodedData") diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index ff3b758454a..4f8fdef484c 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -1,59 +1,6 @@ -from torchvision.transforms import AutoAugmentPolicy, InterpolationMode # usort: skip - -from . import functional, utils # usort: skip - -from ._transform import Transform # usort: skip from ._presets import StereoMatching # usort: skip -from ._augment import RandomCutmix, RandomErasing, RandomMixup, SimpleCopyPaste -from ._auto_augment import AugMix, AutoAugment, RandAugment, TrivialAugmentWide -from ._color import ( - ColorJitter, - Grayscale, - RandomAdjustSharpness, - RandomAutocontrast, - RandomEqualize, - RandomGrayscale, - RandomInvert, - RandomPhotometricDistort, - RandomPosterize, - RandomSolarize, -) -from ._container import Compose, RandomApply, RandomChoice, RandomOrder -from ._geometry import ( - CenterCrop, - ElasticTransform, - FiveCrop, - FixedSizeCrop, - Pad, - RandomAffine, - RandomCrop, - RandomHorizontalFlip, - RandomIoUCrop, - RandomPerspective, - RandomResize, - RandomResizedCrop, - RandomRotation, - RandomShortestSize, - RandomVerticalFlip, - RandomZoomOut, - Resize, - ScaleJitter, - TenCrop, -) -from ._meta import ClampBoundingBox, ConvertBoundingBoxFormat, ConvertDtype, ConvertImageDtype -from ._misc import ( - GaussianBlur, - Identity, - Lambda, - LinearTransformation, - Normalize, - PermuteDimensions, - SanitizeBoundingBoxes, - ToDtype, - TransposeDimensions, -) -from ._temporal import UniformTemporalSubsample -from ._type_conversion import LabelToOneHot, PILToTensor, ToImagePIL, ToImageTensor, ToPILImage - -from ._deprecated import ToTensor # usort: skip +from ._augment import RandomCutmix, RandomMixup, SimpleCopyPaste +from ._geometry import FixedSizeCrop +from ._misc import PermuteDimensions, TransposeDimensions +from ._type_conversion import LabelToOneHot diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index 3ceabba5e42..afa411b4896 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -1,112 +1,17 @@ import math -import numbers -import warnings from typing import Any, cast, Dict, List, Optional, Tuple, Union import PIL.Image import torch from torch.utils._pytree import tree_flatten, tree_unflatten -from torchvision import transforms as _transforms +from torchvision import datapoints from torchvision.ops import masks_to_boxes -from torchvision.prototype import datapoints -from torchvision.prototype.transforms import functional as F, InterpolationMode, Transform -from torchvision.prototype.transforms.functional._geometry import _check_interpolation +from torchvision.prototype import datapoints as proto_datapoints +from torchvision.transforms.v2 import functional as F, InterpolationMode, Transform -from ._transform import _RandomApplyTransform -from .utils import has_any, is_simple_tensor, query_chw, query_spatial_size - - -class RandomErasing(_RandomApplyTransform): - _v1_transform_cls = _transforms.RandomErasing - - def _extract_params_for_v1_transform(self) -> Dict[str, Any]: - return dict( - super()._extract_params_for_v1_transform(), - value="random" if self.value is None else self.value, - ) - - _transformed_types = (is_simple_tensor, datapoints.Image, PIL.Image.Image, datapoints.Video) - - def __init__( - self, - p: float = 0.5, - scale: Tuple[float, float] = (0.02, 0.33), - ratio: Tuple[float, float] = (0.3, 3.3), - value: float = 0.0, - inplace: bool = False, - ): - super().__init__(p=p) - if not isinstance(value, (numbers.Number, str, tuple, list)): - raise TypeError("Argument value should be either a number or str or a sequence") - if isinstance(value, str) and value != "random": - raise ValueError("If value is str, it should be 'random'") - if not isinstance(scale, (tuple, list)): - raise TypeError("Scale should be a sequence") - if not isinstance(ratio, (tuple, list)): - raise TypeError("Ratio should be a sequence") - if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): - warnings.warn("Scale and ratio should be of kind (min, max)") - if scale[0] < 0 or scale[1] > 1: - raise ValueError("Scale should be between 0 and 1") - self.scale = scale - self.ratio = ratio - if isinstance(value, (int, float)): - self.value = [float(value)] - elif isinstance(value, str): - self.value = None - elif isinstance(value, (list, tuple)): - self.value = [float(v) for v in value] - else: - self.value = value - self.inplace = inplace - - self._log_ratio = torch.log(torch.tensor(self.ratio)) - - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: - img_c, img_h, img_w = query_chw(flat_inputs) - - if self.value is not None and not (len(self.value) in (1, img_c)): - raise ValueError( - f"If value is a sequence, it should have either a single value or {img_c} (number of inpt channels)" - ) - - area = img_h * img_w - - log_ratio = self._log_ratio - for _ in range(10): - erase_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item() - aspect_ratio = torch.exp( - torch.empty(1).uniform_( - log_ratio[0], # type: ignore[arg-type] - log_ratio[1], # type: ignore[arg-type] - ) - ).item() - - h = int(round(math.sqrt(erase_area * aspect_ratio))) - w = int(round(math.sqrt(erase_area / aspect_ratio))) - if not (h < img_h and w < img_w): - continue - - if self.value is None: - v = torch.empty([img_c, h, w], dtype=torch.float32).normal_() - else: - v = torch.tensor(self.value)[:, None, None] - - i = torch.randint(0, img_h - h + 1, size=(1,)).item() - j = torch.randint(0, img_w - w + 1, size=(1,)).item() - break - else: - i, j, h, w, v = 0, 0, img_h, img_w, None - - return dict(i=i, j=j, h=h, w=w, v=v) - - def _transform( - self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any] - ) -> Union[datapoints.ImageType, datapoints.VideoType]: - if params["v"] is not None: - inpt = F.erase(inpt, **params, inplace=self.inplace) - - return inpt +from torchvision.transforms.v2._transform import _RandomApplyTransform +from torchvision.transforms.v2.functional._geometry import _check_interpolation +from torchvision.transforms.v2.utils import has_any, is_simple_tensor, query_spatial_size class _BaseMixupCutmix(_RandomApplyTransform): @@ -118,19 +23,19 @@ def __init__(self, alpha: float, p: float = 0.5) -> None: def _check_inputs(self, flat_inputs: List[Any]) -> None: if not ( has_any(flat_inputs, datapoints.Image, datapoints.Video, is_simple_tensor) - and has_any(flat_inputs, datapoints.OneHotLabel) + and has_any(flat_inputs, proto_datapoints.OneHotLabel) ): raise TypeError(f"{type(self).__name__}() is only defined for tensor images/videos and one-hot labels.") - if has_any(flat_inputs, PIL.Image.Image, datapoints.BoundingBox, datapoints.Mask, datapoints.Label): + if has_any(flat_inputs, PIL.Image.Image, datapoints.BoundingBox, datapoints.Mask, proto_datapoints.Label): raise TypeError( f"{type(self).__name__}() does not support PIL images, bounding boxes, masks and plain labels." ) - def _mixup_onehotlabel(self, inpt: datapoints.OneHotLabel, lam: float) -> datapoints.OneHotLabel: + def _mixup_onehotlabel(self, inpt: proto_datapoints.OneHotLabel, lam: float) -> proto_datapoints.OneHotLabel: if inpt.ndim < 2: raise ValueError("Need a batch of one hot labels") output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam)) - return datapoints.OneHotLabel.wrap_like(inpt, output) + return proto_datapoints.OneHotLabel.wrap_like(inpt, output) class RandomMixup(_BaseMixupCutmix): @@ -149,7 +54,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: output = type(inpt).wrap_like(inpt, output) # type: ignore[arg-type] return output - elif isinstance(inpt, datapoints.OneHotLabel): + elif isinstance(inpt, proto_datapoints.OneHotLabel): return self._mixup_onehotlabel(inpt, lam) else: return inpt @@ -193,7 +98,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: output = inpt.wrap_like(inpt, output) # type: ignore[arg-type] return output - elif isinstance(inpt, datapoints.OneHotLabel): + elif isinstance(inpt, proto_datapoints.OneHotLabel): lam_adjusted = params["lam_adjusted"] return self._mixup_onehotlabel(inpt, lam_adjusted) else: @@ -307,7 +212,7 @@ def _extract_image_targets( bboxes.append(obj) elif isinstance(obj, datapoints.Mask): masks.append(obj) - elif isinstance(obj, (datapoints.Label, datapoints.OneHotLabel)): + elif isinstance(obj, (proto_datapoints.Label, proto_datapoints.OneHotLabel)): labels.append(obj) if not (len(images) == len(bboxes) == len(masks) == len(labels)): @@ -345,7 +250,7 @@ def _insert_outputs( elif isinstance(obj, datapoints.Mask): flat_sample[i] = datapoints.Mask.wrap_like(obj, output_targets[c2]["masks"]) c2 += 1 - elif isinstance(obj, (datapoints.Label, datapoints.OneHotLabel)): + elif isinstance(obj, (proto_datapoints.Label, proto_datapoints.OneHotLabel)): flat_sample[i] = obj.wrap_like(obj, output_targets[c3]["labels"]) # type: ignore[arg-type] c3 += 1 diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 69238760be5..fa4ccef2eb9 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -1,829 +1,13 @@ -import math -import numbers -import warnings -from typing import Any, cast, Dict, List, Literal, Optional, Sequence, Tuple, Type, Union +from typing import Any, Dict, List, Optional, Sequence, Type, Union import PIL.Image import torch -from torchvision import transforms as _transforms -from torchvision.ops.boxes import box_iou -from torchvision.prototype import datapoints -from torchvision.prototype.transforms import functional as F, InterpolationMode, Transform -from torchvision.prototype.transforms.functional._geometry import _check_interpolation -from torchvision.transforms.functional import _get_perspective_coeffs - -from ._transform import _RandomApplyTransform -from ._utils import ( - _check_padding_arg, - _check_padding_mode_arg, - _check_sequence_input, - _setup_angle, - _setup_fill_arg, - _setup_float_or_seq, - _setup_size, -) -from .utils import has_all, has_any, is_simple_tensor, query_bounding_box, query_spatial_size - - -class RandomHorizontalFlip(_RandomApplyTransform): - _v1_transform_cls = _transforms.RandomHorizontalFlip - - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.horizontal_flip(inpt) - - -class RandomVerticalFlip(_RandomApplyTransform): - _v1_transform_cls = _transforms.RandomVerticalFlip - - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.vertical_flip(inpt) - - -class Resize(Transform): - _v1_transform_cls = _transforms.Resize - - def __init__( - self, - size: Union[int, Sequence[int]], - interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, - max_size: Optional[int] = None, - antialias: Optional[Union[str, bool]] = "warn", - ) -> None: - super().__init__() - - if isinstance(size, int): - size = [size] - elif isinstance(size, (list, tuple)) and len(size) in {1, 2}: - size = list(size) - else: - raise ValueError( - f"size can either be an integer or a list or tuple of one or two integers, " f"but got {size} instead." - ) - self.size = size - - self.interpolation = _check_interpolation(interpolation) - self.max_size = max_size - self.antialias = antialias - - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.resize( - inpt, - self.size, - interpolation=self.interpolation, - max_size=self.max_size, - antialias=self.antialias, - ) - - -class CenterCrop(Transform): - _v1_transform_cls = _transforms.CenterCrop - - def __init__(self, size: Union[int, Sequence[int]]): - super().__init__() - self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") - - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.center_crop(inpt, output_size=self.size) - - -class RandomResizedCrop(Transform): - _v1_transform_cls = _transforms.RandomResizedCrop - - def __init__( - self, - size: Union[int, Sequence[int]], - scale: Tuple[float, float] = (0.08, 1.0), - ratio: Tuple[float, float] = (3.0 / 4.0, 4.0 / 3.0), - interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, - antialias: Optional[Union[str, bool]] = "warn", - ) -> None: - super().__init__() - self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") - - if not isinstance(scale, Sequence): - raise TypeError("Scale should be a sequence") - scale = cast(Tuple[float, float], scale) - if not isinstance(ratio, Sequence): - raise TypeError("Ratio should be a sequence") - ratio = cast(Tuple[float, float], ratio) - if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): - warnings.warn("Scale and ratio should be of kind (min, max)") - - self.scale = scale - self.ratio = ratio - self.interpolation = _check_interpolation(interpolation) - self.antialias = antialias - - self._log_ratio = torch.log(torch.tensor(self.ratio)) - - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: - height, width = query_spatial_size(flat_inputs) - area = height * width - - log_ratio = self._log_ratio - for _ in range(10): - target_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item() - aspect_ratio = torch.exp( - torch.empty(1).uniform_( - log_ratio[0], # type: ignore[arg-type] - log_ratio[1], # type: ignore[arg-type] - ) - ).item() - - w = int(round(math.sqrt(target_area * aspect_ratio))) - h = int(round(math.sqrt(target_area / aspect_ratio))) - - if 0 < w <= width and 0 < h <= height: - i = torch.randint(0, height - h + 1, size=(1,)).item() - j = torch.randint(0, width - w + 1, size=(1,)).item() - break - else: - # Fallback to central crop - in_ratio = float(width) / float(height) - if in_ratio < min(self.ratio): - w = width - h = int(round(w / min(self.ratio))) - elif in_ratio > max(self.ratio): - h = height - w = int(round(h * max(self.ratio))) - else: # whole image - w = width - h = height - i = (height - h) // 2 - j = (width - w) // 2 - - return dict(top=i, left=j, height=h, width=w) - - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.resized_crop( - inpt, **params, size=self.size, interpolation=self.interpolation, antialias=self.antialias - ) - - -ImageOrVideoTypeJIT = Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT] - - -class FiveCrop(Transform): - """ - Example: - >>> class BatchMultiCrop(transforms.Transform): - ... def forward(self, sample: Tuple[Tuple[Union[datapoints.Image, datapoints.Video], ...], datapoints.Label]): - ... images_or_videos, labels = sample - ... batch_size = len(images_or_videos) - ... image_or_video = images_or_videos[0] - ... images_or_videos = image_or_video.wrap_like(image_or_video, torch.stack(images_or_videos)) - ... labels = datapoints.Label.wrap_like(labels, labels.repeat(batch_size)) - ... return images_or_videos, labels - ... - >>> image = datapoints.Image(torch.rand(3, 256, 256)) - >>> label = datapoints.Label(0) - >>> transform = transforms.Compose([transforms.FiveCrop(), BatchMultiCrop()]) - >>> images, labels = transform(image, label) - >>> images.shape - torch.Size([5, 3, 224, 224]) - >>> labels.shape - torch.Size([5]) - """ - - _v1_transform_cls = _transforms.FiveCrop - - _transformed_types = ( - datapoints.Image, - PIL.Image.Image, - is_simple_tensor, - datapoints.Video, - ) - - def __init__(self, size: Union[int, Sequence[int]]) -> None: - super().__init__() - self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") - - def _transform( - self, inpt: ImageOrVideoTypeJIT, params: Dict[str, Any] - ) -> Tuple[ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT]: - return F.five_crop(inpt, self.size) - - def _check_inputs(self, flat_inputs: List[Any]) -> None: - if has_any(flat_inputs, datapoints.BoundingBox, datapoints.Mask): - raise TypeError(f"BoundingBox'es and Mask's are not supported by {type(self).__name__}()") - - -class TenCrop(Transform): - """ - See :class:`~torchvision.prototype.transforms.FiveCrop` for an example. - """ - - _v1_transform_cls = _transforms.TenCrop - - _transformed_types = ( - datapoints.Image, - PIL.Image.Image, - is_simple_tensor, - datapoints.Video, - ) - - def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) -> None: - super().__init__() - self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") - self.vertical_flip = vertical_flip - - def _check_inputs(self, flat_inputs: List[Any]) -> None: - if has_any(flat_inputs, datapoints.BoundingBox, datapoints.Mask): - raise TypeError(f"BoundingBox'es and Mask's are not supported by {type(self).__name__}()") - - def _transform( - self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any] - ) -> Tuple[ - ImageOrVideoTypeJIT, - ImageOrVideoTypeJIT, - ImageOrVideoTypeJIT, - ImageOrVideoTypeJIT, - ImageOrVideoTypeJIT, - ImageOrVideoTypeJIT, - ImageOrVideoTypeJIT, - ImageOrVideoTypeJIT, - ImageOrVideoTypeJIT, - ImageOrVideoTypeJIT, - ]: - return F.ten_crop(inpt, self.size, vertical_flip=self.vertical_flip) - - -class Pad(Transform): - _v1_transform_cls = _transforms.Pad - - def _extract_params_for_v1_transform(self) -> Dict[str, Any]: - params = super()._extract_params_for_v1_transform() - - if not (params["fill"] is None or isinstance(params["fill"], (int, float))): - raise ValueError( - f"{type(self.__name__)}() can only be scripted for a scalar `fill`, but got {self.fill} for images." - ) - - return params - - def __init__( - self, - padding: Union[int, Sequence[int]], - fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0, - padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant", - ) -> None: - super().__init__() - - _check_padding_arg(padding) - _check_padding_mode_arg(padding_mode) - - # This cast does Sequence[int] -> List[int] and is required to make mypy happy - if not isinstance(padding, int): - padding = list(padding) - self.padding = padding - self.fill = _setup_fill_arg(fill) - self.padding_mode = padding_mode - - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - fill = self.fill[type(inpt)] - return F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type] - - -class RandomZoomOut(_RandomApplyTransform): - def __init__( - self, - fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0, - side_range: Sequence[float] = (1.0, 4.0), - p: float = 0.5, - ) -> None: - super().__init__(p=p) - - self.fill = _setup_fill_arg(fill) - - _check_sequence_input(side_range, "side_range", req_sizes=(2,)) - - self.side_range = side_range - if side_range[0] < 1.0 or side_range[0] > side_range[1]: - raise ValueError(f"Invalid canvas side range provided {side_range}.") - - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: - orig_h, orig_w = query_spatial_size(flat_inputs) - - r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0]) - canvas_width = int(orig_w * r) - canvas_height = int(orig_h * r) - - r = torch.rand(2) - left = int((canvas_width - orig_w) * r[0]) - top = int((canvas_height - orig_h) * r[1]) - right = canvas_width - (left + orig_w) - bottom = canvas_height - (top + orig_h) - padding = [left, top, right, bottom] - - return dict(padding=padding) - - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - fill = self.fill[type(inpt)] - return F.pad(inpt, **params, fill=fill) - - -class RandomRotation(Transform): - _v1_transform_cls = _transforms.RandomRotation - - def __init__( - self, - degrees: Union[numbers.Number, Sequence], - interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, - expand: bool = False, - fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0, - center: Optional[List[float]] = None, - ) -> None: - super().__init__() - self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,)) - self.interpolation = _check_interpolation(interpolation) - self.expand = expand - - self.fill = _setup_fill_arg(fill) - - if center is not None: - _check_sequence_input(center, "center", req_sizes=(2,)) - - self.center = center - - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: - angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item() - return dict(angle=angle) - - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - fill = self.fill[type(inpt)] - return F.rotate( - inpt, - **params, - interpolation=self.interpolation, - expand=self.expand, - center=self.center, - fill=fill, - ) - - -class RandomAffine(Transform): - _v1_transform_cls = _transforms.RandomAffine - - def __init__( - self, - degrees: Union[numbers.Number, Sequence], - translate: Optional[Sequence[float]] = None, - scale: Optional[Sequence[float]] = None, - shear: Optional[Union[int, float, Sequence[float]]] = None, - interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, - fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0, - center: Optional[List[float]] = None, - ) -> None: - super().__init__() - self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,)) - if translate is not None: - _check_sequence_input(translate, "translate", req_sizes=(2,)) - for t in translate: - if not (0.0 <= t <= 1.0): - raise ValueError("translation values should be between 0 and 1") - self.translate = translate - if scale is not None: - _check_sequence_input(scale, "scale", req_sizes=(2,)) - for s in scale: - if s <= 0: - raise ValueError("scale values should be positive") - self.scale = scale - - if shear is not None: - self.shear = _setup_angle(shear, name="shear", req_sizes=(2, 4)) - else: - self.shear = shear - - self.interpolation = _check_interpolation(interpolation) - self.fill = _setup_fill_arg(fill) - - if center is not None: - _check_sequence_input(center, "center", req_sizes=(2,)) - - self.center = center - - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: - height, width = query_spatial_size(flat_inputs) - - angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item() - if self.translate is not None: - max_dx = float(self.translate[0] * width) - max_dy = float(self.translate[1] * height) - tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx).item())) - ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy).item())) - translate = (tx, ty) - else: - translate = (0, 0) - - if self.scale is not None: - scale = torch.empty(1).uniform_(self.scale[0], self.scale[1]).item() - else: - scale = 1.0 - - shear_x = shear_y = 0.0 - if self.shear is not None: - shear_x = torch.empty(1).uniform_(self.shear[0], self.shear[1]).item() - if len(self.shear) == 4: - shear_y = torch.empty(1).uniform_(self.shear[2], self.shear[3]).item() - - shear = (shear_x, shear_y) - return dict(angle=angle, translate=translate, scale=scale, shear=shear) - - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - fill = self.fill[type(inpt)] - return F.affine( - inpt, - **params, - interpolation=self.interpolation, - fill=fill, - center=self.center, - ) - - -class RandomCrop(Transform): - _v1_transform_cls = _transforms.RandomCrop - - def _extract_params_for_v1_transform(self) -> Dict[str, Any]: - params = super()._extract_params_for_v1_transform() - - if not (params["fill"] is None or isinstance(params["fill"], (int, float))): - raise ValueError( - f"{type(self.__name__)}() can only be scripted for a scalar `fill`, but got {self.fill} for images." - ) - - padding = self.padding - if padding is not None: - pad_left, pad_right, pad_top, pad_bottom = padding - padding = [pad_left, pad_top, pad_right, pad_bottom] - params["padding"] = padding - - return params - - def __init__( - self, - size: Union[int, Sequence[int]], - padding: Optional[Union[int, Sequence[int]]] = None, - pad_if_needed: bool = False, - fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0, - padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant", - ) -> None: - super().__init__() - - self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") - - if pad_if_needed or padding is not None: - if padding is not None: - _check_padding_arg(padding) - _check_padding_mode_arg(padding_mode) - - self.padding = F._geometry._parse_pad_padding(padding) if padding else None # type: ignore[arg-type] - self.pad_if_needed = pad_if_needed - self.fill = _setup_fill_arg(fill) - self.padding_mode = padding_mode - - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: - padded_height, padded_width = query_spatial_size(flat_inputs) - - if self.padding is not None: - pad_left, pad_right, pad_top, pad_bottom = self.padding - padded_height += pad_top + pad_bottom - padded_width += pad_left + pad_right - else: - pad_left = pad_right = pad_top = pad_bottom = 0 - - cropped_height, cropped_width = self.size - - if self.pad_if_needed: - if padded_height < cropped_height: - diff = cropped_height - padded_height - - pad_top += diff - pad_bottom += diff - padded_height += 2 * diff - - if padded_width < cropped_width: - diff = cropped_width - padded_width - - pad_left += diff - pad_right += diff - padded_width += 2 * diff - - if padded_height < cropped_height or padded_width < cropped_width: - raise ValueError( - f"Required crop size {(cropped_height, cropped_width)} is larger than " - f"{'padded ' if self.padding is not None else ''}input image size {(padded_height, padded_width)}." - ) - - # We need a different order here than we have in self.padding since this padding will be parsed again in `F.pad` - padding = [pad_left, pad_top, pad_right, pad_bottom] - needs_pad = any(padding) - - needs_vert_crop, top = ( - (True, int(torch.randint(0, padded_height - cropped_height + 1, size=()))) - if padded_height > cropped_height - else (False, 0) - ) - needs_horz_crop, left = ( - (True, int(torch.randint(0, padded_width - cropped_width + 1, size=()))) - if padded_width > cropped_width - else (False, 0) - ) - - return dict( - needs_crop=needs_vert_crop or needs_horz_crop, - top=top, - left=left, - height=cropped_height, - width=cropped_width, - needs_pad=needs_pad, - padding=padding, - ) - - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - if params["needs_pad"]: - fill = self.fill[type(inpt)] - inpt = F.pad(inpt, padding=params["padding"], fill=fill, padding_mode=self.padding_mode) - - if params["needs_crop"]: - inpt = F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"]) - - return inpt - - -class RandomPerspective(_RandomApplyTransform): - _v1_transform_cls = _transforms.RandomPerspective - - def __init__( - self, - distortion_scale: float = 0.5, - fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0, - interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, - p: float = 0.5, - ) -> None: - super().__init__(p=p) - - if not (0 <= distortion_scale <= 1): - raise ValueError("Argument distortion_scale value should be between 0 and 1") - - self.distortion_scale = distortion_scale - self.interpolation = _check_interpolation(interpolation) - self.fill = _setup_fill_arg(fill) - - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: - height, width = query_spatial_size(flat_inputs) - - distortion_scale = self.distortion_scale - - half_height = height // 2 - half_width = width // 2 - bound_height = int(distortion_scale * half_height) + 1 - bound_width = int(distortion_scale * half_width) + 1 - topleft = [ - int(torch.randint(0, bound_width, size=(1,))), - int(torch.randint(0, bound_height, size=(1,))), - ] - topright = [ - int(torch.randint(width - bound_width, width, size=(1,))), - int(torch.randint(0, bound_height, size=(1,))), - ] - botright = [ - int(torch.randint(width - bound_width, width, size=(1,))), - int(torch.randint(height - bound_height, height, size=(1,))), - ] - botleft = [ - int(torch.randint(0, bound_width, size=(1,))), - int(torch.randint(height - bound_height, height, size=(1,))), - ] - startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]] - endpoints = [topleft, topright, botright, botleft] - perspective_coeffs = _get_perspective_coeffs(startpoints, endpoints) - return dict(coefficients=perspective_coeffs) - - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - fill = self.fill[type(inpt)] - return F.perspective( - inpt, - None, - None, - fill=fill, - interpolation=self.interpolation, - **params, - ) - - -class ElasticTransform(Transform): - _v1_transform_cls = _transforms.ElasticTransform - - def __init__( - self, - alpha: Union[float, Sequence[float]] = 50.0, - sigma: Union[float, Sequence[float]] = 5.0, - fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0, - interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, - ) -> None: - super().__init__() - self.alpha = _setup_float_or_seq(alpha, "alpha", 2) - self.sigma = _setup_float_or_seq(sigma, "sigma", 2) - - self.interpolation = _check_interpolation(interpolation) - self.fill = _setup_fill_arg(fill) - - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: - size = list(query_spatial_size(flat_inputs)) - - dx = torch.rand([1, 1] + size) * 2 - 1 - if self.sigma[0] > 0.0: - kx = int(8 * self.sigma[0] + 1) - # if kernel size is even we have to make it odd - if kx % 2 == 0: - kx += 1 - dx = F.gaussian_blur(dx, [kx, kx], list(self.sigma)) - dx = dx * self.alpha[0] / size[0] - - dy = torch.rand([1, 1] + size) * 2 - 1 - if self.sigma[1] > 0.0: - ky = int(8 * self.sigma[1] + 1) - # if kernel size is even we have to make it odd - if ky % 2 == 0: - ky += 1 - dy = F.gaussian_blur(dy, [ky, ky], list(self.sigma)) - dy = dy * self.alpha[1] / size[1] - displacement = torch.concat([dx, dy], 1).permute([0, 2, 3, 1]) # 1 x H x W x 2 - return dict(displacement=displacement) - - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - fill = self.fill[type(inpt)] - return F.elastic( - inpt, - **params, - fill=fill, - interpolation=self.interpolation, - ) - - -class RandomIoUCrop(Transform): - def __init__( - self, - min_scale: float = 0.3, - max_scale: float = 1.0, - min_aspect_ratio: float = 0.5, - max_aspect_ratio: float = 2.0, - sampler_options: Optional[List[float]] = None, - trials: int = 40, - ): - super().__init__() - # Configuration similar to https://github.com/weiliu89/caffe/blob/ssd/examples/ssd/ssd_coco.py#L89-L174 - self.min_scale = min_scale - self.max_scale = max_scale - self.min_aspect_ratio = min_aspect_ratio - self.max_aspect_ratio = max_aspect_ratio - if sampler_options is None: - sampler_options = [0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0] - self.options = sampler_options - self.trials = trials - - def _check_inputs(self, flat_inputs: List[Any]) -> None: - if not ( - has_all(flat_inputs, datapoints.BoundingBox) - and has_any(flat_inputs, PIL.Image.Image, datapoints.Image, is_simple_tensor) - and has_any(flat_inputs, datapoints.Label, datapoints.OneHotLabel) - ): - raise TypeError( - f"{type(self).__name__}() requires input sample to contain Images or PIL Images, " - "BoundingBoxes and Labels or OneHotLabels. Sample can also contain Masks." - ) - - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: - orig_h, orig_w = query_spatial_size(flat_inputs) - bboxes = query_bounding_box(flat_inputs) - - while True: - # sample an option - idx = int(torch.randint(low=0, high=len(self.options), size=(1,))) - min_jaccard_overlap = self.options[idx] - if min_jaccard_overlap >= 1.0: # a value larger than 1 encodes the leave as-is option - return dict() - - for _ in range(self.trials): - # check the aspect ratio limitations - r = self.min_scale + (self.max_scale - self.min_scale) * torch.rand(2) - new_w = int(orig_w * r[0]) - new_h = int(orig_h * r[1]) - aspect_ratio = new_w / new_h - if not (self.min_aspect_ratio <= aspect_ratio <= self.max_aspect_ratio): - continue - - # check for 0 area crops - r = torch.rand(2) - left = int((orig_w - new_w) * r[0]) - top = int((orig_h - new_h) * r[1]) - right = left + new_w - bottom = top + new_h - if left == right or top == bottom: - continue - - # check for any valid boxes with centers within the crop area - xyxy_bboxes = F.convert_format_bounding_box( - bboxes.as_subclass(torch.Tensor), bboxes.format, datapoints.BoundingBoxFormat.XYXY - ) - cx = 0.5 * (xyxy_bboxes[..., 0] + xyxy_bboxes[..., 2]) - cy = 0.5 * (xyxy_bboxes[..., 1] + xyxy_bboxes[..., 3]) - is_within_crop_area = (left < cx) & (cx < right) & (top < cy) & (cy < bottom) - if not is_within_crop_area.any(): - continue - - # check at least 1 box with jaccard limitations - xyxy_bboxes = xyxy_bboxes[is_within_crop_area] - ious = box_iou( - xyxy_bboxes, - torch.tensor([[left, top, right, bottom]], dtype=xyxy_bboxes.dtype, device=xyxy_bboxes.device), - ) - if ious.max() < min_jaccard_overlap: - continue - - return dict(top=top, left=left, height=new_h, width=new_w, is_within_crop_area=is_within_crop_area) - - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - if len(params) < 1: - return inpt - - is_within_crop_area = params["is_within_crop_area"] - - if isinstance(inpt, (datapoints.Label, datapoints.OneHotLabel)): - return inpt.wrap_like(inpt, inpt[is_within_crop_area]) # type: ignore[arg-type] - - output = F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"]) - - if isinstance(output, datapoints.BoundingBox): - bboxes = output[is_within_crop_area] - bboxes = F.clamp_bounding_box(bboxes, output.format, output.spatial_size) - output = datapoints.BoundingBox.wrap_like(output, bboxes) - elif isinstance(output, datapoints.Mask): - # apply is_within_crop_area if mask is one-hot encoded - masks = output[is_within_crop_area] - output = datapoints.Mask.wrap_like(output, masks) - - return output - - -class ScaleJitter(Transform): - def __init__( - self, - target_size: Tuple[int, int], - scale_range: Tuple[float, float] = (0.1, 2.0), - interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, - antialias: Optional[Union[str, bool]] = "warn", - ): - super().__init__() - self.target_size = target_size - self.scale_range = scale_range - self.interpolation = _check_interpolation(interpolation) - self.antialias = antialias - - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: - orig_height, orig_width = query_spatial_size(flat_inputs) - - scale = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0]) - r = min(self.target_size[1] / orig_height, self.target_size[0] / orig_width) * scale - new_width = int(orig_width * r) - new_height = int(orig_height * r) - - return dict(size=(new_height, new_width)) - - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.resize(inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias) - - -class RandomShortestSize(Transform): - def __init__( - self, - min_size: Union[List[int], Tuple[int], int], - max_size: Optional[int] = None, - interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, - antialias: Optional[Union[str, bool]] = "warn", - ): - super().__init__() - self.min_size = [min_size] if isinstance(min_size, int) else list(min_size) - self.max_size = max_size - self.interpolation = _check_interpolation(interpolation) - self.antialias = antialias - - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: - orig_height, orig_width = query_spatial_size(flat_inputs) - - min_size = self.min_size[int(torch.randint(len(self.min_size), ()))] - r = min_size / min(orig_height, orig_width) - if self.max_size is not None: - r = min(r, self.max_size / max(orig_height, orig_width)) - - new_width = int(orig_width * r) - new_height = int(orig_height * r) - - return dict(size=(new_height, new_width)) - - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.resize(inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias) +from torchvision import datapoints +from torchvision.prototype.datapoints import Label, OneHotLabel +from torchvision.transforms.v2 import functional as F, Transform +from torchvision.transforms.v2._utils import _setup_fill_arg, _setup_size +from torchvision.transforms.v2.utils import has_any, is_simple_tensor, query_bounding_box, query_spatial_size class FixedSizeCrop(Transform): @@ -854,9 +38,7 @@ def _check_inputs(self, flat_inputs: List[Any]) -> None: f"{type(self).__name__}() requires input sample to contain an tensor or PIL image or a Video." ) - if has_any(flat_inputs, datapoints.BoundingBox) and not has_any( - flat_inputs, datapoints.Label, datapoints.OneHotLabel - ): + if has_any(flat_inputs, datapoints.BoundingBox) and not has_any(flat_inputs, Label, OneHotLabel): raise TypeError( f"If a BoundingBox is contained in the input sample, " f"{type(self).__name__}() also requires it to contain a Label or OneHotLabel." @@ -927,7 +109,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: ) if params["is_valid"] is not None: - if isinstance(inpt, (datapoints.Label, datapoints.OneHotLabel, datapoints.Mask)): + if isinstance(inpt, (Label, OneHotLabel, datapoints.Mask)): inpt = inpt.wrap_like(inpt, inpt[params["is_valid"]]) # type: ignore[arg-type] elif isinstance(inpt, datapoints.BoundingBox): inpt = datapoints.BoundingBox.wrap_like( @@ -940,25 +122,3 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: inpt = F.pad(inpt, params["padding"], fill=fill, padding_mode=self.padding_mode) return inpt - - -class RandomResize(Transform): - def __init__( - self, - min_size: int, - max_size: int, - interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, - antialias: Optional[Union[str, bool]] = "warn", - ) -> None: - super().__init__() - self.min_size = min_size - self.max_size = max_size - self.interpolation = _check_interpolation(interpolation) - self.antialias = antialias - - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: - size = int(torch.randint(self.min_size, self.max_size, ())) - return dict(size=[size]) - - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.resize(inpt, params["size"], interpolation=self.interpolation, antialias=self.antialias) diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index caed3eec904..b51b59a1516 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -1,182 +1,13 @@ -import collections import warnings -from contextlib import suppress -from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Type, Union - -import PIL.Image +from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union import torch -from torch.utils._pytree import tree_flatten, tree_unflatten - -from torchvision import transforms as _transforms -from torchvision.prototype import datapoints -from torchvision.prototype.transforms import functional as F, Transform - -from ._utils import _get_defaultdict, _setup_float_or_seq, _setup_size -from .utils import has_any, is_simple_tensor, query_bounding_box - - -class Identity(Transform): - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return inpt - - -class Lambda(Transform): - def __init__(self, lambd: Callable[[Any], Any], *types: Type): - super().__init__() - self.lambd = lambd - self.types = types or (object,) - - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - if isinstance(inpt, self.types): - return self.lambd(inpt) - else: - return inpt - - def extra_repr(self) -> str: - extras = [] - name = getattr(self.lambd, "__name__", None) - if name: - extras.append(name) - extras.append(f"types={[type.__name__ for type in self.types]}") - return ", ".join(extras) - - -class LinearTransformation(Transform): - _v1_transform_cls = _transforms.LinearTransformation - - _transformed_types = (is_simple_tensor, datapoints.Image, datapoints.Video) - - def __init__(self, transformation_matrix: torch.Tensor, mean_vector: torch.Tensor): - super().__init__() - if transformation_matrix.size(0) != transformation_matrix.size(1): - raise ValueError( - "transformation_matrix should be square. Got " - f"{tuple(transformation_matrix.size())} rectangular matrix." - ) - - if mean_vector.size(0) != transformation_matrix.size(0): - raise ValueError( - f"mean_vector should have the same length {mean_vector.size(0)}" - f" as any one of the dimensions of the transformation_matrix [{tuple(transformation_matrix.size())}]" - ) - - if transformation_matrix.device != mean_vector.device: - raise ValueError( - f"Input tensors should be on the same device. Got {transformation_matrix.device} and {mean_vector.device}" - ) - - if transformation_matrix.dtype != mean_vector.dtype: - raise ValueError( - f"Input tensors should have the same dtype. Got {transformation_matrix.dtype} and {mean_vector.dtype}" - ) - - self.transformation_matrix = transformation_matrix - self.mean_vector = mean_vector - - def _check_inputs(self, sample: Any) -> Any: - if has_any(sample, PIL.Image.Image): - raise TypeError("LinearTransformation does not work on PIL Images") - - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - shape = inpt.shape - n = shape[-3] * shape[-2] * shape[-1] - if n != self.transformation_matrix.shape[0]: - raise ValueError( - "Input tensor and transformation matrix have incompatible shape." - + f"[{shape[-3]} x {shape[-2]} x {shape[-1]}] != " - + f"{self.transformation_matrix.shape[0]}" - ) - - if inpt.device.type != self.mean_vector.device.type: - raise ValueError( - "Input tensor should be on the same device as transformation matrix and mean vector. " - f"Got {inpt.device} vs {self.mean_vector.device}" - ) - - flat_inpt = inpt.reshape(-1, n) - self.mean_vector - - transformation_matrix = self.transformation_matrix.to(flat_inpt.dtype) - output = torch.mm(flat_inpt, transformation_matrix) - output = output.reshape(shape) - - if isinstance(inpt, (datapoints.Image, datapoints.Video)): - output = type(inpt).wrap_like(inpt, output) # type: ignore[arg-type] - return output - - -class Normalize(Transform): - _v1_transform_cls = _transforms.Normalize - _transformed_types = (datapoints.Image, is_simple_tensor, datapoints.Video) - - def __init__(self, mean: Sequence[float], std: Sequence[float], inplace: bool = False): - super().__init__() - self.mean = list(mean) - self.std = list(std) - self.inplace = inplace - - def _check_inputs(self, sample: Any) -> Any: - if has_any(sample, PIL.Image.Image): - raise TypeError(f"{type(self).__name__}() does not support PIL images.") - - def _transform( - self, inpt: Union[datapoints.TensorImageType, datapoints.TensorVideoType], params: Dict[str, Any] - ) -> Any: - return F.normalize(inpt, mean=self.mean, std=self.std, inplace=self.inplace) - - -class GaussianBlur(Transform): - _v1_transform_cls = _transforms.GaussianBlur - - def __init__( - self, kernel_size: Union[int, Sequence[int]], sigma: Union[int, float, Sequence[float]] = (0.1, 2.0) - ) -> None: - super().__init__() - self.kernel_size = _setup_size(kernel_size, "Kernel size should be a tuple/list of two integers") - for ks in self.kernel_size: - if ks <= 0 or ks % 2 == 0: - raise ValueError("Kernel size value should be an odd and positive number.") - - if isinstance(sigma, (int, float)): - if sigma <= 0: - raise ValueError("If sigma is a single number, it must be positive.") - sigma = float(sigma) - elif isinstance(sigma, Sequence) and len(sigma) == 2: - if not 0.0 < sigma[0] <= sigma[1]: - raise ValueError("sigma values should be positive and of the form (min, max).") - else: - raise TypeError("sigma should be a single int or float or a list/tuple with length 2 floats.") - - self.sigma = _setup_float_or_seq(sigma, "sigma", 2) - - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: - sigma = torch.empty(1).uniform_(self.sigma[0], self.sigma[1]).item() - return dict(sigma=[sigma, sigma]) - - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.gaussian_blur(inpt, self.kernel_size, **params) - -class ToDtype(Transform): - _transformed_types = (torch.Tensor,) +from torchvision import datapoints +from torchvision.transforms.v2 import Transform - def __init__(self, dtype: Union[torch.dtype, Dict[Type, Optional[torch.dtype]]]) -> None: - super().__init__() - if not isinstance(dtype, dict): - dtype = _get_defaultdict(dtype) - if torch.Tensor in dtype and any(cls in dtype for cls in [datapoints.Image, datapoints.Video]): - warnings.warn( - "Got `dtype` values for `torch.Tensor` and either `datapoints.Image` or `datapoints.Video`. " - "Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) " - "in case a `datapoints.Image` or `datapoints.Video` is present in the input." - ) - self.dtype = dtype - - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - dtype = self.dtype[type(inpt)] - if dtype is None: - return inpt - return inpt.to(dtype=dtype) +from torchvision.transforms.v2._utils import _get_defaultdict +from torchvision.transforms.v2.utils import is_simple_tensor class PermuteDimensions(Transform): @@ -225,115 +56,3 @@ def _transform( if dims is None: return inpt.as_subclass(torch.Tensor) return inpt.transpose(*dims) - - -class SanitizeBoundingBoxes(Transform): - # This removes boxes and their corresponding labels: - # - small or degenerate bboxes based on min_size (this includes those where X2 <= X1 or Y2 <= Y1) - # - boxes with any coordinate outside the range of the image (negative, or > spatial_size) - - def __init__( - self, - min_size: float = 1.0, - labels_getter: Union[Callable[[Any], Optional[torch.Tensor]], str, None] = "default", - ) -> None: - super().__init__() - - if min_size < 1: - raise ValueError(f"min_size must be >= 1, got {min_size}.") - self.min_size = min_size - - self.labels_getter = labels_getter - self._labels_getter: Optional[Callable[[Any], Optional[torch.Tensor]]] - if labels_getter == "default": - self._labels_getter = self._find_labels_default_heuristic - elif callable(labels_getter): - self._labels_getter = labels_getter - elif isinstance(labels_getter, str): - self._labels_getter = lambda inputs: inputs[labels_getter] - elif labels_getter is None: - self._labels_getter = None - else: - raise ValueError( - "labels_getter should either be a str, callable, or 'default'. " - f"Got {labels_getter} of type {type(labels_getter)}." - ) - - @staticmethod - def _find_labels_default_heuristic(inputs: Dict[str, Any]) -> Optional[torch.Tensor]: - # Tries to find a "label" key, otherwise tries for the first key that contains "label" - case insensitive - # Returns None if nothing is found - candidate_key = None - with suppress(StopIteration): - candidate_key = next(key for key in inputs.keys() if key.lower() == "labels") - if candidate_key is None: - with suppress(StopIteration): - candidate_key = next(key for key in inputs.keys() if "label" in key.lower()) - if candidate_key is None: - raise ValueError( - "Could not infer where the labels are in the sample. Try passing a callable as the labels_getter parameter?" - "If there are no samples and it is by design, pass labels_getter=None." - ) - return inputs[candidate_key] - - def forward(self, *inputs: Any) -> Any: - inputs = inputs if len(inputs) > 1 else inputs[0] - - if isinstance(self.labels_getter, str) and not isinstance(inputs, collections.abc.Mapping): - raise ValueError( - f"If labels_getter is a str or 'default' (got {self.labels_getter}), " - f"then the input to forward() must be a dict. Got {type(inputs)} instead." - ) - - if self._labels_getter is None: - labels = None - else: - labels = self._labels_getter(inputs) - if labels is not None and not isinstance(labels, torch.Tensor): - raise ValueError(f"The labels in the input to forward() must be a tensor, got {type(labels)} instead.") - - flat_inputs, spec = tree_flatten(inputs) - # TODO: this enforces one single BoundingBox entry. - # Assuming this transform needs to be called at the end of *any* pipeline that has bboxes... - # should we just enforce it for all transforms?? What are the benefits of *not* enforcing this? - boxes = query_bounding_box(flat_inputs) - - if boxes.ndim != 2: - raise ValueError(f"boxes must be of shape (num_boxes, 4), got {boxes.shape}") - - if labels is not None and boxes.shape[0] != labels.shape[0]: - raise ValueError( - f"Number of boxes (shape={boxes.shape}) and number of labels (shape={labels.shape}) do not match." - ) - - boxes = cast( - datapoints.BoundingBox, - F.convert_format_bounding_box( - boxes, - new_format=datapoints.BoundingBoxFormat.XYXY, - ), - ) - ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1] - mask = (ws >= self.min_size) & (hs >= self.min_size) & (boxes >= 0).all(dim=-1) - # TODO: Do we really need to check for out of bounds here? All - # transforms should be clamping anyway, so this should never happen? - image_h, image_w = boxes.spatial_size - mask &= (boxes[:, 0] <= image_w) & (boxes[:, 2] <= image_w) - mask &= (boxes[:, 1] <= image_h) & (boxes[:, 3] <= image_h) - - params = dict(mask=mask, labels=labels) - flat_outputs = [ - # Even-though it may look like we're transforming all inputs, we don't: - # _transform() will only care about BoundingBoxes and the labels - self._transform(inpt, params) - for inpt in flat_inputs - ] - - return tree_unflatten(flat_outputs, spec) - - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - - if (inpt is not None and inpt is params["labels"]) or isinstance(inpt, datapoints.BoundingBox): - inpt = inpt[params["mask"]] - - return inpt diff --git a/torchvision/prototype/transforms/_presets.py b/torchvision/prototype/transforms/_presets.py index 7f18e885c39..25c39a90382 100644 --- a/torchvision/prototype/transforms/_presets.py +++ b/torchvision/prototype/transforms/_presets.py @@ -9,9 +9,9 @@ import torch from torch import Tensor -from torchvision.prototype.transforms.functional._geometry import _check_interpolation +from torchvision.transforms.v2 import functional as F, InterpolationMode -from . import functional as F, InterpolationMode +from torchvision.transforms.v2.functional._geometry import _check_interpolation __all__ = ["StereoMatching"] diff --git a/torchvision/prototype/transforms/_type_conversion.py b/torchvision/prototype/transforms/_type_conversion.py index c84aee62afe..4cd3cf46871 100644 --- a/torchvision/prototype/transforms/_type_conversion.py +++ b/torchvision/prototype/transforms/_type_conversion.py @@ -1,67 +1,29 @@ -from typing import Any, Dict, Optional, Union +from typing import Any, Dict -import numpy as np -import PIL.Image import torch from torch.nn.functional import one_hot -from torchvision.prototype import datapoints -from torchvision.prototype.transforms import functional as F, Transform - -from torchvision.prototype.transforms.utils import is_simple_tensor +from torchvision.prototype import datapoints as proto_datapoints +from torchvision.transforms.v2 import Transform class LabelToOneHot(Transform): - _transformed_types = (datapoints.Label,) + _transformed_types = (proto_datapoints.Label,) def __init__(self, num_categories: int = -1): super().__init__() self.num_categories = num_categories - def _transform(self, inpt: datapoints.Label, params: Dict[str, Any]) -> datapoints.OneHotLabel: + def _transform(self, inpt: proto_datapoints.Label, params: Dict[str, Any]) -> proto_datapoints.OneHotLabel: num_categories = self.num_categories if num_categories == -1 and inpt.categories is not None: num_categories = len(inpt.categories) output = one_hot(inpt.as_subclass(torch.Tensor), num_classes=num_categories) - return datapoints.OneHotLabel(output, categories=inpt.categories) + return proto_datapoints.OneHotLabel(output, categories=inpt.categories) def extra_repr(self) -> str: if self.num_categories == -1: return "" return f"num_categories={self.num_categories}" - - -class PILToTensor(Transform): - _transformed_types = (PIL.Image.Image,) - - def _transform(self, inpt: Union[PIL.Image.Image], params: Dict[str, Any]) -> torch.Tensor: - return F.pil_to_tensor(inpt) - - -class ToImageTensor(Transform): - _transformed_types = (is_simple_tensor, PIL.Image.Image, np.ndarray) - - def _transform( - self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any] - ) -> datapoints.Image: - return F.to_image_tensor(inpt) - - -class ToImagePIL(Transform): - _transformed_types = (is_simple_tensor, datapoints.Image, np.ndarray) - - def __init__(self, mode: Optional[str] = None) -> None: - super().__init__() - self.mode = mode - - def _transform( - self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any] - ) -> PIL.Image.Image: - return F.to_image_pil(inpt, mode=self.mode) - - -# We changed the name to align them with the new naming scheme. Still, `ToPILImage` is -# prevalent and well understood. Thus, we just alias it without deprecating the old name. -ToPILImage = ToImagePIL diff --git a/torchvision/transforms/v2/__init__.py b/torchvision/transforms/v2/__init__.py new file mode 100644 index 00000000000..520e0088e82 --- /dev/null +++ b/torchvision/transforms/v2/__init__.py @@ -0,0 +1,47 @@ +from torchvision.transforms import AutoAugmentPolicy, InterpolationMode # usort: skip + +from . import functional, utils # usort: skip + +from ._transform import Transform # usort: skip + +from ._augment import RandomErasing +from ._auto_augment import AugMix, AutoAugment, RandAugment, TrivialAugmentWide +from ._color import ( + ColorJitter, + Grayscale, + RandomAdjustSharpness, + RandomAutocontrast, + RandomEqualize, + RandomGrayscale, + RandomInvert, + RandomPhotometricDistort, + RandomPosterize, + RandomSolarize, +) +from ._container import Compose, RandomApply, RandomChoice, RandomOrder +from ._geometry import ( + CenterCrop, + ElasticTransform, + FiveCrop, + Pad, + RandomAffine, + RandomCrop, + RandomHorizontalFlip, + RandomIoUCrop, + RandomPerspective, + RandomResize, + RandomResizedCrop, + RandomRotation, + RandomShortestSize, + RandomVerticalFlip, + RandomZoomOut, + Resize, + ScaleJitter, + TenCrop, +) +from ._meta import ClampBoundingBox, ConvertBoundingBoxFormat, ConvertDtype, ConvertImageDtype +from ._misc import GaussianBlur, Identity, Lambda, LinearTransformation, Normalize, SanitizeBoundingBoxes, ToDtype +from ._temporal import UniformTemporalSubsample +from ._type_conversion import PILToTensor, ToImagePIL, ToImageTensor, ToPILImage + +from ._deprecated import ToTensor # usort: skip diff --git a/torchvision/transforms/v2/_augment.py b/torchvision/transforms/v2/_augment.py new file mode 100644 index 00000000000..1375400ed0c --- /dev/null +++ b/torchvision/transforms/v2/_augment.py @@ -0,0 +1,105 @@ +import math +import numbers +import warnings +from typing import Any, Dict, List, Tuple, Union + +import PIL.Image +import torch +from torchvision import datapoints, transforms as _transforms +from torchvision.transforms.v2 import functional as F + +from ._transform import _RandomApplyTransform +from .utils import is_simple_tensor, query_chw + + +class RandomErasing(_RandomApplyTransform): + _v1_transform_cls = _transforms.RandomErasing + + def _extract_params_for_v1_transform(self) -> Dict[str, Any]: + return dict( + super()._extract_params_for_v1_transform(), + value="random" if self.value is None else self.value, + ) + + _transformed_types = (is_simple_tensor, datapoints.Image, PIL.Image.Image, datapoints.Video) + + def __init__( + self, + p: float = 0.5, + scale: Tuple[float, float] = (0.02, 0.33), + ratio: Tuple[float, float] = (0.3, 3.3), + value: float = 0.0, + inplace: bool = False, + ): + super().__init__(p=p) + if not isinstance(value, (numbers.Number, str, tuple, list)): + raise TypeError("Argument value should be either a number or str or a sequence") + if isinstance(value, str) and value != "random": + raise ValueError("If value is str, it should be 'random'") + if not isinstance(scale, (tuple, list)): + raise TypeError("Scale should be a sequence") + if not isinstance(ratio, (tuple, list)): + raise TypeError("Ratio should be a sequence") + if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): + warnings.warn("Scale and ratio should be of kind (min, max)") + if scale[0] < 0 or scale[1] > 1: + raise ValueError("Scale should be between 0 and 1") + self.scale = scale + self.ratio = ratio + if isinstance(value, (int, float)): + self.value = [float(value)] + elif isinstance(value, str): + self.value = None + elif isinstance(value, (list, tuple)): + self.value = [float(v) for v in value] + else: + self.value = value + self.inplace = inplace + + self._log_ratio = torch.log(torch.tensor(self.ratio)) + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + img_c, img_h, img_w = query_chw(flat_inputs) + + if self.value is not None and not (len(self.value) in (1, img_c)): + raise ValueError( + f"If value is a sequence, it should have either a single value or {img_c} (number of inpt channels)" + ) + + area = img_h * img_w + + log_ratio = self._log_ratio + for _ in range(10): + erase_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item() + aspect_ratio = torch.exp( + torch.empty(1).uniform_( + log_ratio[0], # type: ignore[arg-type] + log_ratio[1], # type: ignore[arg-type] + ) + ).item() + + h = int(round(math.sqrt(erase_area * aspect_ratio))) + w = int(round(math.sqrt(erase_area / aspect_ratio))) + if not (h < img_h and w < img_w): + continue + + if self.value is None: + v = torch.empty([img_c, h, w], dtype=torch.float32).normal_() + else: + v = torch.tensor(self.value)[:, None, None] + + i = torch.randint(0, img_h - h + 1, size=(1,)).item() + j = torch.randint(0, img_w - w + 1, size=(1,)).item() + break + else: + i, j, h, w, v = 0, 0, img_h, img_w, None + + return dict(i=i, j=j, h=h, w=w, v=v) + + def _transform( + self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any] + ) -> Union[datapoints.ImageType, datapoints.VideoType]: + if params["v"] is not None: + inpt = F.erase(inpt, **params, inplace=self.inplace) + + return inpt diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/transforms/v2/_auto_augment.py similarity index 98% rename from torchvision/prototype/transforms/_auto_augment.py rename to torchvision/transforms/v2/_auto_augment.py index 67afecf5df1..bdc3c89d7f3 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/transforms/v2/_auto_augment.py @@ -5,12 +5,11 @@ import torch from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec -from torchvision import transforms as _transforms -from torchvision.prototype import datapoints -from torchvision.prototype.transforms import AutoAugmentPolicy, functional as F, InterpolationMode, Transform -from torchvision.prototype.transforms.functional._geometry import _check_interpolation -from torchvision.prototype.transforms.functional._meta import get_spatial_size +from torchvision import datapoints, transforms as _transforms from torchvision.transforms import functional_tensor as _FT +from torchvision.transforms.v2 import AutoAugmentPolicy, functional as F, InterpolationMode, Transform +from torchvision.transforms.v2.functional._geometry import _check_interpolation +from torchvision.transforms.v2.functional._meta import get_spatial_size from ._utils import _setup_fill_arg from .utils import check_type, is_simple_tensor diff --git a/torchvision/prototype/transforms/_color.py b/torchvision/transforms/v2/_color.py similarity index 98% rename from torchvision/prototype/transforms/_color.py rename to torchvision/transforms/v2/_color.py index 8ac0d857753..f1b04d775d7 100644 --- a/torchvision/prototype/transforms/_color.py +++ b/torchvision/transforms/v2/_color.py @@ -3,9 +3,8 @@ import PIL.Image import torch -from torchvision import transforms as _transforms -from torchvision.prototype import datapoints -from torchvision.prototype.transforms import functional as F, Transform +from torchvision import datapoints, transforms as _transforms +from torchvision.transforms.v2 import functional as F, Transform from ._transform import _RandomApplyTransform from .utils import is_simple_tensor, query_chw diff --git a/torchvision/prototype/transforms/_container.py b/torchvision/transforms/v2/_container.py similarity index 98% rename from torchvision/prototype/transforms/_container.py rename to torchvision/transforms/v2/_container.py index 42c73a2c11e..555010fda1e 100644 --- a/torchvision/prototype/transforms/_container.py +++ b/torchvision/transforms/v2/_container.py @@ -5,7 +5,7 @@ from torch import nn from torchvision import transforms as _transforms -from torchvision.prototype.transforms import Transform +from torchvision.transforms.v2 import Transform class Compose(Transform): diff --git a/torchvision/prototype/transforms/_deprecated.py b/torchvision/transforms/v2/_deprecated.py similarity index 92% rename from torchvision/prototype/transforms/_deprecated.py rename to torchvision/transforms/v2/_deprecated.py index cd37f4d73d0..bfb0d06239f 100644 --- a/torchvision/prototype/transforms/_deprecated.py +++ b/torchvision/transforms/v2/_deprecated.py @@ -4,10 +4,10 @@ import numpy as np import PIL.Image import torch - -from torchvision.prototype.transforms import Transform from torchvision.transforms import functional as _F +from torchvision.transforms.v2 import Transform + class ToTensor(Transform): _transformed_types = (PIL.Image.Image, np.ndarray) diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py new file mode 100644 index 00000000000..6a8e4a3e033 --- /dev/null +++ b/torchvision/transforms/v2/_geometry.py @@ -0,0 +1,847 @@ +import math +import numbers +import warnings +from typing import Any, cast, Dict, List, Literal, Optional, Sequence, Tuple, Type, Union + +import PIL.Image +import torch + +from torchvision import datapoints, transforms as _transforms +from torchvision.ops.boxes import box_iou +from torchvision.transforms.functional import _get_perspective_coeffs +from torchvision.transforms.v2 import functional as F, InterpolationMode, Transform +from torchvision.transforms.v2.functional._geometry import _check_interpolation + +from ._transform import _RandomApplyTransform +from ._utils import ( + _check_padding_arg, + _check_padding_mode_arg, + _check_sequence_input, + _setup_angle, + _setup_fill_arg, + _setup_float_or_seq, + _setup_size, +) +from .utils import has_all, has_any, is_simple_tensor, query_bounding_box, query_spatial_size + + +class RandomHorizontalFlip(_RandomApplyTransform): + _v1_transform_cls = _transforms.RandomHorizontalFlip + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.horizontal_flip(inpt) + + +class RandomVerticalFlip(_RandomApplyTransform): + _v1_transform_cls = _transforms.RandomVerticalFlip + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.vertical_flip(inpt) + + +class Resize(Transform): + _v1_transform_cls = _transforms.Resize + + def __init__( + self, + size: Union[int, Sequence[int]], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + max_size: Optional[int] = None, + antialias: Optional[Union[str, bool]] = "warn", + ) -> None: + super().__init__() + + if isinstance(size, int): + size = [size] + elif isinstance(size, (list, tuple)) and len(size) in {1, 2}: + size = list(size) + else: + raise ValueError( + f"size can either be an integer or a list or tuple of one or two integers, " f"but got {size} instead." + ) + self.size = size + + self.interpolation = _check_interpolation(interpolation) + self.max_size = max_size + self.antialias = antialias + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.resize( + inpt, + self.size, + interpolation=self.interpolation, + max_size=self.max_size, + antialias=self.antialias, + ) + + +class CenterCrop(Transform): + _v1_transform_cls = _transforms.CenterCrop + + def __init__(self, size: Union[int, Sequence[int]]): + super().__init__() + self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.center_crop(inpt, output_size=self.size) + + +class RandomResizedCrop(Transform): + _v1_transform_cls = _transforms.RandomResizedCrop + + def __init__( + self, + size: Union[int, Sequence[int]], + scale: Tuple[float, float] = (0.08, 1.0), + ratio: Tuple[float, float] = (3.0 / 4.0, 4.0 / 3.0), + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + antialias: Optional[Union[str, bool]] = "warn", + ) -> None: + super().__init__() + self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") + + if not isinstance(scale, Sequence): + raise TypeError("Scale should be a sequence") + scale = cast(Tuple[float, float], scale) + if not isinstance(ratio, Sequence): + raise TypeError("Ratio should be a sequence") + ratio = cast(Tuple[float, float], ratio) + if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): + warnings.warn("Scale and ratio should be of kind (min, max)") + + self.scale = scale + self.ratio = ratio + self.interpolation = _check_interpolation(interpolation) + self.antialias = antialias + + self._log_ratio = torch.log(torch.tensor(self.ratio)) + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + height, width = query_spatial_size(flat_inputs) + area = height * width + + log_ratio = self._log_ratio + for _ in range(10): + target_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item() + aspect_ratio = torch.exp( + torch.empty(1).uniform_( + log_ratio[0], # type: ignore[arg-type] + log_ratio[1], # type: ignore[arg-type] + ) + ).item() + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if 0 < w <= width and 0 < h <= height: + i = torch.randint(0, height - h + 1, size=(1,)).item() + j = torch.randint(0, width - w + 1, size=(1,)).item() + break + else: + # Fallback to central crop + in_ratio = float(width) / float(height) + if in_ratio < min(self.ratio): + w = width + h = int(round(w / min(self.ratio))) + elif in_ratio > max(self.ratio): + h = height + w = int(round(h * max(self.ratio))) + else: # whole image + w = width + h = height + i = (height - h) // 2 + j = (width - w) // 2 + + return dict(top=i, left=j, height=h, width=w) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.resized_crop( + inpt, **params, size=self.size, interpolation=self.interpolation, antialias=self.antialias + ) + + +ImageOrVideoTypeJIT = Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT] + + +class FiveCrop(Transform): + """ + Example: + >>> class BatchMultiCrop(transforms.Transform): + ... def forward(self, sample: Tuple[Tuple[Union[datapoints.Image, datapoints.Video], ...], int]): + ... images_or_videos, labels = sample + ... batch_size = len(images_or_videos) + ... image_or_video = images_or_videos[0] + ... images_or_videos = image_or_video.wrap_like(image_or_video, torch.stack(images_or_videos)) + ... labels = torch.full((batch_size,), label, device=images_or_videos.device) + ... return images_or_videos, labels + ... + >>> image = datapoints.Image(torch.rand(3, 256, 256)) + >>> label = 3 + >>> transform = transforms.Compose([transforms.FiveCrop(224), BatchMultiCrop()]) + >>> images, labels = transform(image, label) + >>> images.shape + torch.Size([5, 3, 224, 224]) + >>> labels + tensor([3, 3, 3, 3, 3]) + """ + + _v1_transform_cls = _transforms.FiveCrop + + _transformed_types = ( + datapoints.Image, + PIL.Image.Image, + is_simple_tensor, + datapoints.Video, + ) + + def __init__(self, size: Union[int, Sequence[int]]) -> None: + super().__init__() + self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") + + def _transform( + self, inpt: ImageOrVideoTypeJIT, params: Dict[str, Any] + ) -> Tuple[ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT]: + return F.five_crop(inpt, self.size) + + def _check_inputs(self, flat_inputs: List[Any]) -> None: + if has_any(flat_inputs, datapoints.BoundingBox, datapoints.Mask): + raise TypeError(f"BoundingBox'es and Mask's are not supported by {type(self).__name__}()") + + +class TenCrop(Transform): + """ + See :class:`~torchvision.transforms.v2.FiveCrop` for an example. + """ + + _v1_transform_cls = _transforms.TenCrop + + _transformed_types = ( + datapoints.Image, + PIL.Image.Image, + is_simple_tensor, + datapoints.Video, + ) + + def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) -> None: + super().__init__() + self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") + self.vertical_flip = vertical_flip + + def _check_inputs(self, flat_inputs: List[Any]) -> None: + if has_any(flat_inputs, datapoints.BoundingBox, datapoints.Mask): + raise TypeError(f"BoundingBox'es and Mask's are not supported by {type(self).__name__}()") + + def _transform( + self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any] + ) -> Tuple[ + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ]: + return F.ten_crop(inpt, self.size, vertical_flip=self.vertical_flip) + + +class Pad(Transform): + _v1_transform_cls = _transforms.Pad + + def _extract_params_for_v1_transform(self) -> Dict[str, Any]: + params = super()._extract_params_for_v1_transform() + + if not (params["fill"] is None or isinstance(params["fill"], (int, float))): + raise ValueError( + f"{type(self.__name__)}() can only be scripted for a scalar `fill`, but got {self.fill} for images." + ) + + return params + + def __init__( + self, + padding: Union[int, Sequence[int]], + fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0, + padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant", + ) -> None: + super().__init__() + + _check_padding_arg(padding) + _check_padding_mode_arg(padding_mode) + + # This cast does Sequence[int] -> List[int] and is required to make mypy happy + if not isinstance(padding, int): + padding = list(padding) + self.padding = padding + self.fill = _setup_fill_arg(fill) + self.padding_mode = padding_mode + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + fill = self.fill[type(inpt)] + return F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type] + + +class RandomZoomOut(_RandomApplyTransform): + def __init__( + self, + fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0, + side_range: Sequence[float] = (1.0, 4.0), + p: float = 0.5, + ) -> None: + super().__init__(p=p) + + self.fill = _setup_fill_arg(fill) + + _check_sequence_input(side_range, "side_range", req_sizes=(2,)) + + self.side_range = side_range + if side_range[0] < 1.0 or side_range[0] > side_range[1]: + raise ValueError(f"Invalid canvas side range provided {side_range}.") + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + orig_h, orig_w = query_spatial_size(flat_inputs) + + r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0]) + canvas_width = int(orig_w * r) + canvas_height = int(orig_h * r) + + r = torch.rand(2) + left = int((canvas_width - orig_w) * r[0]) + top = int((canvas_height - orig_h) * r[1]) + right = canvas_width - (left + orig_w) + bottom = canvas_height - (top + orig_h) + padding = [left, top, right, bottom] + + return dict(padding=padding) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + fill = self.fill[type(inpt)] + return F.pad(inpt, **params, fill=fill) + + +class RandomRotation(Transform): + _v1_transform_cls = _transforms.RandomRotation + + def __init__( + self, + degrees: Union[numbers.Number, Sequence], + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, + expand: bool = False, + fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0, + center: Optional[List[float]] = None, + ) -> None: + super().__init__() + self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,)) + self.interpolation = _check_interpolation(interpolation) + self.expand = expand + + self.fill = _setup_fill_arg(fill) + + if center is not None: + _check_sequence_input(center, "center", req_sizes=(2,)) + + self.center = center + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item() + return dict(angle=angle) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + fill = self.fill[type(inpt)] + return F.rotate( + inpt, + **params, + interpolation=self.interpolation, + expand=self.expand, + center=self.center, + fill=fill, + ) + + +class RandomAffine(Transform): + _v1_transform_cls = _transforms.RandomAffine + + def __init__( + self, + degrees: Union[numbers.Number, Sequence], + translate: Optional[Sequence[float]] = None, + scale: Optional[Sequence[float]] = None, + shear: Optional[Union[int, float, Sequence[float]]] = None, + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, + fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0, + center: Optional[List[float]] = None, + ) -> None: + super().__init__() + self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,)) + if translate is not None: + _check_sequence_input(translate, "translate", req_sizes=(2,)) + for t in translate: + if not (0.0 <= t <= 1.0): + raise ValueError("translation values should be between 0 and 1") + self.translate = translate + if scale is not None: + _check_sequence_input(scale, "scale", req_sizes=(2,)) + for s in scale: + if s <= 0: + raise ValueError("scale values should be positive") + self.scale = scale + + if shear is not None: + self.shear = _setup_angle(shear, name="shear", req_sizes=(2, 4)) + else: + self.shear = shear + + self.interpolation = _check_interpolation(interpolation) + self.fill = _setup_fill_arg(fill) + + if center is not None: + _check_sequence_input(center, "center", req_sizes=(2,)) + + self.center = center + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + height, width = query_spatial_size(flat_inputs) + + angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item() + if self.translate is not None: + max_dx = float(self.translate[0] * width) + max_dy = float(self.translate[1] * height) + tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx).item())) + ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy).item())) + translate = (tx, ty) + else: + translate = (0, 0) + + if self.scale is not None: + scale = torch.empty(1).uniform_(self.scale[0], self.scale[1]).item() + else: + scale = 1.0 + + shear_x = shear_y = 0.0 + if self.shear is not None: + shear_x = torch.empty(1).uniform_(self.shear[0], self.shear[1]).item() + if len(self.shear) == 4: + shear_y = torch.empty(1).uniform_(self.shear[2], self.shear[3]).item() + + shear = (shear_x, shear_y) + return dict(angle=angle, translate=translate, scale=scale, shear=shear) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + fill = self.fill[type(inpt)] + return F.affine( + inpt, + **params, + interpolation=self.interpolation, + fill=fill, + center=self.center, + ) + + +class RandomCrop(Transform): + _v1_transform_cls = _transforms.RandomCrop + + def _extract_params_for_v1_transform(self) -> Dict[str, Any]: + params = super()._extract_params_for_v1_transform() + + if not (params["fill"] is None or isinstance(params["fill"], (int, float))): + raise ValueError( + f"{type(self.__name__)}() can only be scripted for a scalar `fill`, but got {self.fill} for images." + ) + + padding = self.padding + if padding is not None: + pad_left, pad_right, pad_top, pad_bottom = padding + padding = [pad_left, pad_top, pad_right, pad_bottom] + params["padding"] = padding + + return params + + def __init__( + self, + size: Union[int, Sequence[int]], + padding: Optional[Union[int, Sequence[int]]] = None, + pad_if_needed: bool = False, + fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0, + padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant", + ) -> None: + super().__init__() + + self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") + + if pad_if_needed or padding is not None: + if padding is not None: + _check_padding_arg(padding) + _check_padding_mode_arg(padding_mode) + + self.padding = F._geometry._parse_pad_padding(padding) if padding else None # type: ignore[arg-type] + self.pad_if_needed = pad_if_needed + self.fill = _setup_fill_arg(fill) + self.padding_mode = padding_mode + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + padded_height, padded_width = query_spatial_size(flat_inputs) + + if self.padding is not None: + pad_left, pad_right, pad_top, pad_bottom = self.padding + padded_height += pad_top + pad_bottom + padded_width += pad_left + pad_right + else: + pad_left = pad_right = pad_top = pad_bottom = 0 + + cropped_height, cropped_width = self.size + + if self.pad_if_needed: + if padded_height < cropped_height: + diff = cropped_height - padded_height + + pad_top += diff + pad_bottom += diff + padded_height += 2 * diff + + if padded_width < cropped_width: + diff = cropped_width - padded_width + + pad_left += diff + pad_right += diff + padded_width += 2 * diff + + if padded_height < cropped_height or padded_width < cropped_width: + raise ValueError( + f"Required crop size {(cropped_height, cropped_width)} is larger than " + f"{'padded ' if self.padding is not None else ''}input image size {(padded_height, padded_width)}." + ) + + # We need a different order here than we have in self.padding since this padding will be parsed again in `F.pad` + padding = [pad_left, pad_top, pad_right, pad_bottom] + needs_pad = any(padding) + + needs_vert_crop, top = ( + (True, int(torch.randint(0, padded_height - cropped_height + 1, size=()))) + if padded_height > cropped_height + else (False, 0) + ) + needs_horz_crop, left = ( + (True, int(torch.randint(0, padded_width - cropped_width + 1, size=()))) + if padded_width > cropped_width + else (False, 0) + ) + + return dict( + needs_crop=needs_vert_crop or needs_horz_crop, + top=top, + left=left, + height=cropped_height, + width=cropped_width, + needs_pad=needs_pad, + padding=padding, + ) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + if params["needs_pad"]: + fill = self.fill[type(inpt)] + inpt = F.pad(inpt, padding=params["padding"], fill=fill, padding_mode=self.padding_mode) + + if params["needs_crop"]: + inpt = F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"]) + + return inpt + + +class RandomPerspective(_RandomApplyTransform): + _v1_transform_cls = _transforms.RandomPerspective + + def __init__( + self, + distortion_scale: float = 0.5, + fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + p: float = 0.5, + ) -> None: + super().__init__(p=p) + + if not (0 <= distortion_scale <= 1): + raise ValueError("Argument distortion_scale value should be between 0 and 1") + + self.distortion_scale = distortion_scale + self.interpolation = _check_interpolation(interpolation) + self.fill = _setup_fill_arg(fill) + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + height, width = query_spatial_size(flat_inputs) + + distortion_scale = self.distortion_scale + + half_height = height // 2 + half_width = width // 2 + bound_height = int(distortion_scale * half_height) + 1 + bound_width = int(distortion_scale * half_width) + 1 + topleft = [ + int(torch.randint(0, bound_width, size=(1,))), + int(torch.randint(0, bound_height, size=(1,))), + ] + topright = [ + int(torch.randint(width - bound_width, width, size=(1,))), + int(torch.randint(0, bound_height, size=(1,))), + ] + botright = [ + int(torch.randint(width - bound_width, width, size=(1,))), + int(torch.randint(height - bound_height, height, size=(1,))), + ] + botleft = [ + int(torch.randint(0, bound_width, size=(1,))), + int(torch.randint(height - bound_height, height, size=(1,))), + ] + startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]] + endpoints = [topleft, topright, botright, botleft] + perspective_coeffs = _get_perspective_coeffs(startpoints, endpoints) + return dict(coefficients=perspective_coeffs) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + fill = self.fill[type(inpt)] + return F.perspective( + inpt, + None, + None, + fill=fill, + interpolation=self.interpolation, + **params, + ) + + +class ElasticTransform(Transform): + _v1_transform_cls = _transforms.ElasticTransform + + def __init__( + self, + alpha: Union[float, Sequence[float]] = 50.0, + sigma: Union[float, Sequence[float]] = 5.0, + fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + ) -> None: + super().__init__() + self.alpha = _setup_float_or_seq(alpha, "alpha", 2) + self.sigma = _setup_float_or_seq(sigma, "sigma", 2) + + self.interpolation = _check_interpolation(interpolation) + self.fill = _setup_fill_arg(fill) + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + size = list(query_spatial_size(flat_inputs)) + + dx = torch.rand([1, 1] + size) * 2 - 1 + if self.sigma[0] > 0.0: + kx = int(8 * self.sigma[0] + 1) + # if kernel size is even we have to make it odd + if kx % 2 == 0: + kx += 1 + dx = F.gaussian_blur(dx, [kx, kx], list(self.sigma)) + dx = dx * self.alpha[0] / size[0] + + dy = torch.rand([1, 1] + size) * 2 - 1 + if self.sigma[1] > 0.0: + ky = int(8 * self.sigma[1] + 1) + # if kernel size is even we have to make it odd + if ky % 2 == 0: + ky += 1 + dy = F.gaussian_blur(dy, [ky, ky], list(self.sigma)) + dy = dy * self.alpha[1] / size[1] + displacement = torch.concat([dx, dy], 1).permute([0, 2, 3, 1]) # 1 x H x W x 2 + return dict(displacement=displacement) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + fill = self.fill[type(inpt)] + return F.elastic( + inpt, + **params, + fill=fill, + interpolation=self.interpolation, + ) + + +class RandomIoUCrop(Transform): + def __init__( + self, + min_scale: float = 0.3, + max_scale: float = 1.0, + min_aspect_ratio: float = 0.5, + max_aspect_ratio: float = 2.0, + sampler_options: Optional[List[float]] = None, + trials: int = 40, + ): + super().__init__() + # Configuration similar to https://github.com/weiliu89/caffe/blob/ssd/examples/ssd/ssd_coco.py#L89-L174 + self.min_scale = min_scale + self.max_scale = max_scale + self.min_aspect_ratio = min_aspect_ratio + self.max_aspect_ratio = max_aspect_ratio + if sampler_options is None: + sampler_options = [0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0] + self.options = sampler_options + self.trials = trials + + def _check_inputs(self, flat_inputs: List[Any]) -> None: + if not ( + has_all(flat_inputs, datapoints.BoundingBox) + and has_any(flat_inputs, PIL.Image.Image, datapoints.Image, is_simple_tensor) + ): + raise TypeError( + f"{type(self).__name__}() requires input sample to contain tensor or PIL images " + "and bounding boxes. Sample can also contain masks." + ) + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + orig_h, orig_w = query_spatial_size(flat_inputs) + bboxes = query_bounding_box(flat_inputs) + + while True: + # sample an option + idx = int(torch.randint(low=0, high=len(self.options), size=(1,))) + min_jaccard_overlap = self.options[idx] + if min_jaccard_overlap >= 1.0: # a value larger than 1 encodes the leave as-is option + return dict() + + for _ in range(self.trials): + # check the aspect ratio limitations + r = self.min_scale + (self.max_scale - self.min_scale) * torch.rand(2) + new_w = int(orig_w * r[0]) + new_h = int(orig_h * r[1]) + aspect_ratio = new_w / new_h + if not (self.min_aspect_ratio <= aspect_ratio <= self.max_aspect_ratio): + continue + + # check for 0 area crops + r = torch.rand(2) + left = int((orig_w - new_w) * r[0]) + top = int((orig_h - new_h) * r[1]) + right = left + new_w + bottom = top + new_h + if left == right or top == bottom: + continue + + # FIXME: I think we can stop here? + + # check for any valid boxes with centers within the crop area + xyxy_bboxes = F.convert_format_bounding_box( + bboxes.as_subclass(torch.Tensor), bboxes.format, datapoints.BoundingBoxFormat.XYXY + ) + cx = 0.5 * (xyxy_bboxes[..., 0] + xyxy_bboxes[..., 2]) + cy = 0.5 * (xyxy_bboxes[..., 1] + xyxy_bboxes[..., 3]) + is_within_crop_area = (left < cx) & (cx < right) & (top < cy) & (cy < bottom) + if not is_within_crop_area.any(): + continue + + # check at least 1 box with jaccard limitations + xyxy_bboxes = xyxy_bboxes[is_within_crop_area] + ious = box_iou( + xyxy_bboxes, + torch.tensor([[left, top, right, bottom]], dtype=xyxy_bboxes.dtype, device=xyxy_bboxes.device), + ) + if ious.max() < min_jaccard_overlap: + continue + + return dict(top=top, left=left, height=new_h, width=new_w, is_within_crop_area=is_within_crop_area) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + # FIXME: refactor this to not remove anything + + if len(params) < 1: + return inpt + + is_within_crop_area = params["is_within_crop_area"] + + output = F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"]) + + if isinstance(output, datapoints.BoundingBox): + bboxes = output[is_within_crop_area] + bboxes = F.clamp_bounding_box(bboxes, output.format, output.spatial_size) + output = datapoints.BoundingBox.wrap_like(output, bboxes) + elif isinstance(output, datapoints.Mask): + # apply is_within_crop_area if mask is one-hot encoded + masks = output[is_within_crop_area] + output = datapoints.Mask.wrap_like(output, masks) + + return output + + +class ScaleJitter(Transform): + def __init__( + self, + target_size: Tuple[int, int], + scale_range: Tuple[float, float] = (0.1, 2.0), + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + antialias: Optional[Union[str, bool]] = "warn", + ): + super().__init__() + self.target_size = target_size + self.scale_range = scale_range + self.interpolation = _check_interpolation(interpolation) + self.antialias = antialias + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + orig_height, orig_width = query_spatial_size(flat_inputs) + + scale = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0]) + r = min(self.target_size[1] / orig_height, self.target_size[0] / orig_width) * scale + new_width = int(orig_width * r) + new_height = int(orig_height * r) + + return dict(size=(new_height, new_width)) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.resize(inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias) + + +class RandomShortestSize(Transform): + def __init__( + self, + min_size: Union[List[int], Tuple[int], int], + max_size: Optional[int] = None, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + antialias: Optional[Union[str, bool]] = "warn", + ): + super().__init__() + self.min_size = [min_size] if isinstance(min_size, int) else list(min_size) + self.max_size = max_size + self.interpolation = _check_interpolation(interpolation) + self.antialias = antialias + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + orig_height, orig_width = query_spatial_size(flat_inputs) + + min_size = self.min_size[int(torch.randint(len(self.min_size), ()))] + r = min_size / min(orig_height, orig_width) + if self.max_size is not None: + r = min(r, self.max_size / max(orig_height, orig_width)) + + new_width = int(orig_width * r) + new_height = int(orig_height * r) + + return dict(size=(new_height, new_width)) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.resize(inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias) + + +class RandomResize(Transform): + def __init__( + self, + min_size: int, + max_size: int, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + antialias: Optional[Union[str, bool]] = "warn", + ) -> None: + super().__init__() + self.min_size = min_size + self.max_size = max_size + self.interpolation = _check_interpolation(interpolation) + self.antialias = antialias + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + size = int(torch.randint(self.min_size, self.max_size, ())) + return dict(size=[size]) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.resize(inpt, params["size"], interpolation=self.interpolation, antialias=self.antialias) diff --git a/torchvision/prototype/transforms/_meta.py b/torchvision/transforms/v2/_meta.py similarity index 90% rename from torchvision/prototype/transforms/_meta.py rename to torchvision/transforms/v2/_meta.py index 79bd5549b2e..6e6655d0b54 100644 --- a/torchvision/prototype/transforms/_meta.py +++ b/torchvision/transforms/v2/_meta.py @@ -2,9 +2,8 @@ import torch -from torchvision import transforms as _transforms -from torchvision.prototype import datapoints -from torchvision.prototype.transforms import functional as F, Transform +from torchvision import datapoints, transforms as _transforms +from torchvision.transforms.v2 import functional as F, Transform from .utils import is_simple_tensor diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py new file mode 100644 index 00000000000..89e743dae3d --- /dev/null +++ b/torchvision/transforms/v2/_misc.py @@ -0,0 +1,290 @@ +import collections +import warnings +from contextlib import suppress +from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Type, Union + +import PIL.Image + +import torch +from torch.utils._pytree import tree_flatten, tree_unflatten + +from torchvision import datapoints, transforms as _transforms +from torchvision.transforms.v2 import functional as F, Transform + +from ._utils import _get_defaultdict, _setup_float_or_seq, _setup_size +from .utils import has_any, is_simple_tensor, query_bounding_box + + +class Identity(Transform): + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return inpt + + +class Lambda(Transform): + def __init__(self, lambd: Callable[[Any], Any], *types: Type): + super().__init__() + self.lambd = lambd + self.types = types or (object,) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + if isinstance(inpt, self.types): + return self.lambd(inpt) + else: + return inpt + + def extra_repr(self) -> str: + extras = [] + name = getattr(self.lambd, "__name__", None) + if name: + extras.append(name) + extras.append(f"types={[type.__name__ for type in self.types]}") + return ", ".join(extras) + + +class LinearTransformation(Transform): + _v1_transform_cls = _transforms.LinearTransformation + + _transformed_types = (is_simple_tensor, datapoints.Image, datapoints.Video) + + def __init__(self, transformation_matrix: torch.Tensor, mean_vector: torch.Tensor): + super().__init__() + if transformation_matrix.size(0) != transformation_matrix.size(1): + raise ValueError( + "transformation_matrix should be square. Got " + f"{tuple(transformation_matrix.size())} rectangular matrix." + ) + + if mean_vector.size(0) != transformation_matrix.size(0): + raise ValueError( + f"mean_vector should have the same length {mean_vector.size(0)}" + f" as any one of the dimensions of the transformation_matrix [{tuple(transformation_matrix.size())}]" + ) + + if transformation_matrix.device != mean_vector.device: + raise ValueError( + f"Input tensors should be on the same device. Got {transformation_matrix.device} and {mean_vector.device}" + ) + + if transformation_matrix.dtype != mean_vector.dtype: + raise ValueError( + f"Input tensors should have the same dtype. Got {transformation_matrix.dtype} and {mean_vector.dtype}" + ) + + self.transformation_matrix = transformation_matrix + self.mean_vector = mean_vector + + def _check_inputs(self, sample: Any) -> Any: + if has_any(sample, PIL.Image.Image): + raise TypeError("LinearTransformation does not work on PIL Images") + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + shape = inpt.shape + n = shape[-3] * shape[-2] * shape[-1] + if n != self.transformation_matrix.shape[0]: + raise ValueError( + "Input tensor and transformation matrix have incompatible shape." + + f"[{shape[-3]} x {shape[-2]} x {shape[-1]}] != " + + f"{self.transformation_matrix.shape[0]}" + ) + + if inpt.device.type != self.mean_vector.device.type: + raise ValueError( + "Input tensor should be on the same device as transformation matrix and mean vector. " + f"Got {inpt.device} vs {self.mean_vector.device}" + ) + + flat_inpt = inpt.reshape(-1, n) - self.mean_vector + + transformation_matrix = self.transformation_matrix.to(flat_inpt.dtype) + output = torch.mm(flat_inpt, transformation_matrix) + output = output.reshape(shape) + + if isinstance(inpt, (datapoints.Image, datapoints.Video)): + output = type(inpt).wrap_like(inpt, output) # type: ignore[arg-type] + return output + + +class Normalize(Transform): + _v1_transform_cls = _transforms.Normalize + _transformed_types = (datapoints.Image, is_simple_tensor, datapoints.Video) + + def __init__(self, mean: Sequence[float], std: Sequence[float], inplace: bool = False): + super().__init__() + self.mean = list(mean) + self.std = list(std) + self.inplace = inplace + + def _check_inputs(self, sample: Any) -> Any: + if has_any(sample, PIL.Image.Image): + raise TypeError(f"{type(self).__name__}() does not support PIL images.") + + def _transform( + self, inpt: Union[datapoints.TensorImageType, datapoints.TensorVideoType], params: Dict[str, Any] + ) -> Any: + return F.normalize(inpt, mean=self.mean, std=self.std, inplace=self.inplace) + + +class GaussianBlur(Transform): + _v1_transform_cls = _transforms.GaussianBlur + + def __init__( + self, kernel_size: Union[int, Sequence[int]], sigma: Union[int, float, Sequence[float]] = (0.1, 2.0) + ) -> None: + super().__init__() + self.kernel_size = _setup_size(kernel_size, "Kernel size should be a tuple/list of two integers") + for ks in self.kernel_size: + if ks <= 0 or ks % 2 == 0: + raise ValueError("Kernel size value should be an odd and positive number.") + + if isinstance(sigma, (int, float)): + if sigma <= 0: + raise ValueError("If sigma is a single number, it must be positive.") + sigma = float(sigma) + elif isinstance(sigma, Sequence) and len(sigma) == 2: + if not 0.0 < sigma[0] <= sigma[1]: + raise ValueError("sigma values should be positive and of the form (min, max).") + else: + raise TypeError("sigma should be a single int or float or a list/tuple with length 2 floats.") + + self.sigma = _setup_float_or_seq(sigma, "sigma", 2) + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + sigma = torch.empty(1).uniform_(self.sigma[0], self.sigma[1]).item() + return dict(sigma=[sigma, sigma]) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.gaussian_blur(inpt, self.kernel_size, **params) + + +class ToDtype(Transform): + _transformed_types = (torch.Tensor,) + + def __init__(self, dtype: Union[torch.dtype, Dict[Type, Optional[torch.dtype]]]) -> None: + super().__init__() + if not isinstance(dtype, dict): + dtype = _get_defaultdict(dtype) + if torch.Tensor in dtype and any(cls in dtype for cls in [datapoints.Image, datapoints.Video]): + warnings.warn( + "Got `dtype` values for `torch.Tensor` and either `datapoints.Image` or `datapoints.Video`. " + "Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) " + "in case a `datapoints.Image` or `datapoints.Video` is present in the input." + ) + self.dtype = dtype + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + dtype = self.dtype[type(inpt)] + if dtype is None: + return inpt + return inpt.to(dtype=dtype) + + +class SanitizeBoundingBoxes(Transform): + # This removes boxes and their corresponding labels: + # - small or degenerate bboxes based on min_size (this includes those where X2 <= X1 or Y2 <= Y1) + # - boxes with any coordinate outside the range of the image (negative, or > spatial_size) + + def __init__( + self, + min_size: float = 1.0, + labels_getter: Union[Callable[[Any], Optional[torch.Tensor]], str, None] = "default", + ) -> None: + super().__init__() + + if min_size < 1: + raise ValueError(f"min_size must be >= 1, got {min_size}.") + self.min_size = min_size + + self.labels_getter = labels_getter + self._labels_getter: Optional[Callable[[Any], Optional[torch.Tensor]]] + if labels_getter == "default": + self._labels_getter = self._find_labels_default_heuristic + elif callable(labels_getter): + self._labels_getter = labels_getter + elif isinstance(labels_getter, str): + self._labels_getter = lambda inputs: inputs[labels_getter] + elif labels_getter is None: + self._labels_getter = None + else: + raise ValueError( + "labels_getter should either be a str, callable, or 'default'. " + f"Got {labels_getter} of type {type(labels_getter)}." + ) + + @staticmethod + def _find_labels_default_heuristic(inputs: Dict[str, Any]) -> Optional[torch.Tensor]: + # Tries to find a "label" key, otherwise tries for the first key that contains "label" - case insensitive + # Returns None if nothing is found + candidate_key = None + with suppress(StopIteration): + candidate_key = next(key for key in inputs.keys() if key.lower() == "labels") + if candidate_key is None: + with suppress(StopIteration): + candidate_key = next(key for key in inputs.keys() if "label" in key.lower()) + if candidate_key is None: + raise ValueError( + "Could not infer where the labels are in the sample. Try passing a callable as the labels_getter parameter?" + "If there are no samples and it is by design, pass labels_getter=None." + ) + return inputs[candidate_key] + + def forward(self, *inputs: Any) -> Any: + inputs = inputs if len(inputs) > 1 else inputs[0] + + if isinstance(self.labels_getter, str) and not isinstance(inputs, collections.abc.Mapping): + raise ValueError( + f"If labels_getter is a str or 'default' (got {self.labels_getter}), " + f"then the input to forward() must be a dict. Got {type(inputs)} instead." + ) + + if self._labels_getter is None: + labels = None + else: + labels = self._labels_getter(inputs) + if labels is not None and not isinstance(labels, torch.Tensor): + raise ValueError(f"The labels in the input to forward() must be a tensor, got {type(labels)} instead.") + + flat_inputs, spec = tree_flatten(inputs) + # TODO: this enforces one single BoundingBox entry. + # Assuming this transform needs to be called at the end of *any* pipeline that has bboxes... + # should we just enforce it for all transforms?? What are the benefits of *not* enforcing this? + boxes = query_bounding_box(flat_inputs) + + if boxes.ndim != 2: + raise ValueError(f"boxes must be of shape (num_boxes, 4), got {boxes.shape}") + + if labels is not None and boxes.shape[0] != labels.shape[0]: + raise ValueError( + f"Number of boxes (shape={boxes.shape}) and number of labels (shape={labels.shape}) do not match." + ) + + boxes = cast( + datapoints.BoundingBox, + F.convert_format_bounding_box( + boxes, + new_format=datapoints.BoundingBoxFormat.XYXY, + ), + ) + ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1] + mask = (ws >= self.min_size) & (hs >= self.min_size) & (boxes >= 0).all(dim=-1) + # TODO: Do we really need to check for out of bounds here? All + # transforms should be clamping anyway, so this should never happen? + image_h, image_w = boxes.spatial_size + mask &= (boxes[:, 0] <= image_w) & (boxes[:, 2] <= image_w) + mask &= (boxes[:, 1] <= image_h) & (boxes[:, 3] <= image_h) + + params = dict(mask=mask, labels=labels) + flat_outputs = [ + # Even-though it may look like we're transforming all inputs, we don't: + # _transform() will only care about BoundingBoxes and the labels + self._transform(inpt, params) + for inpt in flat_inputs + ] + + return tree_unflatten(flat_outputs, spec) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + + if (inpt is not None and inpt is params["labels"]) or isinstance(inpt, datapoints.BoundingBox): + inpt = inpt[params["mask"]] + + return inpt diff --git a/torchvision/prototype/transforms/_temporal.py b/torchvision/transforms/v2/_temporal.py similarity index 53% rename from torchvision/prototype/transforms/_temporal.py rename to torchvision/transforms/v2/_temporal.py index 62fe7f4edf5..ab3b91d6cc2 100644 --- a/torchvision/prototype/transforms/_temporal.py +++ b/torchvision/transforms/v2/_temporal.py @@ -1,18 +1,17 @@ from typing import Any, Dict -from torchvision.prototype import datapoints -from torchvision.prototype.transforms import functional as F, Transform +from torchvision import datapoints +from torchvision.transforms.v2 import functional as F, Transform -from torchvision.prototype.transforms.utils import is_simple_tensor +from torchvision.transforms.v2.utils import is_simple_tensor class UniformTemporalSubsample(Transform): _transformed_types = (is_simple_tensor, datapoints.Video) - def __init__(self, num_samples: int, temporal_dim: int = -4): + def __init__(self, num_samples: int): super().__init__() self.num_samples = num_samples - self.temporal_dim = temporal_dim def _transform(self, inpt: datapoints.VideoType, params: Dict[str, Any]) -> datapoints.VideoType: - return F.uniform_temporal_subsample(inpt, self.num_samples, temporal_dim=self.temporal_dim) + return F.uniform_temporal_subsample(inpt, self.num_samples) diff --git a/torchvision/prototype/transforms/_transform.py b/torchvision/transforms/v2/_transform.py similarity index 98% rename from torchvision/prototype/transforms/_transform.py rename to torchvision/transforms/v2/_transform.py index 7f3c03d5e67..3f92b3c1646 100644 --- a/torchvision/prototype/transforms/_transform.py +++ b/torchvision/transforms/v2/_transform.py @@ -7,8 +7,8 @@ import torch from torch import nn from torch.utils._pytree import tree_flatten, tree_unflatten -from torchvision.prototype import datapoints -from torchvision.prototype.transforms.utils import check_type, has_any, is_simple_tensor +from torchvision import datapoints +from torchvision.transforms.v2.utils import check_type, has_any, is_simple_tensor from torchvision.utils import _log_api_usage_once diff --git a/torchvision/transforms/v2/_type_conversion.py b/torchvision/transforms/v2/_type_conversion.py new file mode 100644 index 00000000000..984d5ba50c0 --- /dev/null +++ b/torchvision/transforms/v2/_type_conversion.py @@ -0,0 +1,44 @@ +from typing import Any, Dict, Optional, Union + +import numpy as np +import PIL.Image +import torch + +from torchvision import datapoints +from torchvision.transforms.v2 import functional as F, Transform + +from torchvision.transforms.v2.utils import is_simple_tensor + + +class PILToTensor(Transform): + _transformed_types = (PIL.Image.Image,) + + def _transform(self, inpt: PIL.Image.Image, params: Dict[str, Any]) -> torch.Tensor: + return F.pil_to_tensor(inpt) + + +class ToImageTensor(Transform): + _transformed_types = (is_simple_tensor, PIL.Image.Image, np.ndarray) + + def _transform( + self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any] + ) -> datapoints.Image: + return F.to_image_tensor(inpt) + + +class ToImagePIL(Transform): + _transformed_types = (is_simple_tensor, datapoints.Image, np.ndarray) + + def __init__(self, mode: Optional[str] = None) -> None: + super().__init__() + self.mode = mode + + def _transform( + self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any] + ) -> PIL.Image.Image: + return F.to_image_pil(inpt, mode=self.mode) + + +# We changed the name to align them with the new naming scheme. Still, `ToPILImage` is +# prevalent and well understood. Thus, we just alias it without deprecating the old name. +ToPILImage = ToImagePIL diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/transforms/v2/_utils.py similarity index 96% rename from torchvision/prototype/transforms/_utils.py rename to torchvision/transforms/v2/_utils.py index f2d818b1326..d68851576d3 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/transforms/v2/_utils.py @@ -3,8 +3,8 @@ from collections import defaultdict from typing import Any, Dict, Literal, Sequence, Type, TypeVar, Union -from torchvision.prototype import datapoints -from torchvision.prototype.datapoints._datapoint import FillType, FillTypeJIT +from torchvision import datapoints +from torchvision.datapoints._datapoint import FillType, FillTypeJIT from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401 diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/transforms/v2/functional/__init__.py similarity index 100% rename from torchvision/prototype/transforms/functional/__init__.py rename to torchvision/transforms/v2/functional/__init__.py diff --git a/torchvision/prototype/transforms/functional/_augment.py b/torchvision/transforms/v2/functional/_augment.py similarity index 97% rename from torchvision/prototype/transforms/functional/_augment.py rename to torchvision/transforms/v2/functional/_augment.py index 0164a0b5b9b..e9d0339a982 100644 --- a/torchvision/prototype/transforms/functional/_augment.py +++ b/torchvision/transforms/v2/functional/_augment.py @@ -3,7 +3,7 @@ import PIL.Image import torch -from torchvision.prototype import datapoints +from torchvision import datapoints from torchvision.transforms.functional import pil_to_tensor, to_pil_image from torchvision.utils import _log_api_usage_once diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/transforms/v2/functional/_color.py similarity index 99% rename from torchvision/prototype/transforms/functional/_color.py rename to torchvision/transforms/v2/functional/_color.py index e1c8bb87cdd..2ebb4f044cb 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -3,7 +3,7 @@ import PIL.Image import torch from torch.nn.functional import conv2d -from torchvision.prototype import datapoints +from torchvision import datapoints from torchvision.transforms import functional_pil as _FP from torchvision.transforms.functional_tensor import _max_value diff --git a/torchvision/prototype/transforms/functional/_deprecated.py b/torchvision/transforms/v2/functional/_deprecated.py similarity index 96% rename from torchvision/prototype/transforms/functional/_deprecated.py rename to torchvision/transforms/v2/functional/_deprecated.py index 09870216059..8f035f70889 100644 --- a/torchvision/prototype/transforms/functional/_deprecated.py +++ b/torchvision/transforms/v2/functional/_deprecated.py @@ -4,7 +4,7 @@ import PIL.Image import torch -from torchvision.prototype import datapoints +from torchvision import datapoints from torchvision.transforms import functional as _F diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py similarity index 99% rename from torchvision/prototype/transforms/functional/_geometry.py rename to torchvision/transforms/v2/functional/_geometry.py index 22731bb157f..c48250f3b96 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -7,7 +7,7 @@ import torch from torch.nn.functional import grid_sample, interpolate, pad as torch_pad -from torchvision.prototype import datapoints +from torchvision import datapoints from torchvision.transforms import functional_pil as _FP from torchvision.transforms.functional import ( _check_antialias, diff --git a/torchvision/prototype/transforms/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py similarity index 99% rename from torchvision/prototype/transforms/functional/_meta.py rename to torchvision/transforms/v2/functional/_meta.py index 5e32516fb8a..c61f7a710d4 100644 --- a/torchvision/prototype/transforms/functional/_meta.py +++ b/torchvision/transforms/v2/functional/_meta.py @@ -2,8 +2,8 @@ import PIL.Image import torch -from torchvision.prototype import datapoints -from torchvision.prototype.datapoints import BoundingBoxFormat +from torchvision import datapoints +from torchvision.datapoints import BoundingBoxFormat from torchvision.transforms import functional_pil as _FP from torchvision.transforms.functional_tensor import _max_value diff --git a/torchvision/prototype/transforms/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py similarity index 99% rename from torchvision/prototype/transforms/functional/_misc.py rename to torchvision/transforms/v2/functional/_misc.py index 9d0a00f88c3..cf728e27825 100644 --- a/torchvision/prototype/transforms/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -5,7 +5,7 @@ import torch from torch.nn.functional import conv2d, pad as torch_pad -from torchvision.prototype import datapoints +from torchvision import datapoints from torchvision.transforms.functional import pil_to_tensor, to_pil_image from torchvision.utils import _log_api_usage_once diff --git a/torchvision/prototype/transforms/functional/_temporal.py b/torchvision/transforms/v2/functional/_temporal.py similarity index 55% rename from torchvision/prototype/transforms/functional/_temporal.py rename to torchvision/transforms/v2/functional/_temporal.py index d39a64534ca..438e6b5199a 100644 --- a/torchvision/prototype/transforms/functional/_temporal.py +++ b/torchvision/transforms/v2/functional/_temporal.py @@ -1,33 +1,27 @@ import torch -from torchvision.prototype import datapoints +from torchvision import datapoints from torchvision.utils import _log_api_usage_once from ._utils import is_simple_tensor -def uniform_temporal_subsample_video(video: torch.Tensor, num_samples: int, temporal_dim: int = -4) -> torch.Tensor: +def uniform_temporal_subsample_video(video: torch.Tensor, num_samples: int) -> torch.Tensor: # Reference: https://github.com/facebookresearch/pytorchvideo/blob/a0a131e/pytorchvideo/transforms/functional.py#L19 - t_max = video.shape[temporal_dim] - 1 + t_max = video.shape[-4] - 1 indices = torch.linspace(0, t_max, num_samples, device=video.device).long() - return torch.index_select(video, temporal_dim, indices) + return torch.index_select(video, -4, indices) -def uniform_temporal_subsample( - inpt: datapoints.VideoTypeJIT, num_samples: int, temporal_dim: int = -4 -) -> datapoints.VideoTypeJIT: +def uniform_temporal_subsample(inpt: datapoints.VideoTypeJIT, num_samples: int) -> datapoints.VideoTypeJIT: if not torch.jit.is_scripting(): _log_api_usage_once(uniform_temporal_subsample) if torch.jit.is_scripting() or is_simple_tensor(inpt): - return uniform_temporal_subsample_video(inpt, num_samples, temporal_dim=temporal_dim) + return uniform_temporal_subsample_video(inpt, num_samples) elif isinstance(inpt, datapoints.Video): - if temporal_dim != -4 and inpt.ndim - 4 != temporal_dim: - raise ValueError("Video inputs must have temporal_dim equivalent to -4") - output = uniform_temporal_subsample_video( - inpt.as_subclass(torch.Tensor), num_samples, temporal_dim=temporal_dim - ) + output = uniform_temporal_subsample_video(inpt.as_subclass(torch.Tensor), num_samples) return datapoints.Video.wrap_like(inpt, output) else: raise TypeError(f"Input can either be a plain tensor or a `Video` datapoint, but got {type(inpt)} instead.") diff --git a/torchvision/prototype/transforms/functional/_type_conversion.py b/torchvision/transforms/v2/functional/_type_conversion.py similarity index 95% rename from torchvision/prototype/transforms/functional/_type_conversion.py rename to torchvision/transforms/v2/functional/_type_conversion.py index 286aa7485da..67572cf4a72 100644 --- a/torchvision/prototype/transforms/functional/_type_conversion.py +++ b/torchvision/transforms/v2/functional/_type_conversion.py @@ -3,7 +3,7 @@ import numpy as np import PIL.Image import torch -from torchvision.prototype import datapoints +from torchvision import datapoints from torchvision.transforms import functional as _F diff --git a/torchvision/prototype/transforms/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py similarity index 70% rename from torchvision/prototype/transforms/functional/_utils.py rename to torchvision/transforms/v2/functional/_utils.py index e4efeb6016f..f31ccb939a5 100644 --- a/torchvision/prototype/transforms/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -1,7 +1,7 @@ from typing import Any import torch -from torchvision.prototype.datapoints._datapoint import Datapoint +from torchvision.datapoints._datapoint import Datapoint def is_simple_tensor(inpt: Any) -> bool: diff --git a/torchvision/prototype/transforms/utils.py b/torchvision/transforms/v2/utils.py similarity index 94% rename from torchvision/prototype/transforms/utils.py rename to torchvision/transforms/v2/utils.py index ff7fff50ced..c4cf481bcd2 100644 --- a/torchvision/prototype/transforms/utils.py +++ b/torchvision/transforms/v2/utils.py @@ -3,10 +3,10 @@ from typing import Any, Callable, List, Tuple, Type, Union import PIL.Image +from torchvision import datapoints from torchvision._utils import sequence_to_str -from torchvision.prototype import datapoints -from torchvision.prototype.transforms.functional import get_dimensions, get_spatial_size, is_simple_tensor +from torchvision.transforms.v2.functional import get_dimensions, get_spatial_size, is_simple_tensor def query_bounding_box(flat_inputs: List[Any]) -> datapoints.BoundingBox: