@@ -215,53 +215,64 @@ def test_randomperspective_fill(self):
215
215
F .perspective (img_conv , startpoints , endpoints , fill = tuple ([fill ] * wrong_num_bands ))
216
216
217
217
def test_resize (self ):
218
- height = random .randint (24 , 32 ) * 2
219
- width = random .randint (24 , 32 ) * 2
220
- osize = random .randint (5 , 12 ) * 2
221
218
222
- # TODO: Check output size check for bug-fix, improve this later
223
- t = transforms .Resize (osize )
224
- self .assertTrue (isinstance (t .size , int ))
225
- self .assertEqual (t .size , osize )
219
+ input_sizes = [
220
+ # height, width
221
+ # square image
222
+ (28 , 28 ),
223
+ (27 , 27 ),
224
+ # rectangular image: h < w
225
+ (28 , 34 ),
226
+ (29 , 35 ),
227
+ # rectangular image: h > w
228
+ (34 , 28 ),
229
+ (35 , 29 ),
230
+ ]
231
+ test_output_sizes_1 = [
232
+ # single integer
233
+ 22 , 27 , 28 , 36 ,
234
+ # single integer in tuple/list
235
+ [22 , ], (27 , ),
236
+ ]
237
+ test_output_sizes_2 = [
238
+ # two integers
239
+ [22 , 22 ], [22 , 28 ], [22 , 36 ],
240
+ [27 , 22 ], [36 , 22 ], [28 , 28 ],
241
+ [28 , 37 ], [37 , 27 ], [37 , 37 ]
242
+ ]
243
+
244
+ for height , width in input_sizes :
245
+ img = Image .new ("RGB" , size = (width , height ), color = 127 )
246
+
247
+ for osize in test_output_sizes_1 :
248
+
249
+ t = transforms .Resize (osize )
250
+ result = t (img )
251
+
252
+ msg = "{}, {} - {}" .format (height , width , osize )
253
+ osize = osize [0 ] if isinstance (osize , (list , tuple )) else osize
254
+ # If size is an int, smaller edge of the image will be matched to this number.
255
+ # i.e, if height > width, then image will be rescaled to (size * height / width, size).
256
+ if height < width :
257
+ expected_size = (int (osize * width / height ), osize ) # (w, h)
258
+ self .assertEqual (result .size , expected_size , msg = msg )
259
+ elif width < height :
260
+ expected_size = (osize , int (osize * height / width )) # (w, h)
261
+ self .assertEqual (result .size , expected_size , msg = msg )
262
+ else :
263
+ expected_size = (osize , osize ) # (w, h)
264
+ self .assertEqual (result .size , expected_size , msg = msg )
226
265
227
- img = torch .ones (3 , height , width )
228
- result = transforms .Compose ([
229
- transforms .ToPILImage (),
230
- transforms .Resize (osize ),
231
- transforms .ToTensor (),
232
- ])(img )
233
- self .assertIn (osize , result .size ())
234
- if height < width :
235
- self .assertLessEqual (result .size (1 ), result .size (2 ))
236
- elif width < height :
237
- self .assertGreaterEqual (result .size (1 ), result .size (2 ))
266
+ for height , width in input_sizes :
267
+ img = Image .new ("RGB" , size = (width , height ), color = 127 )
238
268
239
- result = transforms .Compose ([
240
- transforms .ToPILImage (),
241
- transforms .Resize ([osize , osize ]),
242
- transforms .ToTensor (),
243
- ])(img )
244
- self .assertIn (osize , result .size ())
245
- self .assertEqual (result .size (1 ), osize )
246
- self .assertEqual (result .size (2 ), osize )
269
+ for osize in test_output_sizes_2 :
270
+ oheight , owidth = osize
247
271
248
- oheight = random .randint (5 , 12 ) * 2
249
- owidth = random .randint (5 , 12 ) * 2
250
- result = transforms .Compose ([
251
- transforms .ToPILImage (),
252
- transforms .Resize ((oheight , owidth )),
253
- transforms .ToTensor (),
254
- ])(img )
255
- self .assertEqual (result .size (1 ), oheight )
256
- self .assertEqual (result .size (2 ), owidth )
272
+ t = transforms .Resize (osize )
273
+ result = t (img )
257
274
258
- result = transforms .Compose ([
259
- transforms .ToPILImage (),
260
- transforms .Resize ([oheight , owidth ]),
261
- transforms .ToTensor (),
262
- ])(img )
263
- self .assertEqual (result .size (1 ), oheight )
264
- self .assertEqual (result .size (2 ), owidth )
275
+ self .assertEqual ((owidth , oheight ), result .size )
265
276
266
277
def test_random_crop (self ):
267
278
height = random .randint (10 , 32 ) * 2
0 commit comments