Skip to content

Add JPEG augmentation #8316

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,7 @@ Miscellaneous
v2.SanitizeBoundingBoxes
v2.ClampBoundingBoxes
v2.UniformTemporalSubsample
v2.JPEG

Functionals

Expand All @@ -419,6 +420,7 @@ Functionals
v2.functional.sanitize_bounding_boxes
v2.functional.clamp_bounding_boxes
v2.functional.uniform_temporal_subsample
v2.functional.jpeg

.. _conversion_transforms:

Expand Down
11 changes: 11 additions & 0 deletions gallery/transforms/plot_transforms_illustrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,17 @@
equalized_imgs = [equalizer(orig_img) for _ in range(4)]
plot([orig_img] + equalized_imgs)

# %%
# JPEG
# ~~~~~~~~~~~~~~
# The :class:`~torchvision.transforms.v2.JPEG` transform
# (see also :func:`~torchvision.transforms.v2.functional.jpeg`)
# applies JPEG compression to the given image with random
# degree of compression.
jpeg = v2.JPEG((5, 50))
jpeg_imgs = [jpeg(orig_img) for _ in range(4)]
plot([orig_img] + jpeg_imgs)

# %%
# Augmentation Transforms
# -----------------------
Expand Down
83 changes: 83 additions & 0 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5932,3 +5932,86 @@ def test_errors_functional(self):

with pytest.raises(ValueError, match="bouding_boxes must be a tv_tensors.BoundingBoxes instance or a"):
F.sanitize_bounding_boxes(good_bbox.tolist())


class TestJPEG:
@pytest.mark.parametrize("quality", [5, 75])
@pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
def test_kernel_image(self, quality, color_space):
check_kernel(F.jpeg_image, make_image(color_space=color_space), quality=quality)

def test_kernel_video(self):
check_kernel(F.jpeg_video, make_video(), quality=5)

@pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video])
def test_functional(self, make_input):
check_functional(F.jpeg, make_input(), quality=5)

@pytest.mark.parametrize(
("kernel", "input_type"),
[
(F.jpeg_image, torch.Tensor),
(F._jpeg_image_pil, PIL.Image.Image),
(F.jpeg_image, tv_tensors.Image),
(F.jpeg_video, tv_tensors.Video),
],
)
def test_functional_signature(self, kernel, input_type):
check_functional_kernel_signature_match(F.jpeg, kernel=kernel, input_type=input_type)

@pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video])
@pytest.mark.parametrize("quality", [5, (10, 20)])
@pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
def test_transform(self, make_input, quality, color_space):
check_transform(transforms.JPEG(quality=quality), make_input(color_space=color_space))

@pytest.mark.parametrize("quality", [5])
def test_functional_image_correctness(self, quality):
image = make_image()

actual = F.jpeg(image, quality=quality)
expected = F.to_image(F.jpeg(F.to_pil_image(image), quality=quality))

# NOTE: this will fail if torchvision and Pillow use different JPEG encoder/decoder
torch.testing.assert_close(actual, expected, rtol=0, atol=1)

@pytest.mark.parametrize("quality", [5, (10, 20)])
@pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
@pytest.mark.parametrize("seed", list(range(5)))
def test_transform_image_correctness(self, quality, color_space, seed):
image = make_image(color_space=color_space)

transform = transforms.JPEG(quality=quality)

with freeze_rng_state():
torch.manual_seed(seed)
actual = transform(image)

torch.manual_seed(seed)
expected = F.to_image(transform(F.to_pil_image(image)))

torch.testing.assert_close(actual, expected, rtol=0, atol=1)

@pytest.mark.parametrize("quality", [5, (10, 20)])
@pytest.mark.parametrize("seed", list(range(10)))
def test_transform_get_params_bounds(self, quality, seed):
transform = transforms.JPEG(quality=quality)

with freeze_rng_state():
torch.manual_seed(seed)
params = transform._get_params([])

if isinstance(quality, int):
assert params["quality"] == quality
else:
assert quality[0] <= params["quality"] <= quality[1]

@pytest.mark.parametrize("quality", [[0], [0, 0, 0]])
def test_transform_sequence_len_error(self, quality):
with pytest.raises(ValueError, match="quality should be a sequence of length 2"):
transforms.JPEG(quality=quality)

@pytest.mark.parametrize("quality", [-1, 0, 150])
def test_transform_invalid_quality_error(self, quality):
with pytest.raises(ValueError, match="quality must be an integer from 1 to 100"):
transforms.JPEG(quality=quality)
2 changes: 1 addition & 1 deletion torchvision/transforms/v2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from ._transform import Transform # usort: skip

