|
53 | 53 | "RandomAdjustSharpness", |
54 | 54 | "RandomAutocontrast", |
55 | 55 | "RandomEqualize", |
| 56 | + "ElasticTransform", |
56 | 57 | ] |
57 | 58 |
|
58 | 59 |
|
@@ -2049,3 +2050,117 @@ def forward(self, img): |
2049 | 2050 |
|
2050 | 2051 | def __repr__(self) -> str: |
2051 | 2052 | return f"{self.__class__.__name__}(p={self.p})" |
| 2053 | + |
| 2054 | + |
| 2055 | +class ElasticTransform(torch.nn.Module): |
| 2056 | + """Transform a tensor image with elastic transformations. |
| 2057 | + Given alpha and sigma, it will generate displacement |
| 2058 | + vectors for all pixels based on random offsets. Alpha controls the strength |
| 2059 | + and sigma controls the smoothness of the displacements. |
| 2060 | + The displacements are added to an identity grid and the resulting grid is |
| 2061 | + used to grid_sample from the image. |
| 2062 | +
|
| 2063 | + Applications: |
| 2064 | + Randomly transforms the morphology of objects in images and produces a |
| 2065 | + see-through-water-like effect. |
| 2066 | +
|
| 2067 | + Args: |
| 2068 | + alpha (float or sequence of floats): Magnitude of displacements. Default is 50.0. |
| 2069 | + sigma (float or sequence of floats): Smoothness of displacements. Default is 5.0. |
| 2070 | + interpolation (InterpolationMode): Desired interpolation enum defined by |
| 2071 | + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. |
| 2072 | + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. |
| 2073 | + For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable. |
| 2074 | + fill (sequence or number): Pixel fill value for the area outside the transformed |
| 2075 | + image. Default is ``0``. If given a number, the value is used for all bands respectively. |
| 2076 | +
|
| 2077 | + """ |
| 2078 | + |
| 2079 | + def __init__(self, alpha=50.0, sigma=5.0, interpolation=InterpolationMode.BILINEAR, fill=0): |
| 2080 | + super().__init__() |
| 2081 | + _log_api_usage_once(self) |
| 2082 | + if not isinstance(alpha, (float, Sequence)): |
| 2083 | + raise TypeError(f"alpha should be float or a sequence of floats. Got {type(alpha)}") |
| 2084 | + if isinstance(alpha, Sequence) and len(alpha) != 2: |
| 2085 | + raise ValueError(f"If alpha is a sequence its length should be 2. Got {len(alpha)}") |
| 2086 | + if isinstance(alpha, Sequence): |
| 2087 | + for element in alpha: |
| 2088 | + if not isinstance(element, float): |
| 2089 | + raise TypeError(f"alpha should be a sequence of floats. Got {type(element)}") |
| 2090 | + |
| 2091 | + if isinstance(alpha, float): |
| 2092 | + alpha = [float(alpha), float(alpha)] |
| 2093 | + if isinstance(alpha, (list, tuple)) and len(alpha) == 1: |
| 2094 | + alpha = [alpha[0], alpha[0]] |
| 2095 | + |
| 2096 | + self.alpha = alpha |
| 2097 | + |
| 2098 | + if not isinstance(sigma, (float, Sequence)): |
| 2099 | + raise TypeError(f"sigma should be float or a sequence of floats. Got {type(sigma)}") |
| 2100 | + if isinstance(sigma, Sequence) and len(sigma) != 2: |
| 2101 | + raise ValueError(f"If sigma is a sequence its length should be 2. Got {len(sigma)}") |
| 2102 | + if isinstance(sigma, Sequence): |
| 2103 | + for element in sigma: |
| 2104 | + if not isinstance(element, float): |
| 2105 | + raise TypeError(f"sigma should be a sequence of floats. Got {type(element)}") |
| 2106 | + |
| 2107 | + if isinstance(sigma, float): |
| 2108 | + sigma = [float(sigma), float(sigma)] |
| 2109 | + if isinstance(sigma, (list, tuple)) and len(sigma) == 1: |
| 2110 | + sigma = [sigma[0], sigma[0]] |
| 2111 | + |
| 2112 | + self.sigma = sigma |
| 2113 | + |
| 2114 | + # Backward compatibility with integer value |
| 2115 | + if isinstance(interpolation, int): |
| 2116 | + warnings.warn( |
| 2117 | + "Argument interpolation should be of type InterpolationMode instead of int. " |
| 2118 | + "Please, use InterpolationMode enum." |
| 2119 | + ) |
| 2120 | + interpolation = _interpolation_modes_from_int(interpolation) |
| 2121 | + self.interpolation = interpolation |
| 2122 | + |
| 2123 | + if not isinstance(fill, (int, float)): |
| 2124 | + raise TypeError(f"fill should be int or float. Got {type(fill)}") |
| 2125 | + self.fill = fill |
| 2126 | + |
| 2127 | + @staticmethod |
| 2128 | + def get_params(alpha: List[float], sigma: List[float], size: List[int]) -> Tensor: |
| 2129 | + dx = torch.rand([1, 1] + size) * 2 - 1 |
| 2130 | + if sigma[0] > 0.0: |
| 2131 | + kx = int(8 * sigma[0] + 1) |
| 2132 | + # if kernel size is even we have to make it odd |
| 2133 | + if kx % 2 == 0: |
| 2134 | + kx += 1 |
| 2135 | + dx = F.gaussian_blur(dx, [kx, kx], sigma) |
| 2136 | + dx = dx * alpha[0] / size[0] |
| 2137 | + |
| 2138 | + dy = torch.rand([1, 1] + size) * 2 - 1 |
| 2139 | + if sigma[1] > 0.0: |
| 2140 | + ky = int(8 * sigma[1] + 1) |
| 2141 | + # if kernel size is even we have to make it odd |
| 2142 | + if ky % 2 == 0: |
| 2143 | + ky += 1 |
| 2144 | + dy = F.gaussian_blur(dy, [ky, ky], sigma) |
| 2145 | + dy = dy * alpha[1] / size[1] |
| 2146 | + return torch.concat([dx, dy], 1).permute([0, 2, 3, 1]) # 1 x H x W x 2 |
| 2147 | + |
| 2148 | + def forward(self, tensor: Tensor) -> Tensor: |
| 2149 | + """ |
| 2150 | + Args: |
| 2151 | + img (PIL Image or Tensor): Image to be transformed. |
| 2152 | +
|
| 2153 | + Returns: |
| 2154 | + PIL Image or Tensor: Transformed image. |
| 2155 | + """ |
| 2156 | + size = F.get_image_size(tensor)[::-1] |
| 2157 | + displacement = self.get_params(self.alpha, self.sigma, size) |
| 2158 | + return F.elastic_transform(tensor, displacement, self.interpolation, self.fill) |
| 2159 | + |
| 2160 | + def __repr__(self): |
| 2161 | + format_string = self.__class__.__name__ + "(alpha=" |
| 2162 | + format_string += str(self.alpha) + ")" |
| 2163 | + format_string += ", (sigma=" + str(self.sigma) + ")" |
| 2164 | + format_string += ", interpolation={self.interpolation}" |
| 2165 | + format_string += ", fill={self.fill})" |
| 2166 | + return format_string |
0 commit comments