8
8
import torch
9
9
from PIL import Image
10
10
from common_utils import get_tmp_dir , needs_cuda
11
+ from _assert_utils import assert_equal
11
12
12
13
from torchvision .io .image import (
13
14
decode_png , decode_jpeg , encode_jpeg , write_jpeg , decode_image , read_file ,
@@ -107,7 +108,7 @@ def test_encode_jpeg(self):
107
108
for src_img in [img , img .contiguous ()]:
108
109
# PIL sets jpeg quality to 75 by default
109
110
jpeg_bytes = encode_jpeg (src_img , quality = 75 )
110
- self . assertTrue (jpeg_bytes . equal ( pil_bytes ) )
111
+ assert_equal (jpeg_bytes , pil_bytes )
111
112
112
113
with self .assertRaisesRegex (
113
114
RuntimeError , "Input tensor dtype should be uint8" ):
@@ -191,7 +192,7 @@ def test_encode_png(self):
191
192
rec_img = torch .from_numpy (np .array (rec_img ))
192
193
rec_img = rec_img .permute (2 , 0 , 1 )
193
194
194
- self . assertTrue (img_pil . equal ( rec_img ) )
195
+ assert_equal (img_pil , rec_img )
195
196
196
197
with self .assertRaisesRegex (
197
198
RuntimeError , "Input tensor dtype should be uint8" ):
@@ -224,7 +225,7 @@ def test_write_png(self):
224
225
saved_image = torch .from_numpy (np .array (Image .open (torch_png )))
225
226
saved_image = saved_image .permute (2 , 0 , 1 )
226
227
227
- self . assertTrue (img_pil . equal ( saved_image ) )
228
+ assert_equal (img_pil , saved_image )
228
229
229
230
def test_read_file (self ):
230
231
with get_tmp_dir () as d :
@@ -235,7 +236,7 @@ def test_read_file(self):
235
236
236
237
data = read_file (fpath )
237
238
expected = torch .tensor (list (content ), dtype = torch .uint8 )
238
- self . assertTrue (data . equal ( expected ) )
239
+ assert_equal (data , expected )
239
240
os .unlink (fpath )
240
241
241
242
with self .assertRaisesRegex (
@@ -251,7 +252,7 @@ def test_read_file_non_ascii(self):
251
252
252
253
data = read_file (fpath )
253
254
expected = torch .tensor (list (content ), dtype = torch .uint8 )
254
- self . assertTrue (data . equal ( expected ) )
255
+ assert_equal (data , expected )
255
256
os .unlink (fpath )
256
257
257
258
def test_write_file (self ):
0 commit comments