@@ -314,3 +314,79 @@ def resized_crop_image_pil(
314
314
) -> PIL .Image .Image :
315
315
img = crop_image_pil (img , top , left , height , width )
316
316
return resize_image_pil (img , size , interpolation = interpolation )
317
+
318
+
319
+ def _parse_five_crop_size (size : List [int ]) -> List [int ]:
320
+ if isinstance (size , numbers .Number ):
321
+ size = (int (size ), int (size ))
322
+ elif isinstance (size , (tuple , list )) and len (size ) == 1 :
323
+ size = (size [0 ], size [0 ]) # type: ignore[assignment]
324
+
325
+ if len (size ) != 2 :
326
+ raise ValueError ("Please provide only two dimensions (h, w) for size." )
327
+
328
+ return size
329
+
330
+
331
+ def five_crop_image_tensor (
332
+ img : torch .Tensor , size : List [int ]
333
+ ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor ]:
334
+ crop_height , crop_width = _parse_five_crop_size (size )
335
+ _ , image_height , image_width = get_dimensions_image_tensor (img )
336
+
337
+ if crop_width > image_width or crop_height > image_height :
338
+ msg = "Requested crop size {} is bigger than input size {}"
339
+ raise ValueError (msg .format (size , (image_height , image_width )))
340
+
341
+ tl = crop_image_tensor (img , 0 , 0 , crop_height , crop_width )
342
+ tr = crop_image_tensor (img , 0 , image_width - crop_width , crop_height , crop_width )
343
+ bl = crop_image_tensor (img , image_height - crop_height , 0 , crop_height , crop_width )
344
+ br = crop_image_tensor (img , image_height - crop_height , image_width - crop_width , crop_height , crop_width )
345
+ center = center_crop_image_tensor (img , [crop_height , crop_width ])
346
+
347
+ return tl , tr , bl , br , center
348
+
349
+
350
+ def five_crop_image_pil (
351
+ img : PIL .Image .Image , size : List [int ]
352
+ ) -> Tuple [PIL .Image .Image , PIL .Image .Image , PIL .Image .Image , PIL .Image .Image , PIL .Image .Image ]:
353
+ crop_height , crop_width = _parse_five_crop_size (size )
354
+ _ , image_height , image_width = get_dimensions_image_pil (img )
355
+
356
+ if crop_width > image_width or crop_height > image_height :
357
+ msg = "Requested crop size {} is bigger than input size {}"
358
+ raise ValueError (msg .format (size , (image_height , image_width )))
359
+
360
+ tl = crop_image_pil (img , 0 , 0 , crop_height , crop_width )
361
+ tr = crop_image_pil (img , 0 , image_width - crop_width , crop_height , crop_width )
362
+ bl = crop_image_pil (img , image_height - crop_height , 0 , crop_height , crop_width )
363
+ br = crop_image_pil (img , image_height - crop_height , image_width - crop_width , crop_height , crop_width )
364
+ center = center_crop_image_pil (img , [crop_height , crop_width ])
365
+
366
+ return tl , tr , bl , br , center
367
+
368
+
369
+ def ten_crop_image_tensor (img : torch .Tensor , size : List [int ], vertical_flip : bool = False ) -> List [torch .Tensor ]:
370
+ tl , tr , bl , br , center = five_crop_image_tensor (img , size )
371
+
372
+ if vertical_flip :
373
+ img = vertical_flip_image_tensor (img )
374
+ else :
375
+ img = horizontal_flip_image_tensor (img )
376
+
377
+ tl_flip , tr_flip , bl_flip , br_flip , center_flip = five_crop_image_tensor (img , size )
378
+
379
+ return [tl , tr , bl , br , center , tl_flip , tr_flip , bl_flip , br_flip , center_flip ]
380
+
381
+
382
+ def ten_crop_image_pil (img : PIL .Image .Image , size : List [int ], vertical_flip : bool = False ) -> List [PIL .Image .Image ]:
383
+ tl , tr , bl , br , center = five_crop_image_pil (img , size )
384
+
385
+ if vertical_flip :
386
+ img = vertical_flip_image_pil (img )
387
+ else :
388
+ img = horizontal_flip_image_pil (img )
389
+
390
+ tl_flip , tr_flip , bl_flip , br_flip , center_flip = five_crop_image_pil (img , size )
391
+
392
+ return [tl , tr , bl , br , center , tl_flip , tr_flip , bl_flip , br_flip , center_flip ]
0 commit comments