Skip to content

Commit 195bb86

Browse files
NicolasHugpmeier
andauthored
Use torch.testing.assert_close in test_image.py (#3877)
Co-authored-by: Philip Meier <[email protected]>
1 parent 05e061f commit 195bb86

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

test/test_image.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch
99
from PIL import Image
1010
from common_utils import get_tmp_dir, needs_cuda
11+
from _assert_utils import assert_equal
1112

1213
from torchvision.io.image import (
1314
decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file,
@@ -107,7 +108,7 @@ def test_encode_jpeg(self):
107108
for src_img in [img, img.contiguous()]:
108109
# PIL sets jpeg quality to 75 by default
109110
jpeg_bytes = encode_jpeg(src_img, quality=75)
110-
self.assertTrue(jpeg_bytes.equal(pil_bytes))
111+
assert_equal(jpeg_bytes, pil_bytes)
111112

112113
with self.assertRaisesRegex(
113114
RuntimeError, "Input tensor dtype should be uint8"):
@@ -191,7 +192,7 @@ def test_encode_png(self):
191192
rec_img = torch.from_numpy(np.array(rec_img))
192193
rec_img = rec_img.permute(2, 0, 1)
193194

194-
self.assertTrue(img_pil.equal(rec_img))
195+
assert_equal(img_pil, rec_img)
195196

196197
with self.assertRaisesRegex(
197198
RuntimeError, "Input tensor dtype should be uint8"):
@@ -224,7 +225,7 @@ def test_write_png(self):
224225
saved_image = torch.from_numpy(np.array(Image.open(torch_png)))
225226
saved_image = saved_image.permute(2, 0, 1)
226227

227-
self.assertTrue(img_pil.equal(saved_image))
228+
assert_equal(img_pil, saved_image)
228229

229230
def test_read_file(self):
230231
with get_tmp_dir() as d:
@@ -235,7 +236,7 @@ def test_read_file(self):
235236

236237
data = read_file(fpath)
237238
expected = torch.tensor(list(content), dtype=torch.uint8)
238-
self.assertTrue(data.equal(expected))
239+
assert_equal(data, expected)
239240
os.unlink(fpath)
240241

241242
with self.assertRaisesRegex(
@@ -251,7 +252,7 @@ def test_read_file_non_ascii(self):
251252

252253
data = read_file(fpath)
253254
expected = torch.tensor(list(content), dtype=torch.uint8)
254-
self.assertTrue(data.equal(expected))
255+
assert_equal(data, expected)
255256
os.unlink(fpath)
256257

257258
def test_write_file(self):

0 commit comments

Comments
 (0)