Skip to content

Commit 8e24479

Browse files
authored
Refactors test_image.py so tests don't write files to assets folder (#3018)
* Fix writing to files by using get_tmp_dir() * Add ImageReadMode to imports * Fix failing test due to incorrect image path
1 parent 4ab46e5 commit 8e24479

File tree

1 file changed

+41
-47
lines changed

1 file changed

+41
-47
lines changed

test/test_image.py

Lines changed: 41 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
1-
import os
2-
import io
31
import glob
2+
import io
3+
import os
44
import unittest
55

6+
import numpy as np
67
import torch
78
from PIL import Image
9+
from common_utils import get_tmp_dir
10+
811
from torchvision.io.image import (
912
decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file,
1013
encode_png, write_png, write_file, ImageReadMode)
11-
import numpy as np
12-
13-
from common_utils import get_tmp_dir
14-
1514

1615
IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
1716
FAKEDATA_DIR = os.path.join(IMAGE_ROOT, "fakedata")
@@ -22,14 +21,10 @@
2221

2322
def get_images(directory, img_ext):
2423
assert os.path.isdir(directory)
25-
for root, _, files in os.walk(directory):
26-
if os.path.basename(root) in {'damaged_jpeg', 'jpeg_write'}:
27-
continue
28-
29-
for fl in files:
30-
_, ext = os.path.splitext(fl)
31-
if ext == img_ext:
32-
yield os.path.join(root, fl)
24+
image_paths = glob.glob(directory + f'/**/*{img_ext}', recursive=True)
25+
for path in image_paths:
26+
if path.split(os.sep)[-2] not in ['damaged_jpeg', 'jpeg_write']:
27+
yield path
3328

3429

3530
def pil_read_image(img_path):
@@ -75,7 +70,7 @@ def test_decode_jpeg(self):
7570
decode_jpeg(torch.empty((100, 1), dtype=torch.uint8))
7671

7772
with self.assertRaisesRegex(RuntimeError, "Expected a torch.uint8 tensor"):
78-
decode_jpeg(torch.empty((100, ), dtype=torch.float16))
73+
decode_jpeg(torch.empty((100,), dtype=torch.float16))
7974

8075
with self.assertRaises(RuntimeError):
8176
decode_jpeg(torch.empty((100), dtype=torch.uint8))
@@ -119,12 +114,12 @@ def test_encode_jpeg(self):
119114

120115
with self.assertRaisesRegex(
121116
ValueError, "Image quality should be a positive number "
122-
"between 1 and 100"):
117+
"between 1 and 100"):
123118
encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=-1)
124119

125120
with self.assertRaisesRegex(
126121
ValueError, "Image quality should be a positive number "
127-
"between 1 and 100"):
122+
"between 1 and 100"):
128123
encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=101)
129124

130125
with self.assertRaisesRegex(
@@ -140,27 +135,27 @@ def test_encode_jpeg(self):
140135
encode_jpeg(torch.empty((100, 100), dtype=torch.uint8))
141136

142137
def test_write_jpeg(self):
143-
for img_path in get_images(ENCODE_JPEG, ".jpg"):
144-
data = read_file(img_path)
145-
img = decode_jpeg(data)
138+
with get_tmp_dir() as d:
139+
for img_path in get_images(ENCODE_JPEG, ".jpg"):
140+
data = read_file(img_path)
141+
img = decode_jpeg(data)
146142

147-
basedir = os.path.dirname(img_path)
148-
filename, _ = os.path.splitext(os.path.basename(img_path))
149-
torch_jpeg = os.path.join(
150-
basedir, '{0}_torch.jpg'.format(filename))
151-
pil_jpeg = os.path.join(
152-
basedir, 'jpeg_write', '{0}_pil.jpg'.format(filename))
143+
basedir = os.path.dirname(img_path)
144+
filename, _ = os.path.splitext(os.path.basename(img_path))
145+
torch_jpeg = os.path.join(
146+
d, '{0}_torch.jpg'.format(filename))
147+
pil_jpeg = os.path.join(
148+
basedir, 'jpeg_write', '{0}_pil.jpg'.format(filename))
153149

154-
write_jpeg(img, torch_jpeg, quality=75)
150+
write_jpeg(img, torch_jpeg, quality=75)
155151

156-
with open(torch_jpeg, 'rb') as f:
157-
torch_bytes = f.read()
152+
with open(torch_jpeg, 'rb') as f:
153+
torch_bytes = f.read()
158154

159-
with open(pil_jpeg, 'rb') as f:
160-
pil_bytes = f.read()
155+
with open(pil_jpeg, 'rb') as f:
156+
pil_bytes = f.read()
161157

162-
os.remove(torch_jpeg)
163-
self.assertEqual(torch_bytes, pil_bytes)
158+
self.assertEqual(torch_bytes, pil_bytes)
164159

165160
def test_decode_png(self):
166161
conversion = [(None, ImageReadMode.UNCHANGED), ("L", ImageReadMode.GRAY), ("LA", ImageReadMode.GRAY_ALPHA),
@@ -216,20 +211,19 @@ def test_encode_png(self):
216211
encode_png(torch.empty((5, 100, 100), dtype=torch.uint8))
217212

218213
def test_write_png(self):
219-
for img_path in get_images(IMAGE_DIR, '.png'):
220-
pil_image = Image.open(img_path)
221-
img_pil = torch.from_numpy(np.array(pil_image))
222-
img_pil = img_pil.permute(2, 0, 1)
223-
224-
basedir = os.path.dirname(img_path)
225-
filename, _ = os.path.splitext(os.path.basename(img_path))
226-
torch_png = os.path.join(basedir, '{0}_torch.png'.format(filename))
227-
write_png(img_pil, torch_png, compression_level=6)
228-
saved_image = torch.from_numpy(np.array(Image.open(torch_png)))
229-
os.remove(torch_png)
230-
saved_image = saved_image.permute(2, 0, 1)
231-
232-
self.assertTrue(img_pil.equal(saved_image))
214+
with get_tmp_dir() as d:
215+
for img_path in get_images(IMAGE_DIR, '.png'):
216+
pil_image = Image.open(img_path)
217+
img_pil = torch.from_numpy(np.array(pil_image))
218+
img_pil = img_pil.permute(2, 0, 1)
219+
220+
filename, _ = os.path.splitext(os.path.basename(img_path))
221+
torch_png = os.path.join(d, '{0}_torch.png'.format(filename))
222+
write_png(img_pil, torch_png, compression_level=6)
223+
saved_image = torch.from_numpy(np.array(Image.open(torch_png)))
224+
saved_image = saved_image.permute(2, 0, 1)
225+
226+
self.assertTrue(img_pil.equal(saved_image))
233227

234228
def test_read_file(self):
235229
with get_tmp_dir() as d:

0 commit comments

Comments
 (0)