diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 512ffec0abb..8d5cc24d25a 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -22,7 +22,8 @@ def __init__( self.crop_height = size[0] self.crop_width = size[1] - self.fill = _setup_fill_arg(fill) + self.fill = fill + self._fill = _setup_fill_arg(fill) self.padding_mode = padding_mode @@ -118,7 +119,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: ) if params["needs_pad"]: - fill = self.fill[type(inpt)] + fill = self._fill[type(inpt)] inpt = F.pad(inpt, params["padding"], fill=fill, padding_mode=self.padding_mode) return inpt diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index d8ab0bb2410..f1eed87b9c0 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -255,9 +255,7 @@ def _extract_params_for_v1_transform(self) -> Dict[str, Any]: params = super()._extract_params_for_v1_transform() if not (params["fill"] is None or isinstance(params["fill"], (int, float))): - raise ValueError( - f"{type(self.__name__)}() can only be scripted for a scalar `fill`, but got {self.fill} for images." - ) + raise ValueError(f"{type(self).__name__}() can only be scripted for a scalar `fill`, but got {self.fill}.") return params @@ -276,11 +274,12 @@ def __init__( if not isinstance(padding, int): padding = list(padding) self.padding = padding - self.fill = _setup_fill_arg(fill) + self.fill = fill + self._fill = _setup_fill_arg(fill) self.padding_mode = padding_mode def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - fill = self.fill[type(inpt)] + fill = self._fill[type(inpt)] return F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type] @@ -293,7 +292,8 @@ def __init__( ) -> None: super().__init__(p=p) - self.fill = _setup_fill_arg(fill) + self.fill = fill + self._fill = _setup_fill_arg(fill) _check_sequence_input(side_range, "side_range", req_sizes=(2,)) @@ -318,7 +318,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(padding=padding) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - fill = self.fill[type(inpt)] + fill = self._fill[type(inpt)] return F.pad(inpt, **params, fill=fill) @@ -338,7 +338,8 @@ def __init__( self.interpolation = _check_interpolation(interpolation) self.expand = expand - self.fill = _setup_fill_arg(fill) + self.fill = fill + self._fill = _setup_fill_arg(fill) if center is not None: _check_sequence_input(center, "center", req_sizes=(2,)) @@ -350,7 +351,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(angle=angle) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - fill = self.fill[type(inpt)] + fill = self._fill[type(inpt)] return F.rotate( inpt, **params, @@ -395,7 +396,8 @@ def __init__( self.shear = shear self.interpolation = _check_interpolation(interpolation) - self.fill = _setup_fill_arg(fill) + self.fill = fill + self._fill = _setup_fill_arg(fill) if center is not None: _check_sequence_input(center, "center", req_sizes=(2,)) @@ -430,7 +432,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(angle=angle, translate=translate, scale=scale, shear=shear) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - fill = self.fill[type(inpt)] + fill = self._fill[type(inpt)] return F.affine( inpt, **params, @@ -447,9 +449,7 @@ def _extract_params_for_v1_transform(self) -> Dict[str, Any]: params = super()._extract_params_for_v1_transform() if not (params["fill"] is None or isinstance(params["fill"], (int, float))): - raise ValueError( - f"{type(self.__name__)}() can only be scripted for a scalar `fill`, but got {self.fill} for images." - ) + raise ValueError(f"{type(self).__name__}() can only be scripted for a scalar `fill`, but got {self.fill}.") padding = self.padding if padding is not None: @@ -478,7 +478,8 @@ def __init__( self.padding = F._geometry._parse_pad_padding(padding) if padding else None # type: ignore[arg-type] self.pad_if_needed = pad_if_needed - self.fill = _setup_fill_arg(fill) + self.fill = fill + self._fill = _setup_fill_arg(fill) self.padding_mode = padding_mode def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: @@ -541,7 +542,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if params["needs_pad"]: - fill = self.fill[type(inpt)] + fill = self._fill[type(inpt)] inpt = F.pad(inpt, padding=params["padding"], fill=fill, padding_mode=self.padding_mode) if params["needs_crop"]: @@ -567,7 +568,8 @@ def __init__( self.distortion_scale = distortion_scale self.interpolation = _check_interpolation(interpolation) - self.fill = _setup_fill_arg(fill) + self.fill = fill + self._fill = _setup_fill_arg(fill) def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: height, width = query_spatial_size(flat_inputs) @@ -600,7 +602,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(coefficients=perspective_coeffs) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - fill = self.fill[type(inpt)] + fill = self._fill[type(inpt)] return F.perspective( inpt, None, @@ -626,7 +628,8 @@ def __init__( self.sigma = _setup_float_or_seq(sigma, "sigma", 2) self.interpolation = _check_interpolation(interpolation) - self.fill = _setup_fill_arg(fill) + self.fill = fill + self._fill = _setup_fill_arg(fill) def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: size = list(query_spatial_size(flat_inputs)) @@ -652,7 +655,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(displacement=displacement) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - fill = self.fill[type(inpt)] + fill = self._fill[type(inpt)] return F.elastic( inpt, **params, diff --git a/torchvision/transforms/v2/_transform.py b/torchvision/transforms/v2/_transform.py index 3f92b3c1646..f83ed5d6e11 100644 --- a/torchvision/transforms/v2/_transform.py +++ b/torchvision/transforms/v2/_transform.py @@ -108,30 +108,17 @@ def __init_subclass__(cls) -> None: def _extract_params_for_v1_transform(self) -> Dict[str, Any]: # This method is called by `__prepare_scriptable__` to instantiate the equivalent v1 transform from the current - # v2 transform instance. It does two things: - # 1. Extract all available public attributes that are specific to that transform and not `nn.Module` in general - # 2. If available handle the `fill` attribute for v1 compatibility (see below for details) + # v2 transform instance. It extracts all available public attributes that are specific to that transform and + # not `nn.Module` in general. # Overwrite this method on the v2 transform class if the above is not sufficient. For example, this might happen # if the v2 transform introduced new parameters that are not support by the v1 transform. common_attrs = nn.Module().__dict__.keys() - params = { + return { attr: value for attr, value in self.__dict__.items() if not attr.startswith("_") and attr not in common_attrs } - # transforms v2 has a more complex handling for the `fill` parameter than v1. By default, the input is parsed - # with `prototype.transforms._utils._setup_fill_arg()`, which returns a defaultdict that holds the fill value - # for the different datapoint types. Below we extract the value for tensors and return that together with the - # other params. - # This is needed for `Pad`, `ElasticTransform`, `RandomAffine`, `RandomCrop`, `RandomPerspective` and - # `RandomRotation` - if "fill" in params: - fill_type_defaultdict = params.pop("fill") - params["fill"] = fill_type_defaultdict[torch.Tensor] - - return params - def __prepare_scriptable__(self) -> nn.Module: # This method is called early on when `torch.jit.script`'ing an `nn.Module` instance. If it succeeds, the return # value is used for scripting over the original object that should have been scripted. Since the v1 transforms