Skip to content

Commit cdf0f45

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] Add JPEG augmentation (#8316)
Reviewed By: vmoens Differential Revision: D55062770 fbshipit-source-id: 926a1eea4f55cb0b3c1a4f379088c1505ec70479 Co-authored-by: Nicolas Hug <[email protected]> Co-authored-by: Nicolas Hug <[email protected]>
1 parent 6ce2ceb commit cdf0f45

File tree

7 files changed

+179
-4
lines changed

7 files changed

+179
-4
lines changed

docs/source/transforms.rst

+2
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,7 @@ Miscellaneous
407407
v2.SanitizeBoundingBoxes
408408
v2.ClampBoundingBoxes
409409
v2.UniformTemporalSubsample
410+
v2.JPEG
410411

411412
Functionals
412413

@@ -419,6 +420,7 @@ Functionals
419420
v2.functional.sanitize_bounding_boxes
420421
v2.functional.clamp_bounding_boxes
421422
v2.functional.uniform_temporal_subsample
423+
v2.functional.jpeg
422424

423425
.. _conversion_transforms:
424426

gallery/transforms/plot_transforms_illustrations.py

+11
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,17 @@
237237
equalized_imgs = [equalizer(orig_img) for _ in range(4)]
238238
plot([orig_img] + equalized_imgs)
239239

240+
# %%
241+
# JPEG
242+
# ~~~~~~~~~~~~~~
243+
# The :class:`~torchvision.transforms.v2.JPEG` transform
244+
# (see also :func:`~torchvision.transforms.v2.functional.jpeg`)
245+
# applies JPEG compression to the given image with random
246+
# degree of compression.
247+
jpeg = v2.JPEG((5, 50))
248+
jpeg_imgs = [jpeg(orig_img) for _ in range(4)]
249+
plot([orig_img] + jpeg_imgs)
250+
240251
# %%
241252
# Augmentation Transforms
242253
# -----------------------

test/test_transforms_v2.py

+83
Original file line numberDiff line numberDiff line change
@@ -5932,3 +5932,86 @@ def test_errors_functional(self):
59325932

59335933
with pytest.raises(ValueError, match="bouding_boxes must be a tv_tensors.BoundingBoxes instance or a"):
59345934
F.sanitize_bounding_boxes(good_bbox.tolist())
5935+
5936+
5937+
class TestJPEG:
5938+
@pytest.mark.parametrize("quality", [5, 75])
5939+
@pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
5940+
def test_kernel_image(self, quality, color_space):
5941+
check_kernel(F.jpeg_image, make_image(color_space=color_space), quality=quality)
5942+
5943+
def test_kernel_video(self):
5944+
check_kernel(F.jpeg_video, make_video(), quality=5)
5945+
5946+
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video])
5947+
def test_functional(self, make_input):
5948+
check_functional(F.jpeg, make_input(), quality=5)
5949+
5950+
@pytest.mark.parametrize(
5951+
("kernel", "input_type"),
5952+
[
5953+
(F.jpeg_image, torch.Tensor),
5954+
(F._jpeg_image_pil, PIL.Image.Image),
5955+
(F.jpeg_image, tv_tensors.Image),
5956+
(F.jpeg_video, tv_tensors.Video),
5957+
],
5958+
)
5959+
def test_functional_signature(self, kernel, input_type):
5960+
check_functional_kernel_signature_match(F.jpeg, kernel=kernel, input_type=input_type)
5961+
5962+
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video])
5963+
@pytest.mark.parametrize("quality", [5, (10, 20)])
5964+
@pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
5965+
def test_transform(self, make_input, quality, color_space):
5966+
check_transform(transforms.JPEG(quality=quality), make_input(color_space=color_space))
5967+
5968+
@pytest.mark.parametrize("quality", [5])
5969+
def test_functional_image_correctness(self, quality):
5970+
image = make_image()
5971+
5972+
actual = F.jpeg(image, quality=quality)
5973+
expected = F.to_image(F.jpeg(F.to_pil_image(image), quality=quality))
5974+
5975+
# NOTE: this will fail if torchvision and Pillow use different JPEG encoder/decoder
5976+
torch.testing.assert_close(actual, expected, rtol=0, atol=1)
5977+
5978+
@pytest.mark.parametrize("quality", [5, (10, 20)])
5979+
@pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
5980+
@pytest.mark.parametrize("seed", list(range(5)))
5981+
def test_transform_image_correctness(self, quality, color_space, seed):
5982+
image = make_image(color_space=color_space)
5983+
5984+
transform = transforms.JPEG(quality=quality)
5985+
5986+
with freeze_rng_state():
5987+
torch.manual_seed(seed)
5988+
actual = transform(image)
5989+
5990+
torch.manual_seed(seed)
5991+
expected = F.to_image(transform(F.to_pil_image(image)))
5992+
5993+
torch.testing.assert_close(actual, expected, rtol=0, atol=1)
5994+
5995+
@pytest.mark.parametrize("quality", [5, (10, 20)])
5996+
@pytest.mark.parametrize("seed", list(range(10)))
5997+
def test_transform_get_params_bounds(self, quality, seed):
5998+
transform = transforms.JPEG(quality=quality)
5999+
6000+
with freeze_rng_state():
6001+
torch.manual_seed(seed)
6002+
params = transform._get_params([])
6003+
6004+
if isinstance(quality, int):
6005+
assert params["quality"] == quality
6006+
else:
6007+
assert quality[0] <= params["quality"] <= quality[1]
6008+
6009+
@pytest.mark.parametrize("quality", [[0], [0, 0, 0]])
6010+
def test_transform_sequence_len_error(self, quality):
6011+
with pytest.raises(ValueError, match="quality should be a sequence of length 2"):
6012+
transforms.JPEG(quality=quality)
6013+
6014+
@pytest.mark.parametrize("quality", [-1, 0, 150])
6015+
def test_transform_invalid_quality_error(self, quality):
6016+
with pytest.raises(ValueError, match="quality must be an integer from 1 to 100"):
6017+
transforms.JPEG(quality=quality)

