@@ -233,6 +233,8 @@ def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None:
233
233
raise ValueError (f"Padding must be an int or a 1, 2, or 4 element tuple, not a { len (padding )} element tuple" )
234
234
235
235
236
+ # TODO: let's use torchvision._utils.StrEnum to have the best of both worlds (strings and enums)
237
+ # https://github.com/pytorch/vision/issues/6250
236
238
def _check_padding_mode_arg (padding_mode : Literal ["constant" , "edge" , "reflect" , "symmetric" ]) -> None :
237
239
if padding_mode not in ["constant" , "edge" , "reflect" , "symmetric" ]:
238
240
raise ValueError ("Padding mode should be either constant, edge, reflect or symmetric" )
@@ -437,18 +439,18 @@ def __init__(
437
439
438
440
self .size = _setup_size (size , error_msg = "Please provide only two dimensions (h, w) for size." )
439
441
440
- if padding is not None :
441
- _check_padding_arg (padding )
442
-
443
- if (padding is not None ) or pad_if_needed :
444
- _check_padding_mode_arg (padding_mode )
445
- _check_fill_arg (fill )
446
-
447
442
self .padding = padding
448
443
self .pad_if_needed = pad_if_needed
449
444
self .fill = fill
450
445
self .padding_mode = padding_mode
451
446
447
+ self ._pad_op = None
448
+ if self .padding is not None :
449
+ self ._pad_op = Pad (self .padding , fill = self .fill , padding_mode = self .padding_mode )
450
+
451
+ if self .pad_if_needed :
452
+ self ._pad_op = Pad (0 , fill = self .fill , padding_mode = self .padding_mode )
453
+
452
454
def _get_params (self , sample : Any ) -> Dict [str , Any ]:
453
455
image = query_image (sample )
454
456
_ , height , width = get_image_dimensions (image )
@@ -466,34 +468,36 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
466
468
left = torch .randint (0 , width - output_width + 1 , size = (1 ,)).item ()
467
469
return dict (top = top , left = left , height = output_height , width = output_width )
468
470
469
- def _forward (self , flat_inputs : List [Any ]) -> List [Any ]:
470
- if self .padding is not None :
471
- flat_inputs = [F .pad (flat_input , self .padding , self .fill , self .padding_mode ) for flat_input in flat_inputs ]
472
-
473
- image = query_image (flat_inputs )
474
- _ , height , width = get_image_dimensions (image )
471
+ def _transform (self , inpt : Any , params : Dict [str , Any ]) -> Any :
472
+ return F .crop (inpt , ** params )
475
473
476
- # pad the width if needed
477
- if self .pad_if_needed and width < self .size [1 ]:
478
- padding = [self .size [1 ] - width , 0 ]
479
- flat_inputs = [F .pad (flat_input , padding , self .fill , self .padding_mode ) for flat_input in flat_inputs ]
480
- # pad the height if needed
481
- if self .pad_if_needed and height < self .size [0 ]:
482
- padding = [0 , self .size [0 ] - height ]
483
- flat_inputs = [F .pad (flat_input , padding , self .fill , self .padding_mode ) for flat_input in flat_inputs ]
474
+ def forward (self , * inputs : Any ) -> Any :
475
+ sample = inputs if len (inputs ) > 1 else inputs [0 ]
484
476
485
- params = self ._get_params (flat_inputs )
477
+ if self ._pad_op is not None :
478
+ sample = self ._pad_op (sample )
486
479
487
- return [F .crop (flat_input , ** params ) for flat_input in flat_inputs ]
480
+ image = query_image (sample )
481
+ _ , height , width = get_image_dimensions (image )
488
482
489
- def forward (self , * inputs : Any ) -> Any :
490
- from torch .utils ._pytree import tree_flatten , tree_unflatten
483
+ if self .pad_if_needed :
484
+ # This check is to explicitly ensure that self._pad_op is defined
485
+ if self ._pad_op is None :
486
+ raise RuntimeError (
487
+ "Internal error, self._pad_op is None. "
488
+ "Please, fill an issue about that on https://github.com/pytorch/vision/issues"
489
+ )
491
490
492
- sample = inputs if len (inputs ) > 1 else inputs [0 ]
491
+ # pad the width if needed
492
+ if width < self .size [1 ]:
493
+ self ._pad_op .padding = [self .size [1 ] - width , 0 ]
494
+ sample = self ._pad_op (sample )
495
+ # pad the height if needed
496
+ if height < self .size [0 ]:
497
+ self ._pad_op .padding = [0 , self .size [0 ] - height ]
498
+ sample = self ._pad_op (sample )
493
499
494
- flat_inputs , spec = tree_flatten (sample )
495
- out_flat_inputs = self ._forward (flat_inputs )
496
- return tree_unflatten (out_flat_inputs , spec )
500
+ return super ().forward (sample )
497
501
498
502
499
503
class RandomPerspective (_RandomApplyTransform ):
0 commit comments