|
1 | 1 | import glob |
2 | 2 | import io |
3 | 3 | import os |
| 4 | +import sys |
4 | 5 | import unittest |
| 6 | +from pathlib import Path |
5 | 7 |
|
6 | 8 | import pytest |
7 | 9 | import numpy as np |
8 | 10 | import torch |
9 | 11 | from PIL import Image |
10 | | -from common_utils import get_tmp_dir, needs_cuda |
| 12 | +import torchvision.transforms.functional as F |
| 13 | +from common_utils import get_tmp_dir, needs_cuda, cpu_only |
11 | 14 | from _assert_utils import assert_equal |
12 | 15 |
|
13 | 16 | from torchvision.io.image import ( |
14 | 17 | decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file, |
15 | | - encode_png, write_png, write_file, ImageReadMode) |
| 18 | + encode_png, write_png, write_file, ImageReadMode, read_image) |
16 | 19 |
|
17 | 20 | IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets") |
18 | 21 | FAKEDATA_DIR = os.path.join(IMAGE_ROOT, "fakedata") |
19 | 22 | IMAGE_DIR = os.path.join(FAKEDATA_DIR, "imagefolder") |
20 | 23 | DAMAGED_JPEG = os.path.join(IMAGE_ROOT, 'damaged_jpeg') |
21 | 24 | ENCODE_JPEG = os.path.join(IMAGE_ROOT, "encode_jpeg") |
| 25 | +IS_WINDOWS = sys.platform in ('win32', 'cygwin') |
| 26 | + |
| 27 | + |
| 28 | +def _get_safe_image_name(name): |
| 29 | + # Used when we need to change the pytest "id" for an "image path" parameter. |
| 30 | + # If we don't, the test id (i.e. its name) will contain the whole path to the image, which is machine-specific, |
| 31 | + # and this creates issues when the test is running in a different machine than where it was collected |
| 32 | + # (typically, in fb internal infra) |
| 33 | + return name.split(os.path.sep)[-1] |
22 | 34 |
|
23 | 35 |
|
24 | 36 | def get_images(directory, img_ext): |
@@ -93,72 +105,6 @@ def test_damaged_images(self): |
93 | 105 | with self.assertRaises(RuntimeError): |
94 | 106 | decode_jpeg(data) |
95 | 107 |
|
96 | | - def test_encode_jpeg(self): |
97 | | - for img_path in get_images(ENCODE_JPEG, ".jpg"): |
98 | | - dirname = os.path.dirname(img_path) |
99 | | - filename, _ = os.path.splitext(os.path.basename(img_path)) |
100 | | - write_folder = os.path.join(dirname, 'jpeg_write') |
101 | | - expected_file = os.path.join( |
102 | | - write_folder, '{0}_pil.jpg'.format(filename)) |
103 | | - img = decode_jpeg(read_file(img_path)) |
104 | | - |
105 | | - with open(expected_file, 'rb') as f: |
106 | | - pil_bytes = f.read() |
107 | | - pil_bytes = torch.as_tensor(list(pil_bytes), dtype=torch.uint8) |
108 | | - for src_img in [img, img.contiguous()]: |
109 | | - # PIL sets jpeg quality to 75 by default |
110 | | - jpeg_bytes = encode_jpeg(src_img, quality=75) |
111 | | - assert_equal(jpeg_bytes, pil_bytes) |
112 | | - |
113 | | - with self.assertRaisesRegex( |
114 | | - RuntimeError, "Input tensor dtype should be uint8"): |
115 | | - encode_jpeg(torch.empty((3, 100, 100), dtype=torch.float32)) |
116 | | - |
117 | | - with self.assertRaisesRegex( |
118 | | - ValueError, "Image quality should be a positive number " |
119 | | - "between 1 and 100"): |
120 | | - encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=-1) |
121 | | - |
122 | | - with self.assertRaisesRegex( |
123 | | - ValueError, "Image quality should be a positive number " |
124 | | - "between 1 and 100"): |
125 | | - encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=101) |
126 | | - |
127 | | - with self.assertRaisesRegex( |
128 | | - RuntimeError, "The number of channels should be 1 or 3, got: 5"): |
129 | | - encode_jpeg(torch.empty((5, 100, 100), dtype=torch.uint8)) |
130 | | - |
131 | | - with self.assertRaisesRegex( |
132 | | - RuntimeError, "Input data should be a 3-dimensional tensor"): |
133 | | - encode_jpeg(torch.empty((1, 3, 100, 100), dtype=torch.uint8)) |
134 | | - |
135 | | - with self.assertRaisesRegex( |
136 | | - RuntimeError, "Input data should be a 3-dimensional tensor"): |
137 | | - encode_jpeg(torch.empty((100, 100), dtype=torch.uint8)) |
138 | | - |
139 | | - def test_write_jpeg(self): |
140 | | - with get_tmp_dir() as d: |
141 | | - for img_path in get_images(ENCODE_JPEG, ".jpg"): |
142 | | - data = read_file(img_path) |
143 | | - img = decode_jpeg(data) |
144 | | - |
145 | | - basedir = os.path.dirname(img_path) |
146 | | - filename, _ = os.path.splitext(os.path.basename(img_path)) |
147 | | - torch_jpeg = os.path.join( |
148 | | - d, '{0}_torch.jpg'.format(filename)) |
149 | | - pil_jpeg = os.path.join( |
150 | | - basedir, 'jpeg_write', '{0}_pil.jpg'.format(filename)) |
151 | | - |
152 | | - write_jpeg(img, torch_jpeg, quality=75) |
153 | | - |
154 | | - with open(torch_jpeg, 'rb') as f: |
155 | | - torch_bytes = f.read() |
156 | | - |
157 | | - with open(pil_jpeg, 'rb') as f: |
158 | | - pil_bytes = f.read() |
159 | | - |
160 | | - self.assertEqual(torch_bytes, pil_bytes) |
161 | | - |
162 | 108 | def test_decode_png(self): |
163 | 109 | conversion = [(None, ImageReadMode.UNCHANGED), ("L", ImageReadMode.GRAY), ("LA", ImageReadMode.GRAY_ALPHA), |
164 | 110 | ("RGB", ImageReadMode.RGB), ("RGBA", ImageReadMode.RGB_ALPHA)] |
@@ -282,11 +228,7 @@ def test_write_file_non_ascii(self): |
282 | 228 |
|
283 | 229 | @needs_cuda |
284 | 230 | @pytest.mark.parametrize('img_path', [ |
285 | | - # We need to change the "id" for that parameter. |
286 | | - # If we don't, the test id (i.e. its name) will contain the whole path to the image which is machine-specific, |
287 | | - # and this creates issues when the test is running in a different machine than where it was collected |
288 | | - # (typically, in fb internal infra) |
289 | | - pytest.param(jpeg_path, id=jpeg_path.split('/')[-1]) |
| 231 | + pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) |
290 | 232 | for jpeg_path in get_images(IMAGE_ROOT, ".jpg") |
291 | 233 | ]) |
292 | 234 | @pytest.mark.parametrize('mode', [ImageReadMode.UNCHANGED, ImageReadMode.GRAY, ImageReadMode.RGB]) |
@@ -325,5 +267,146 @@ def test_decode_jpeg_cuda_errors(): |
325 | 267 | torch.ops.image.decode_jpeg_cuda(data, ImageReadMode.UNCHANGED.value, 'cpu') |
326 | 268 |
|
327 | 269 |
|
| 270 | +@cpu_only |
| 271 | +def test_encode_jpeg_errors(): |
| 272 | + |
| 273 | + with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"): |
| 274 | + encode_jpeg(torch.empty((3, 100, 100), dtype=torch.float32)) |
| 275 | + |
| 276 | + with pytest.raises(ValueError, match="Image quality should be a positive number " |
| 277 | + "between 1 and 100"): |
| 278 | + encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=-1) |
| 279 | + |
| 280 | + with pytest.raises(ValueError, match="Image quality should be a positive number " |
| 281 | + "between 1 and 100"): |
| 282 | + encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=101) |
| 283 | + |
| 284 | + with pytest.raises(RuntimeError, match="The number of channels should be 1 or 3, got: 5"): |
| 285 | + encode_jpeg(torch.empty((5, 100, 100), dtype=torch.uint8)) |
| 286 | + |
| 287 | + with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"): |
| 288 | + encode_jpeg(torch.empty((1, 3, 100, 100), dtype=torch.uint8)) |
| 289 | + |
| 290 | + with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"): |
| 291 | + encode_jpeg(torch.empty((100, 100), dtype=torch.uint8)) |
| 292 | + |
| 293 | + |
| 294 | +def _collect_if(cond): |
| 295 | + # TODO: remove this once test_encode_jpeg_windows and test_write_jpeg_windows |
| 296 | + # are removed |
| 297 | + def _inner(test_func): |
| 298 | + if cond: |
| 299 | + return test_func |
| 300 | + else: |
| 301 | + return pytest.mark.dont_collect(test_func) |
| 302 | + return _inner |
| 303 | + |
| 304 | + |
| 305 | +@cpu_only |
| 306 | +@_collect_if(cond=IS_WINDOWS) |
| 307 | +def test_encode_jpeg_windows(): |
| 308 | + # This test is *wrong*. |
| 309 | + # It compares a torchvision-encoded jpeg with a PIL-encoded jpeg, but it |
| 310 | + # starts encoding the torchvision version from an image that comes from |
| 311 | + # decode_jpeg, which can yield different results from pil.decode (see |
| 312 | + # test_decode... which uses a high tolerance). |
| 313 | + # Instead, we should start encoding from the exact same decoded image, for a |
| 314 | + # valid comparison. This is done in test_encode_jpeg, but unfortunately |
| 315 | + # these more correct tests fail on windows (probably because of a difference |
| 316 | + # in libjpeg) between torchvision and PIL. |
| 317 | + # FIXME: make the correct tests pass on windows and remove this. |
| 318 | + for img_path in get_images(ENCODE_JPEG, ".jpg"): |
| 319 | + dirname = os.path.dirname(img_path) |
| 320 | + filename, _ = os.path.splitext(os.path.basename(img_path)) |
| 321 | + write_folder = os.path.join(dirname, 'jpeg_write') |
| 322 | + expected_file = os.path.join( |
| 323 | + write_folder, '{0}_pil.jpg'.format(filename)) |
| 324 | + img = decode_jpeg(read_file(img_path)) |
| 325 | + |
| 326 | + with open(expected_file, 'rb') as f: |
| 327 | + pil_bytes = f.read() |
| 328 | + pil_bytes = torch.as_tensor(list(pil_bytes), dtype=torch.uint8) |
| 329 | + for src_img in [img, img.contiguous()]: |
| 330 | + # PIL sets jpeg quality to 75 by default |
| 331 | + jpeg_bytes = encode_jpeg(src_img, quality=75) |
| 332 | + assert_equal(jpeg_bytes, pil_bytes) |
| 333 | + |
| 334 | + |
| 335 | +@cpu_only |
| 336 | +@_collect_if(cond=IS_WINDOWS) |
| 337 | +def test_write_jpeg_windows(): |
| 338 | + # FIXME: Remove this eventually, see test_encode_jpeg_windows |
| 339 | + with get_tmp_dir() as d: |
| 340 | + for img_path in get_images(ENCODE_JPEG, ".jpg"): |
| 341 | + data = read_file(img_path) |
| 342 | + img = decode_jpeg(data) |
| 343 | + |
| 344 | + basedir = os.path.dirname(img_path) |
| 345 | + filename, _ = os.path.splitext(os.path.basename(img_path)) |
| 346 | + torch_jpeg = os.path.join( |
| 347 | + d, '{0}_torch.jpg'.format(filename)) |
| 348 | + pil_jpeg = os.path.join( |
| 349 | + basedir, 'jpeg_write', '{0}_pil.jpg'.format(filename)) |
| 350 | + |
| 351 | + write_jpeg(img, torch_jpeg, quality=75) |
| 352 | + |
| 353 | + with open(torch_jpeg, 'rb') as f: |
| 354 | + torch_bytes = f.read() |
| 355 | + |
| 356 | + with open(pil_jpeg, 'rb') as f: |
| 357 | + pil_bytes = f.read() |
| 358 | + |
| 359 | + assert_equal(torch_bytes, pil_bytes) |
| 360 | + |
| 361 | + |
| 362 | +@cpu_only |
| 363 | +@_collect_if(cond=not IS_WINDOWS) |
| 364 | +@pytest.mark.parametrize('img_path', [ |
| 365 | + pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) |
| 366 | + for jpeg_path in get_images(ENCODE_JPEG, ".jpg") |
| 367 | +]) |
| 368 | +def test_encode_jpeg(img_path): |
| 369 | + img = read_image(img_path) |
| 370 | + |
| 371 | + pil_img = F.to_pil_image(img) |
| 372 | + buf = io.BytesIO() |
| 373 | + pil_img.save(buf, format='JPEG', quality=75) |
| 374 | + |
| 375 | + # pytorch can't read from raw bytes so we go through numpy |
| 376 | + pil_bytes = np.frombuffer(buf.getvalue(), dtype=np.uint8) |
| 377 | + encoded_jpeg_pil = torch.as_tensor(pil_bytes) |
| 378 | + |
| 379 | + for src_img in [img, img.contiguous()]: |
| 380 | + encoded_jpeg_torch = encode_jpeg(src_img, quality=75) |
| 381 | + assert_equal(encoded_jpeg_torch, encoded_jpeg_pil) |
| 382 | + |
| 383 | + |
| 384 | +@cpu_only |
| 385 | +@_collect_if(cond=not IS_WINDOWS) |
| 386 | +@pytest.mark.parametrize('img_path', [ |
| 387 | + pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) |
| 388 | + for jpeg_path in get_images(ENCODE_JPEG, ".jpg") |
| 389 | +]) |
| 390 | +def test_write_jpeg(img_path): |
| 391 | + with get_tmp_dir() as d: |
| 392 | + d = Path(d) |
| 393 | + img = read_image(img_path) |
| 394 | + pil_img = F.to_pil_image(img) |
| 395 | + |
| 396 | + torch_jpeg = str(d / 'torch.jpg') |
| 397 | + pil_jpeg = str(d / 'pil.jpg') |
| 398 | + |
| 399 | + write_jpeg(img, torch_jpeg, quality=75) |
| 400 | + pil_img.save(pil_jpeg, quality=75) |
| 401 | + |
| 402 | + with open(torch_jpeg, 'rb') as f: |
| 403 | + torch_bytes = f.read() |
| 404 | + |
| 405 | + with open(pil_jpeg, 'rb') as f: |
| 406 | + pil_bytes = f.read() |
| 407 | + |
| 408 | + assert_equal(torch_bytes, pil_bytes) |
| 409 | + |
| 410 | + |
328 | 411 | if __name__ == '__main__': |
329 | 412 | unittest.main() |
0 commit comments