@@ -233,6 +233,8 @@ def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None:
233233 raise ValueError (f"Padding must be an int or a 1, 2, or 4 element tuple, not a { len (padding )} element tuple" )
234234
235235
236+ # TODO: let's use torchvision._utils.StrEnum to have the best of both worlds (strings and enums)
237+ # https://github.com/pytorch/vision/issues/6250
236238def _check_padding_mode_arg (padding_mode : Literal ["constant" , "edge" , "reflect" , "symmetric" ]) -> None :
237239 if padding_mode not in ["constant" , "edge" , "reflect" , "symmetric" ]:
238240 raise ValueError ("Padding mode should be either constant, edge, reflect or symmetric" )
@@ -437,18 +439,18 @@ def __init__(
437439
438440 self .size = _setup_size (size , error_msg = "Please provide only two dimensions (h, w) for size." )
439441
440- if padding is not None :
441- _check_padding_arg (padding )
442-
443- if (padding is not None ) or pad_if_needed :
444- _check_padding_mode_arg (padding_mode )
445- _check_fill_arg (fill )
446-
447442 self .padding = padding
448443 self .pad_if_needed = pad_if_needed
449444 self .fill = fill
450445 self .padding_mode = padding_mode
451446
447+ self ._pad_op = None
448+ if self .padding is not None :
449+ self ._pad_op = Pad (self .padding , fill = self .fill , padding_mode = self .padding_mode )
450+
451+ if self .pad_if_needed :
452+ self ._pad_op = Pad (0 , fill = self .fill , padding_mode = self .padding_mode )
453+
452454 def _get_params (self , sample : Any ) -> Dict [str , Any ]:
453455 image = query_image (sample )
454456 _ , height , width = get_image_dimensions (image )
@@ -466,34 +468,36 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
466468 left = torch .randint (0 , width - output_width + 1 , size = (1 ,)).item ()
467469 return dict (top = top , left = left , height = output_height , width = output_width )
468470
469- def _forward (self , flat_inputs : List [Any ]) -> List [Any ]:
470- if self .padding is not None :
471- flat_inputs = [F .pad (flat_input , self .padding , self .fill , self .padding_mode ) for flat_input in flat_inputs ]
472-
473- image = query_image (flat_inputs )
474- _ , height , width = get_image_dimensions (image )
471+ def _transform (self , inpt : Any , params : Dict [str , Any ]) -> Any :
472+ return F .crop (inpt , ** params )
475473
476- # pad the width if needed
477- if self .pad_if_needed and width < self .size [1 ]:
478- padding = [self .size [1 ] - width , 0 ]
479- flat_inputs = [F .pad (flat_input , padding , self .fill , self .padding_mode ) for flat_input in flat_inputs ]
480- # pad the height if needed
481- if self .pad_if_needed and height < self .size [0 ]:
482- padding = [0 , self .size [0 ] - height ]
483- flat_inputs = [F .pad (flat_input , padding , self .fill , self .padding_mode ) for flat_input in flat_inputs ]
474+ def forward (self , * inputs : Any ) -> Any :
475+ sample = inputs if len (inputs ) > 1 else inputs [0 ]
484476
485- params = self ._get_params (flat_inputs )
477+ if self ._pad_op is not None :
478+ sample = self ._pad_op (sample )
486479
487- return [F .crop (flat_input , ** params ) for flat_input in flat_inputs ]
480+ image = query_image (sample )
481+ _ , height , width = get_image_dimensions (image )
488482
489- def forward (self , * inputs : Any ) -> Any :
490- from torch .utils ._pytree import tree_flatten , tree_unflatten
483+ if self .pad_if_needed :
484+ # This check is to explicitly ensure that self._pad_op is defined
485+ if self ._pad_op is None :
486+ raise RuntimeError (
487+ "Internal error, self._pad_op is None. "
488+ "Please, fill an issue about that on https://github.com/pytorch/vision/issues"
489+ )
491490
492- sample = inputs if len (inputs ) > 1 else inputs [0 ]
491+ # pad the width if needed
492+ if width < self .size [1 ]:
493+ self ._pad_op .padding = [self .size [1 ] - width , 0 ]
494+ sample = self ._pad_op (sample )
495+ # pad the height if needed
496+ if height < self .size [0 ]:
497+ self ._pad_op .padding = [0 , self .size [0 ] - height ]
498+ sample = self ._pad_op (sample )
493499
494- flat_inputs , spec = tree_flatten (sample )
495- out_flat_inputs = self ._forward (flat_inputs )
496- return tree_unflatten (out_flat_inputs , spec )
500+ return super ().forward (sample )
497501
498502
499503class RandomPerspective (_RandomApplyTransform ):
0 commit comments