torchvision/transforms/v2/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from ._transform import Transform # usort: skip
66

7-
from ._augment import CutMix, MixUp, RandomErasing
7+
from ._augment import CutMix, JPEG, MixUp, RandomErasing
88
from ._auto_augment import AugMix, AutoAugment, RandAugment, TrivialAugmentWide
99
from ._color import (
1010
ColorJitter,

torchvision/transforms/v2/_augment.py

+38-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import math
22
import numbers
33
import warnings
4-
from typing import Any, Callable, Dict, List, Tuple
4+
from typing import Any, Callable, Dict, List, Sequence, Tuple, Union
55

66
import PIL.Image
77
import torch
@@ -11,7 +11,7 @@
1111
from torchvision.transforms.v2 import functional as F
1212

1313
from ._transform import _RandomApplyTransform, Transform
14-
from ._utils import _parse_labels_getter, has_any, is_pure_tensor, query_chw, query_size
14+
from ._utils import _check_sequence_input, _parse_labels_getter, has_any, is_pure_tensor, query_chw, query_size
1515

1616

1717
class RandomErasing(_RandomApplyTransform):
@@ -317,3 +317,39 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
317317
return output
318318
else:
319319
return inpt
320+
321+
322+
class JPEG(Transform):
323+
"""Apply JPEG compression and decompression to the given images.
324+
325+
If the input is a :class:`torch.Tensor`, it is expected
326+
to be of dtype uint8, on CPU, and have [..., 3 or 1, H, W] shape,
327+
where ... means an arbitrary number of leading dimensions.
328+
329+
Args:
330+
quality (sequence or number): JPEG quality, from 1 to 100. Lower means more compression.
331+
If quality is a sequence like (min, max), it specifies the range of JPEG quality to
332+
randomly select from (inclusive of both ends).
333+
334+
Returns:
335+
image with JPEG compression.
336+
"""
337+
338+
def __init__(self, quality: Union[int, Sequence[int]]):
339+
super().__init__()
340+
if isinstance(quality, int):
341+
quality = [quality, quality]
342+
else:
343+
_check_sequence_input(quality, "quality", req_sizes=(2,))
344+
345+
if not (1 <= quality[0] <= quality[1] <= 100 and isinstance(quality[0], int) and isinstance(quality[1], int)):
346+
raise ValueError(f"quality must be an integer from 1 to 100, got {quality =}")
347+
348+
self.quality = quality
349+
350+
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
351+
quality = torch.randint(self.quality[0], self.quality[1] + 1, ()).item()
352+
return dict(quality=quality)
353+
354+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
355+
return self._call_kernel(F.jpeg, inpt, quality=params["quality"])

torchvision/transforms/v2/functional/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
get_size,
2525
) # usort: skip
2626

