1
- import os
2
- import io
3
1
import glob
2
+ import io
3
+ import os
4
4
import unittest
5
5
6
+ import numpy as np
6
7
import torch
7
8
from PIL import Image
9
+ from common_utils import get_tmp_dir
10
+
8
11
from torchvision .io .image import (
9
12
decode_png , decode_jpeg , encode_jpeg , write_jpeg , decode_image , read_file ,
10
13
encode_png , write_png , write_file , ImageReadMode )
11
- import numpy as np
12
-
13
- from common_utils import get_tmp_dir
14
-
15
14
16
15
IMAGE_ROOT = os .path .join (os .path .dirname (os .path .abspath (__file__ )), "assets" )
17
16
FAKEDATA_DIR = os .path .join (IMAGE_ROOT , "fakedata" )
22
21
23
22
def get_images (directory , img_ext ):
24
23
assert os .path .isdir (directory )
25
- for root , _ , files in os .walk (directory ):
26
- if os .path .basename (root ) in {'damaged_jpeg' , 'jpeg_write' }:
27
- continue
28
-
29
- for fl in files :
30
- _ , ext = os .path .splitext (fl )
31
- if ext == img_ext :
32
- yield os .path .join (root , fl )
24
+ image_paths = glob .glob (directory + f'/**/*{ img_ext } ' , recursive = True )
25
+ for path in image_paths :
26
+ if path .split (os .sep )[- 2 ] not in ['damaged_jpeg' , 'jpeg_write' ]:
27
+ yield path
33
28
34
29
35
30
def pil_read_image (img_path ):
@@ -75,7 +70,7 @@ def test_decode_jpeg(self):
75
70
decode_jpeg (torch .empty ((100 , 1 ), dtype = torch .uint8 ))
76
71
77
72
with self .assertRaisesRegex (RuntimeError , "Expected a torch.uint8 tensor" ):
78
- decode_jpeg (torch .empty ((100 , ), dtype = torch .float16 ))
73
+ decode_jpeg (torch .empty ((100 ,), dtype = torch .float16 ))
79
74
80
75
with self .assertRaises (RuntimeError ):
81
76
decode_jpeg (torch .empty ((100 ), dtype = torch .uint8 ))
@@ -119,12 +114,12 @@ def test_encode_jpeg(self):
119
114
120
115
with self .assertRaisesRegex (
121
116
ValueError , "Image quality should be a positive number "
122
- "between 1 and 100" ):
117
+ "between 1 and 100" ):
123
118
encode_jpeg (torch .empty ((3 , 100 , 100 ), dtype = torch .uint8 ), quality = - 1 )
124
119
125
120
with self .assertRaisesRegex (
126
121
ValueError , "Image quality should be a positive number "
127
- "between 1 and 100" ):
122
+ "between 1 and 100" ):
128
123
encode_jpeg (torch .empty ((3 , 100 , 100 ), dtype = torch .uint8 ), quality = 101 )
129
124
130
125
with self .assertRaisesRegex (
@@ -140,27 +135,27 @@ def test_encode_jpeg(self):
140
135
encode_jpeg (torch .empty ((100 , 100 ), dtype = torch .uint8 ))
141
136
142
137
def test_write_jpeg (self ):
143
- for img_path in get_images (ENCODE_JPEG , ".jpg" ):
144
- data = read_file (img_path )
145
- img = decode_jpeg (data )
138
+ with get_tmp_dir () as d :
139
+ for img_path in get_images (ENCODE_JPEG , ".jpg" ):
140
+ data = read_file (img_path )
141
+ img = decode_jpeg (data )
146
142
147
- basedir = os .path .dirname (img_path )
148
- filename , _ = os .path .splitext (os .path .basename (img_path ))
149
- torch_jpeg = os .path .join (
150
- basedir , '{0}_torch.jpg' .format (filename ))
151
- pil_jpeg = os .path .join (
152
- basedir , 'jpeg_write' , '{0}_pil.jpg' .format (filename ))
143
+ basedir = os .path .dirname (img_path )
144
+ filename , _ = os .path .splitext (os .path .basename (img_path ))
145
+ torch_jpeg = os .path .join (
146
+ d , '{0}_torch.jpg' .format (filename ))
147
+ pil_jpeg = os .path .join (
148
+ basedir , 'jpeg_write' , '{0}_pil.jpg' .format (filename ))
153
149
154
- write_jpeg (img , torch_jpeg , quality = 75 )
150
+ write_jpeg (img , torch_jpeg , quality = 75 )
155
151
156
- with open (torch_jpeg , 'rb' ) as f :
157
- torch_bytes = f .read ()
152
+ with open (torch_jpeg , 'rb' ) as f :
153
+ torch_bytes = f .read ()
158
154
159
- with open (pil_jpeg , 'rb' ) as f :
160
- pil_bytes = f .read ()
155
+ with open (pil_jpeg , 'rb' ) as f :
156
+ pil_bytes = f .read ()
161
157
162
- os .remove (torch_jpeg )
163
- self .assertEqual (torch_bytes , pil_bytes )
158
+ self .assertEqual (torch_bytes , pil_bytes )
164
159
165
160
def test_decode_png (self ):
166
161
conversion = [(None , ImageReadMode .UNCHANGED ), ("L" , ImageReadMode .GRAY ), ("LA" , ImageReadMode .GRAY_ALPHA ),
@@ -216,20 +211,19 @@ def test_encode_png(self):
216
211
encode_png (torch .empty ((5 , 100 , 100 ), dtype = torch .uint8 ))
217
212
218
213
def test_write_png (self ):
219
- for img_path in get_images (IMAGE_DIR , '.png' ):
220
- pil_image = Image .open (img_path )
221
- img_pil = torch .from_numpy (np .array (pil_image ))
222
- img_pil = img_pil .permute (2 , 0 , 1 )
223
-
224
- basedir = os .path .dirname (img_path )
225
- filename , _ = os .path .splitext (os .path .basename (img_path ))
226
- torch_png = os .path .join (basedir , '{0}_torch.png' .format (filename ))
227
- write_png (img_pil , torch_png , compression_level = 6 )
228
- saved_image = torch .from_numpy (np .array (Image .open (torch_png )))
229
- os .remove (torch_png )
230
- saved_image = saved_image .permute (2 , 0 , 1 )
231
-
232
- self .assertTrue (img_pil .equal (saved_image ))
214
+ with get_tmp_dir () as d :
215
+ for img_path in get_images (IMAGE_DIR , '.png' ):
216
+ pil_image = Image .open (img_path )
217
+ img_pil = torch .from_numpy (np .array (pil_image ))
218
+ img_pil = img_pil .permute (2 , 0 , 1 )
219
+
220
+ filename , _ = os .path .splitext (os .path .basename (img_path ))
221
+ torch_png = os .path .join (d , '{0}_torch.png' .format (filename ))
222
+ write_png (img_pil , torch_png , compression_level = 6 )
223
+ saved_image = torch .from_numpy (np .array (Image .open (torch_png )))
224
+ saved_image = saved_image .permute (2 , 0 , 1 )
225
+
226
+ self .assertTrue (img_pil .equal (saved_image ))
233
227
234
228
def test_read_file (self ):
235
229
with get_tmp_dir () as d :
0 commit comments