from ._augment import CutMix, MixUp, RandomErasing
from ._augment import CutMix, JPEG, MixUp, RandomErasing
from ._auto_augment import AugMix, AutoAugment, RandAugment, TrivialAugmentWide
from ._color import (
ColorJitter,
Expand Down
40 changes: 38 additions & 2 deletions torchvision/transforms/v2/_augment.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import math
import numbers
import warnings
from typing import Any, Callable, Dict, List, Tuple
from typing import Any, Callable, Dict, List, Sequence, Tuple, Union

import PIL.Image
import torch
Expand All @@ -11,7 +11,7 @@
from torchvision.transforms.v2 import functional as F

from ._transform import _RandomApplyTransform, Transform
from ._utils import _parse_labels_getter, has_any, is_pure_tensor, query_chw, query_size
from ._utils import _check_sequence_input, _parse_labels_getter, has_any, is_pure_tensor, query_chw, query_size


class RandomErasing(_RandomApplyTransform):
Expand Down Expand Up @@ -317,3 +317,39 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return output
else:
return inpt


class JPEG(Transform):
"""Apply JPEG compression and decompression to the given images.

If the input is a :class:`torch.Tensor`, it is expected
to be of dtype uint8, on CPU, and have [..., 3 or 1, H, W] shape,
where ... means an arbitrary number of leading dimensions.

Args:
quality (sequence or number): JPEG quality, from 1 to 100. Lower means more compression.
If quality is a sequence like (min, max), it specifies the range of JPEG quality to
randomly select from (inclusive of both ends).

Returns:
image with JPEG compression.
"""

def __init__(self, quality: Union[int, Sequence[int]]):
super().__init__()
if isinstance(quality, int):
quality = [quality, quality]
else:
_check_sequence_input(quality, "quality", req_sizes=(2,))

if not (1 <= quality[0] <= quality[1] <= 100 and isinstance(quality[0], int) and isinstance(quality[1], int)):
raise ValueError(f"quality must be an integer from 1 to 100, got {quality =}")

self.quality = quality

def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
quality = torch.randint(self.quality[0], self.quality[1] + 1, ()).item()
return dict(quality=quality)

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return self._call_kernel(F.jpeg, inpt, quality=params["quality"])
2 changes: 1 addition & 1 deletion torchvision/transforms/v2/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
get_size,
) # usort: skip

from ._augment import _erase_image_pil, erase, erase_image, erase_video
from ._augment import _erase_image_pil, _jpeg_image_pil, erase, erase_image, erase_video, jpeg, jpeg_image, jpeg_video
from ._color import (
_adjust_brightness_image_pil,
_adjust_contrast_image_pil,
Expand Down
43 changes: 43 additions & 0 deletions torchvision/transforms/v2/functional/_augment.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import io

import PIL.Image

import torch
from torchvision import tv_tensors
from torchvision.io import decode_jpeg, encode_jpeg
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
from torchvision.utils import _log_api_usage_once

Expand Down Expand Up @@ -53,3 +56,43 @@ def erase_video(
video: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
) -> torch.Tensor:
return erase_image(video, i=i, j=j, h=h, w=w, v=v, inplace=inplace)


def jpeg(image: torch.Tensor, quality: int) -> torch.Tensor:
"""See :class:`~torchvision.transforms.v2.JPEG` for details."""
if torch.jit.is_scripting():
return jpeg_image(image, quality=quality)

_log_api_usage_once(jpeg)

kernel = _get_kernel(jpeg, type(image))
return kernel(image, quality=quality)


@_register_kernel_internal(jpeg, torch.Tensor)
@_register_kernel_internal(jpeg, tv_tensors.Image)
def jpeg_image(image: torch.Tensor, quality: int) -> torch.Tensor:
original_shape = image.shape
image = image.view((-1,) + image.shape[-3:])

if image.shape[0] == 0: # degenerate
return image.reshape(original_shape).clone()

image = [decode_jpeg(encode_jpeg(image[i], quality=quality)) for i in range(image.shape[0])]
image = torch.stack(image, dim=0).view(original_shape)
return image


@_register_kernel_internal(jpeg, tv_tensors.Video)
def jpeg_video(video: torch.Tensor, quality: int) -> torch.Tensor:
return jpeg_image(video, quality=quality)


@_register_kernel_internal(jpeg, PIL.Image.Image)
def _jpeg_image_pil(image: PIL.Image.Image, quality: int) -> PIL.Image.Image:
raw_jpeg = io.BytesIO()
image.save(raw_jpeg, format="JPEG", quality=quality)

# we need to copy since PIL.Image.open() will return PIL.JpegImagePlugin.JpegImageFile
# which is a sub-class of PIL.Image.Image. this will fail check_transform() test.
return PIL.Image.open(raw_jpeg).copy()