Skip to content

Refactors test_image.py so tests don't write files to assets folder #3018

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 30, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 41 additions & 47 deletions test/test_image.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
import os
import io
import glob
import io
import os
import unittest

import numpy as np
import torch
from PIL import Image
from common_utils import get_tmp_dir

from torchvision.io.image import (
decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file,
encode_png, write_png, write_file, ImageReadMode)
import numpy as np

from common_utils import get_tmp_dir


IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
FAKEDATA_DIR = os.path.join(IMAGE_ROOT, "fakedata")
Expand All @@ -22,14 +21,10 @@

def get_images(directory, img_ext):
assert os.path.isdir(directory)
for root, _, files in os.walk(directory):
if os.path.basename(root) in {'damaged_jpeg', 'jpeg_write'}:
continue

for fl in files:
_, ext = os.path.splitext(fl)
if ext == img_ext:
yield os.path.join(root, fl)
image_paths = glob.glob(directory + f'/**/*{img_ext}', recursive=True)
for path in image_paths:
if path.split(os.sep)[-2] not in ['damaged_jpeg', 'jpeg_write']:
yield path


def pil_read_image(img_path):
Expand Down Expand Up @@ -75,7 +70,7 @@ def test_decode_jpeg(self):
decode_jpeg(torch.empty((100, 1), dtype=torch.uint8))

with self.assertRaisesRegex(RuntimeError, "Expected a torch.uint8 tensor"):
decode_jpeg(torch.empty((100, ), dtype=torch.float16))
decode_jpeg(torch.empty((100,), dtype=torch.float16))

with self.assertRaises(RuntimeError):
decode_jpeg(torch.empty((100), dtype=torch.uint8))
Expand Down Expand Up @@ -119,12 +114,12 @@ def test_encode_jpeg(self):

with self.assertRaisesRegex(
ValueError, "Image quality should be a positive number "
"between 1 and 100"):
"between 1 and 100"):
encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=-1)

with self.assertRaisesRegex(
ValueError, "Image quality should be a positive number "
"between 1 and 100"):
"between 1 and 100"):
encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=101)

with self.assertRaisesRegex(
Expand All @@ -140,27 +135,27 @@ def test_encode_jpeg(self):
encode_jpeg(torch.empty((100, 100), dtype=torch.uint8))

def test_write_jpeg(self):
for img_path in get_images(ENCODE_JPEG, ".jpg"):
data = read_file(img_path)
img = decode_jpeg(data)
with get_tmp_dir() as d:
for img_path in get_images(ENCODE_JPEG, ".jpg"):
data = read_file(img_path)
img = decode_jpeg(data)

basedir = os.path.dirname(img_path)
filename, _ = os.path.splitext(os.path.basename(img_path))
torch_jpeg = os.path.join(
basedir, '{0}_torch.jpg'.format(filename))
pil_jpeg = os.path.join(
basedir, 'jpeg_write', '{0}_pil.jpg'.format(filename))
basedir = os.path.dirname(img_path)
filename, _ = os.path.splitext(os.path.basename(img_path))
torch_jpeg = os.path.join(
d, '{0}_torch.jpg'.format(filename))
pil_jpeg = os.path.join(
basedir, 'jpeg_write', '{0}_pil.jpg'.format(filename))

write_jpeg(img, torch_jpeg, quality=75)
write_jpeg(img, torch_jpeg, quality=75)

with open(torch_jpeg, 'rb') as f:
torch_bytes = f.read()
with open(torch_jpeg, 'rb') as f:
torch_bytes = f.read()

with open(pil_jpeg, 'rb') as f:
pil_bytes = f.read()
with open(pil_jpeg, 'rb') as f:
pil_bytes = f.read()

os.remove(torch_jpeg)
self.assertEqual(torch_bytes, pil_bytes)
self.assertEqual(torch_bytes, pil_bytes)

def test_decode_png(self):
conversion = [(None, ImageReadMode.UNCHANGED), ("L", ImageReadMode.GRAY), ("LA", ImageReadMode.GRAY_ALPHA),
Expand Down Expand Up @@ -216,20 +211,19 @@ def test_encode_png(self):
encode_png(torch.empty((5, 100, 100), dtype=torch.uint8))

def test_write_png(self):
for img_path in get_images(IMAGE_DIR, '.png'):
pil_image = Image.open(img_path)
img_pil = torch.from_numpy(np.array(pil_image))
img_pil = img_pil.permute(2, 0, 1)

basedir = os.path.dirname(img_path)
filename, _ = os.path.splitext(os.path.basename(img_path))
torch_png = os.path.join(basedir, '{0}_torch.png'.format(filename))
write_png(img_pil, torch_png, compression_level=6)
saved_image = torch.from_numpy(np.array(Image.open(torch_png)))
os.remove(torch_png)
saved_image = saved_image.permute(2, 0, 1)

self.assertTrue(img_pil.equal(saved_image))
with get_tmp_dir() as d:
for img_path in get_images(IMAGE_DIR, '.png'):
pil_image = Image.open(img_path)
img_pil = torch.from_numpy(np.array(pil_image))
img_pil = img_pil.permute(2, 0, 1)

filename, _ = os.path.splitext(os.path.basename(img_path))
torch_png = os.path.join(d, '{0}_torch.png'.format(filename))
write_png(img_pil, torch_png, compression_level=6)
saved_image = torch.from_numpy(np.array(Image.open(torch_png)))
saved_image = saved_image.permute(2, 0, 1)

self.assertTrue(img_pil.equal(saved_image))

def test_read_file(self):
with get_tmp_dir() as d:
Expand Down