Skip to content

Commit aa4cf03

Browse files
authored
Improved test of Resize on PIL images (#2874)
1 parent 98146a1 commit aa4cf03

File tree

1 file changed

+53
-42
lines changed

1 file changed

+53
-42
lines changed

test/test_transforms.py

Lines changed: 53 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -215,53 +215,64 @@ def test_randomperspective_fill(self):
215215
F.perspective(img_conv, startpoints, endpoints, fill=tuple([fill] * wrong_num_bands))
216216

217217
def test_resize(self):
218-
height = random.randint(24, 32) * 2
219-
width = random.randint(24, 32) * 2
220-
osize = random.randint(5, 12) * 2
221218

222-
# TODO: Check output size check for bug-fix, improve this later
223-
t = transforms.Resize(osize)
224-
self.assertTrue(isinstance(t.size, int))
225-
self.assertEqual(t.size, osize)
219+
input_sizes = [
220+
# height, width
221+
# square image
222+
(28, 28),
223+
(27, 27),
224+
# rectangular image: h < w
225+
(28, 34),
226+
(29, 35),
227+
# rectangular image: h > w
228+
(34, 28),
229+
(35, 29),
230+
]
231+
test_output_sizes_1 = [
232+
# single integer
233+
22, 27, 28, 36,
234+
# single integer in tuple/list
235+
[22, ], (27, ),
236+
]
237+
test_output_sizes_2 = [
238+
# two integers
239+
[22, 22], [22, 28], [22, 36],
240+
[27, 22], [36, 22], [28, 28],
241+
[28, 37], [37, 27], [37, 37]
242+
]
243+
244+
for height, width in input_sizes:
245+
img = Image.new("RGB", size=(width, height), color=127)
246+
247+
for osize in test_output_sizes_1:
248+
249+
t = transforms.Resize(osize)
250+
result = t(img)
251+
252+
msg = "{}, {} - {}".format(height, width, osize)
253+
osize = osize[0] if isinstance(osize, (list, tuple)) else osize
254+
# If size is an int, smaller edge of the image will be matched to this number.
255+
# i.e, if height > width, then image will be rescaled to (size * height / width, size).
256+
if height < width:
257+
expected_size = (int(osize * width / height), osize) # (w, h)
258+
self.assertEqual(result.size, expected_size, msg=msg)
259+
elif width < height:
260+
expected_size = (osize, int(osize * height / width)) # (w, h)
261+
self.assertEqual(result.size, expected_size, msg=msg)
262+
else:
263+
expected_size = (osize, osize) # (w, h)
264+
self.assertEqual(result.size, expected_size, msg=msg)
226265

227-
img = torch.ones(3, height, width)
228-
result = transforms.Compose([
229-
transforms.ToPILImage(),
230-
transforms.Resize(osize),
231-
transforms.ToTensor(),
232-
])(img)
233-
self.assertIn(osize, result.size())
234-
if height < width:
235-
self.assertLessEqual(result.size(1), result.size(2))
236-
elif width < height:
237-
self.assertGreaterEqual(result.size(1), result.size(2))
266+
for height, width in input_sizes:
267+
img = Image.new("RGB", size=(width, height), color=127)
238268

239-
result = transforms.Compose([
240-
transforms.ToPILImage(),
241-
transforms.Resize([osize, osize]),
242-
transforms.ToTensor(),
243-
])(img)
244-
self.assertIn(osize, result.size())
245-
self.assertEqual(result.size(1), osize)
246-
self.assertEqual(result.size(2), osize)
269+
for osize in test_output_sizes_2:
270+
oheight, owidth = osize
247271

248-
oheight = random.randint(5, 12) * 2
249-
owidth = random.randint(5, 12) * 2
250-
result = transforms.Compose([
251-
transforms.ToPILImage(),
252-
transforms.Resize((oheight, owidth)),
253-
transforms.ToTensor(),
254-
])(img)
255-
self.assertEqual(result.size(1), oheight)
256-
self.assertEqual(result.size(2), owidth)
272+
t = transforms.Resize(osize)
273+
result = t(img)
257274

258-
result = transforms.Compose([
259-
transforms.ToPILImage(),
260-
transforms.Resize([oheight, owidth]),
261-
transforms.ToTensor(),
262-
])(img)
263-
self.assertEqual(result.size(1), oheight)
264-
self.assertEqual(result.size(2), owidth)
275+
self.assertEqual((owidth, oheight), result.size)
265276

266277
def test_random_crop(self):
267278
height = random.randint(10, 32) * 2

0 commit comments

Comments
 (0)