27-
from ._augment import _erase_image_pil, erase, erase_image, erase_video
27+
from ._augment import _erase_image_pil, _jpeg_image_pil, erase, erase_image, erase_video, jpeg, jpeg_image, jpeg_video
2828
from ._color import (
2929
_adjust_brightness_image_pil,
3030
_adjust_contrast_image_pil,

torchvision/transforms/v2/functional/_augment.py

+43
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
import io
2+
13
import PIL.Image
24

35
import torch
46
from torchvision import tv_tensors
7+
from torchvision.io import decode_jpeg, encode_jpeg
58
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
69
from torchvision.utils import _log_api_usage_once
710

@@ -53,3 +56,43 @@ def erase_video(
5356
video: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
5457
) -> torch.Tensor:
5558
return erase_image(video, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
59+
60+
61+
def jpeg(image: torch.Tensor, quality: int) -> torch.Tensor:
62+
"""See :class:`~torchvision.transforms.v2.JPEG` for details."""
63+
if torch.jit.is_scripting():
64+
return jpeg_image(image, quality=quality)
65+
66+
_log_api_usage_once(jpeg)
67+
68+
kernel = _get_kernel(jpeg, type(image))
69+
return kernel(image, quality=quality)
70+
71+
72+
@_register_kernel_internal(jpeg, torch.Tensor)
73+
@_register_kernel_internal(jpeg, tv_tensors.Image)
74+
def jpeg_image(image: torch.Tensor, quality: int) -> torch.Tensor:
75+
original_shape = image.shape
76+
image = image.view((-1,) + image.shape[-3:])
77+
78+
if image.shape[0] == 0: # degenerate
79+
return image.reshape(original_shape).clone()
80+
81+
image = [decode_jpeg(encode_jpeg(image[i], quality=quality)) for i in range(image.shape[0])]
82+
image = torch.stack(image, dim=0).view(original_shape)
83+
return image
84+
85+
86+
@_register_kernel_internal(jpeg, tv_tensors.Video)
87+
def jpeg_video(video: torch.Tensor, quality: int) -> torch.Tensor:
88+
return jpeg_image(video, quality=quality)
89+
90+
91+
@_register_kernel_internal(jpeg, PIL.Image.Image)
92+
def _jpeg_image_pil(image: PIL.Image.Image, quality: int) -> PIL.Image.Image:
93+
raw_jpeg = io.BytesIO()
94+
image.save(raw_jpeg, format="JPEG", quality=quality)
95+
96+
# we need to copy since PIL.Image.open() will return PIL.JpegImagePlugin.JpegImageFile
97+
# which is a sub-class of PIL.Image.Image. this will fail check_transform() test.
98+
return PIL.Image.open(raw_jpeg).copy()

0 commit comments

Comments
 (0)