diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index d4c94e4760e..54ed18394cd 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -407,6 +407,7 @@ Miscellaneous v2.SanitizeBoundingBoxes v2.ClampBoundingBoxes v2.UniformTemporalSubsample + v2.JPEG Functionals @@ -419,6 +420,7 @@ Functionals v2.functional.sanitize_bounding_boxes v2.functional.clamp_bounding_boxes v2.functional.uniform_temporal_subsample + v2.functional.jpeg .. _conversion_transforms: diff --git a/gallery/transforms/plot_transforms_illustrations.py b/gallery/transforms/plot_transforms_illustrations.py index 95ab455d0fd..2145a74d5e2 100644 --- a/gallery/transforms/plot_transforms_illustrations.py +++ b/gallery/transforms/plot_transforms_illustrations.py @@ -237,6 +237,17 @@ equalized_imgs = [equalizer(orig_img) for _ in range(4)] plot([orig_img] + equalized_imgs) +# %% +# JPEG +# ~~~~~~~~~~~~~~ +# The :class:`~torchvision.transforms.v2.JPEG` transform +# (see also :func:`~torchvision.transforms.v2.functional.jpeg`) +# applies JPEG compression to the given image with random +# degree of compression. +jpeg = v2.JPEG((5, 50)) +jpeg_imgs = [jpeg(orig_img) for _ in range(4)] +plot([orig_img] + jpeg_imgs) + # %% # Augmentation Transforms # ----------------------- diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 1ad47dda02e..32664bdd959 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -5932,3 +5932,86 @@ def test_errors_functional(self): with pytest.raises(ValueError, match="bouding_boxes must be a tv_tensors.BoundingBoxes instance or a"): F.sanitize_bounding_boxes(good_bbox.tolist()) + + +class TestJPEG: + @pytest.mark.parametrize("quality", [5, 75]) + @pytest.mark.parametrize("color_space", ["RGB", "GRAY"]) + def test_kernel_image(self, quality, color_space): + check_kernel(F.jpeg_image, make_image(color_space=color_space), quality=quality) + + def test_kernel_video(self): + check_kernel(F.jpeg_video, make_video(), quality=5) + + @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video]) + def test_functional(self, make_input): + check_functional(F.jpeg, make_input(), quality=5) + + @pytest.mark.parametrize( + ("kernel", "input_type"), + [ + (F.jpeg_image, torch.Tensor), + (F._jpeg_image_pil, PIL.Image.Image), + (F.jpeg_image, tv_tensors.Image), + (F.jpeg_video, tv_tensors.Video), + ], + ) + def test_functional_signature(self, kernel, input_type): + check_functional_kernel_signature_match(F.jpeg, kernel=kernel, input_type=input_type) + + @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video]) + @pytest.mark.parametrize("quality", [5, (10, 20)]) + @pytest.mark.parametrize("color_space", ["RGB", "GRAY"]) + def test_transform(self, make_input, quality, color_space): + check_transform(transforms.JPEG(quality=quality), make_input(color_space=color_space)) + + @pytest.mark.parametrize("quality", [5]) + def test_functional_image_correctness(self, quality): + image = make_image() + + actual = F.jpeg(image, quality=quality) + expected = F.to_image(F.jpeg(F.to_pil_image(image), quality=quality)) + + # NOTE: this will fail if torchvision and Pillow use different JPEG encoder/decoder + torch.testing.assert_close(actual, expected, rtol=0, atol=1) + + @pytest.mark.parametrize("quality", [5, (10, 20)]) + @pytest.mark.parametrize("color_space", ["RGB", "GRAY"]) + @pytest.mark.parametrize("seed", list(range(5))) + def test_transform_image_correctness(self, quality, color_space, seed): + image = make_image(color_space=color_space) + + transform = transforms.JPEG(quality=quality) + + with freeze_rng_state(): + torch.manual_seed(seed) + actual = transform(image) + + torch.manual_seed(seed) + expected = F.to_image(transform(F.to_pil_image(image))) + + torch.testing.assert_close(actual, expected, rtol=0, atol=1) + + @pytest.mark.parametrize("quality", [5, (10, 20)]) + @pytest.mark.parametrize("seed", list(range(10))) + def test_transform_get_params_bounds(self, quality, seed): + transform = transforms.JPEG(quality=quality) + + with freeze_rng_state(): + torch.manual_seed(seed) + params = transform._get_params([]) + + if isinstance(quality, int): + assert params["quality"] == quality + else: + assert quality[0] <= params["quality"] <= quality[1] + + @pytest.mark.parametrize("quality", [[0], [0, 0, 0]]) + def test_transform_sequence_len_error(self, quality): + with pytest.raises(ValueError, match="quality should be a sequence of length 2"): + transforms.JPEG(quality=quality) + + @pytest.mark.parametrize("quality", [-1, 0, 150]) + def test_transform_invalid_quality_error(self, quality): + with pytest.raises(ValueError, match="quality must be an integer from 1 to 100"): + transforms.JPEG(quality=quality) diff --git a/torchvision/transforms/v2/__init__.py b/torchvision/transforms/v2/__init__.py index fea39d3cf20..6dccb8a5b78 100644 --- a/torchvision/transforms/v2/__init__.py +++ b/torchvision/transforms/v2/__init__.py @@ -4,7 +4,7 @@ from ._transform import Transform # usort: skip -from ._augment import CutMix, MixUp, RandomErasing +from ._augment import CutMix, JPEG, MixUp, RandomErasing from ._auto_augment import AugMix, AutoAugment, RandAugment, TrivialAugmentWide from ._color import ( ColorJitter, diff --git a/torchvision/transforms/v2/_augment.py b/torchvision/transforms/v2/_augment.py index caddcac811c..cc645d6c8a8 100644 --- a/torchvision/transforms/v2/_augment.py +++ b/torchvision/transforms/v2/_augment.py @@ -1,7 +1,7 @@ import math import numbers import warnings -from typing import Any, Callable, Dict, List, Tuple +from typing import Any, Callable, Dict, List, Sequence, Tuple, Union import PIL.Image import torch @@ -11,7 +11,7 @@ from torchvision.transforms.v2 import functional as F from ._transform import _RandomApplyTransform, Transform -from ._utils import _parse_labels_getter, has_any, is_pure_tensor, query_chw, query_size +from ._utils import _check_sequence_input, _parse_labels_getter, has_any, is_pure_tensor, query_chw, query_size class RandomErasing(_RandomApplyTransform): @@ -317,3 +317,39 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return output else: return inpt + + +class JPEG(Transform): + """Apply JPEG compression and decompression to the given images. + + If the input is a :class:`torch.Tensor`, it is expected + to be of dtype uint8, on CPU, and have [..., 3 or 1, H, W] shape, + where ... means an arbitrary number of leading dimensions. + + Args: + quality (sequence or number): JPEG quality, from 1 to 100. Lower means more compression. + If quality is a sequence like (min, max), it specifies the range of JPEG quality to + randomly select from (inclusive of both ends). + + Returns: + image with JPEG compression. + """ + + def __init__(self, quality: Union[int, Sequence[int]]): + super().__init__() + if isinstance(quality, int): + quality = [quality, quality] + else: + _check_sequence_input(quality, "quality", req_sizes=(2,)) + + if not (1 <= quality[0] <= quality[1] <= 100 and isinstance(quality[0], int) and isinstance(quality[1], int)): + raise ValueError(f"quality must be an integer from 1 to 100, got {quality =}") + + self.quality = quality + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + quality = torch.randint(self.quality[0], self.quality[1] + 1, ()).item() + return dict(quality=quality) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return self._call_kernel(F.jpeg, inpt, quality=params["quality"]) diff --git a/torchvision/transforms/v2/functional/__init__.py b/torchvision/transforms/v2/functional/__init__.py index 69f5f4521fa..fbc64200984 100644 --- a/torchvision/transforms/v2/functional/__init__.py +++ b/torchvision/transforms/v2/functional/__init__.py @@ -24,7 +24,7 @@ get_size, ) # usort: skip -from ._augment import _erase_image_pil, erase, erase_image, erase_video +from ._augment import _erase_image_pil, _jpeg_image_pil, erase, erase_image, erase_video, jpeg, jpeg_image, jpeg_video from ._color import ( _adjust_brightness_image_pil, _adjust_contrast_image_pil, diff --git a/torchvision/transforms/v2/functional/_augment.py b/torchvision/transforms/v2/functional/_augment.py index 78d4c354160..eac27f37022 100644 --- a/torchvision/transforms/v2/functional/_augment.py +++ b/torchvision/transforms/v2/functional/_augment.py @@ -1,7 +1,10 @@ +import io + import PIL.Image import torch from torchvision import tv_tensors +from torchvision.io import decode_jpeg, encode_jpeg from torchvision.transforms.functional import pil_to_tensor, to_pil_image from torchvision.utils import _log_api_usage_once @@ -53,3 +56,43 @@ def erase_video( video: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False ) -> torch.Tensor: return erase_image(video, i=i, j=j, h=h, w=w, v=v, inplace=inplace) + + +def jpeg(image: torch.Tensor, quality: int) -> torch.Tensor: + """See :class:`~torchvision.transforms.v2.JPEG` for details.""" + if torch.jit.is_scripting(): + return jpeg_image(image, quality=quality) + + _log_api_usage_once(jpeg) + + kernel = _get_kernel(jpeg, type(image)) + return kernel(image, quality=quality) + + +@_register_kernel_internal(jpeg, torch.Tensor) +@_register_kernel_internal(jpeg, tv_tensors.Image) +def jpeg_image(image: torch.Tensor, quality: int) -> torch.Tensor: + original_shape = image.shape + image = image.view((-1,) + image.shape[-3:]) + + if image.shape[0] == 0: # degenerate + return image.reshape(original_shape).clone() + + image = [decode_jpeg(encode_jpeg(image[i], quality=quality)) for i in range(image.shape[0])] + image = torch.stack(image, dim=0).view(original_shape) + return image + + +@_register_kernel_internal(jpeg, tv_tensors.Video) +def jpeg_video(video: torch.Tensor, quality: int) -> torch.Tensor: + return jpeg_image(video, quality=quality) + + +@_register_kernel_internal(jpeg, PIL.Image.Image) +def _jpeg_image_pil(image: PIL.Image.Image, quality: int) -> PIL.Image.Image: + raw_jpeg = io.BytesIO() + image.save(raw_jpeg, format="JPEG", quality=quality) + + # we need to copy since PIL.Image.open() will return PIL.JpegImagePlugin.JpegImageFile + # which is a sub-class of PIL.Image.Image. this will fail check_transform() test. + return PIL.Image.open(raw_jpeg).copy()