Skip to content

Commit 9430be7

Browse files
JLrumbergervfdev-5
andauthored
Added elastic transform in torchvision.transforms (#4938)
* Added elastic augment * ufmt formatting * updated comments * fixed circular dependency issue and bare except error * Fixed three type checking errors in functional_tensor.py * ufmt formatted * changed elastic_deformation to a more common implementation Implementation uses alpha and sigma to control strength and smoothness of the displacement vectors in elastic_deformation instead of control_point_spacings and sigma. * ufmt formatting * Some performance updates Put random offset vectors to device before gaussian_blur is applied speeds it up 3-fold. * fixed type error * fixed again a type error * Update torchvision/transforms/functional_tensor.py Co-authored-by: vfdev <[email protected]> * Added some requested changes - pil image support similar to GaussianBlur - changed interpolation arg to InterpolationMode - added a wrapper in torchvision.transforms.functional.py that gets called by the class in transforms.py -renamed it to ElasticTransform - handled sigma = 0 case * added img docstring * added some tests * Updated tests and the code * Added the requested changes to the arguments of F.elastic_transform Added random_state and displacement as arguments to F.elastic_transform * fixed the type error * Fixed tests and docs * implemented requested changes Changes: 1) alpha AND sigma OR displacement must be given as arguments to transforms.functional_tensor.elastic_transform instead of alpha AND sigma AND displacement 2) displacements are accepted in transforms.functional.elastic_transform as np.array and torch.Tensor instead of only accepting torch.Tensor * ufmt formatting * trochscript error resolved replaced torch.from_numpy() to torch.Tensor() to make it compatible to torchscript * revert to torch.from_numpy() * updated argument checks and errors - In F.elastic_transform added check to see if both user inputs img and displacement are either of type PIL Image and ndarray or both of type tensor. - In F_t.elastic_transform added check if alpha and sigma are None if displacement is given or vice versa. * fixed seed error changed torch.seed to torch.manual_seed in F_t.elastic_transform * Reverted displacement type and other cosmetics * Other minor improvements * changed gaussian_blur filter size changed gaussian_blur filter size from 4 * int(sigma) + 1 to int(8 * sigma + 1) to make it consistent with ernestums implementation * resolved merge error * Revert "resolved merge error" This reverts commit 6a4a4e7. * resolve merge error * ufmt formatted * ufmt formated once again.. * fixed unsupported operand error * Update API and removed random_state from functional part * Added default values * Added ElasticTransform to gallery and updated the docstring * Updated gallery and added _log_api_usage_once BTW, matplotlib.pylab is deprecated * Updated gallery transforms code * Updates according to review Co-authored-by: vfdev <[email protected]>
1 parent 8e0f691 commit 9430be7

File tree

7 files changed

+287
-1
lines changed

7 files changed

+287
-1
lines changed

gallery/plot_transforms.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,17 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs):
149149
affine_imgs = [affine_transfomer(orig_img) for _ in range(4)]
150150
plot(affine_imgs)
151151

152+
####################################
153+
# ElasticTransform
154+
# ~~~~~~~~~~~~~~~~
155+
# The :class:`~torchvision.transforms.ElasticTransform` transform
156+
# (see also :func:`~torchvision.transforms.functional.elastic_transform`)
157+
# Randomly transforms the morphology of objects in images and produces a
158+
# see-through-water-like effect.
159+
elastic_transformer = T.ElasticTransform(alpha=250.0)
160+
transformed_imgs = [elastic_transformer(orig_img) for _ in range(2)]
161+
plot(transformed_imgs)
162+
152163
####################################
153164
# RandomCrop
154165
# ~~~~~~~~~~

gallery/plot_video_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ def __iter__(self):
325325
# ----------------------------------
326326
# Example of visualized video
327327

328-
import matplotlib.pylab as plt
328+
import matplotlib.pyplot as plt
329329

330330
plt.figure(figsize=(12, 12))
331331
for i in range(16):

test/test_functional_tensor.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1363,5 +1363,44 @@ def test_ten_crop(device):
13631363
assert_equal(transformed_batch, s_transformed_batch)
13641364

13651365

1366+
@pytest.mark.parametrize("device", cpu_and_gpu())
1367+
@pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR, BICUBIC])
1368+
@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
1369+
@pytest.mark.parametrize(
1370+
"fill",
1371+
[
1372+
None,
1373+
[255, 255, 255],
1374+
(2.0,),
1375+
],
1376+
)
1377+
def test_elastic_transform_consistency(device, interpolation, dt, fill):
1378+
script_elastic_transform = torch.jit.script(F.elastic_transform)
1379+
img_tensor, _ = _create_data(32, 34, device=device)
1380+
# As there is no PIL implementation for elastic_transform,
1381+
# thus we do not run tests tensor vs pillow
1382+
1383+
if dt is not None:
1384+
img_tensor = img_tensor.to(dt)
1385+
1386+
displacement = T.ElasticTransform.get_params([1.5, 1.5], [2.0, 2.0], [32, 34])
1387+
kwargs = dict(
1388+
displacement=displacement,
1389+
interpolation=interpolation,
1390+
fill=fill,
1391+
)
1392+
1393+
out_tensor1 = F.elastic_transform(img_tensor, **kwargs)
1394+
out_tensor2 = script_elastic_transform(img_tensor, **kwargs)
1395+
assert_equal(out_tensor1, out_tensor2)
1396+
1397+
batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device)
1398+
displacement = T.ElasticTransform.get_params([1.5, 1.5], [2.0, 2.0], [16, 18])
1399+
kwargs["displacement"] = displacement
1400+
if dt is not None:
1401+
batch_tensors = batch_tensors.to(dt)
1402+
_test_fn_on_batch(batch_tensors, F.elastic_transform, **kwargs)
1403+
1404+
13661405
if __name__ == "__main__":
13671406
pytest.main([__file__])

test/test_transforms.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2250,5 +2250,42 @@ def test_random_affine():
22502250
assert t.interpolation == transforms.InterpolationMode.BILINEAR
22512251

22522252

2253+
def test_elastic_transformation():
2254+
with pytest.raises(TypeError, match=r"alpha should be float or a sequence of floats"):
2255+
transforms.ElasticTransform(alpha=True, sigma=2.0)
2256+
with pytest.raises(TypeError, match=r"alpha should be a sequence of floats"):
2257+
transforms.ElasticTransform(alpha=[1.0, True], sigma=2.0)
2258+
with pytest.raises(ValueError, match=r"alpha is a sequence its length should be 2"):
2259+
transforms.ElasticTransform(alpha=[1.0, 0.0, 1.0], sigma=2.0)
2260+
2261+
with pytest.raises(TypeError, match=r"sigma should be float or a sequence of floats"):
2262+
transforms.ElasticTransform(alpha=2.0, sigma=True)
2263+
with pytest.raises(TypeError, match=r"sigma should be a sequence of floats"):
2264+
transforms.ElasticTransform(alpha=2.0, sigma=[1.0, True])
2265+
with pytest.raises(ValueError, match=r"sigma is a sequence its length should be 2"):
2266+
transforms.ElasticTransform(alpha=2.0, sigma=[1.0, 0.0, 1.0])
2267+
2268+
with pytest.warns(UserWarning, match=r"Argument interpolation should be of type InterpolationMode"):
2269+
t = transforms.transforms.ElasticTransform(alpha=2.0, sigma=2.0, interpolation=2)
2270+
assert t.interpolation == transforms.InterpolationMode.BILINEAR
2271+
2272+
with pytest.raises(TypeError, match=r"fill should be int or float"):
2273+
transforms.ElasticTransform(alpha=1.0, sigma=1.0, fill={})
2274+
2275+
x = torch.randint(0, 256, (3, 32, 32), dtype=torch.uint8)
2276+
img = F.to_pil_image(x)
2277+
t = transforms.ElasticTransform(alpha=0.0, sigma=0.0)
2278+
transformed_img = t(img)
2279+
assert transformed_img == img
2280+
2281+
# Smoke test on PIL images
2282+
t = transforms.ElasticTransform(alpha=0.5, sigma=0.23)
2283+
transformed_img = t(img)
2284+
assert isinstance(transformed_img, Image.Image)
2285+
2286+
# Checking if ElasticTransform can be printed as string
2287+
t.__repr__()
2288+
2289+
22532290
if __name__ == "__main__":
22542291
pytest.main([__file__])

torchvision/transforms/functional.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1484,3 +1484,67 @@ def equalize(img: Tensor) -> Tensor:
14841484
return F_pil.equalize(img)
14851485

14861486
return F_t.equalize(img)
1487+
1488+
1489+
def elastic_transform(
1490+
img: Tensor,
1491+
displacement: Tensor,
1492+
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
1493+
fill: Optional[List[float]] = None,
1494+
) -> Tensor:
1495+
"""Transform a tensor image with elastic transformations.
1496+
Given alpha and sigma, it will generate displacement
1497+
vectors for all pixels based on random offsets. Alpha controls the strength
1498+
and sigma controls the smoothness of the displacements.
1499+
The displacements are added to an identity grid and the resulting grid is
1500+
used to grid_sample from the image.
1501+
1502+
Applications:
1503+
Randomly transforms the morphology of objects in images and produces a
1504+
see-through-water-like effect.
1505+
1506+
Args:
1507+
img (PIL Image or Tensor): Image on which elastic_transform is applied.
1508+
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
1509+
where ... means it can have an arbitrary number of leading dimensions.
1510+
If img is PIL Image, it is expected to be in mode "P", "L" or "RGB".
1511+
displacement (Tensor): The displacement field.
1512+
interpolation (InterpolationMode): Desired interpolation enum defined by
1513+
:class:`torchvision.transforms.InterpolationMode`.
1514+
Default is ``InterpolationMode.BILINEAR``.
1515+
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
1516+
fill (number or str or tuple): Pixel fill value for constant fill. Default is 0.
1517+
If a tuple of length 3, it is used to fill R, G, B channels respectively.
1518+
This value is only used when the padding_mode is constant.
1519+
Only number is supported for torch Tensor.
1520+
Only int or str or tuple value is supported for PIL Image.
1521+
"""
1522+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
1523+
_log_api_usage_once(elastic_transform)
1524+
# Backward compatibility with integer value
1525+
if isinstance(interpolation, int):
1526+
warnings.warn(
1527+
"Argument interpolation should be of type InterpolationMode instead of int. "
1528+
"Please, use InterpolationMode enum."
1529+
)
1530+
interpolation = _interpolation_modes_from_int(interpolation)
1531+
1532+
if not isinstance(displacement, torch.Tensor):
1533+
raise TypeError("displacement should be a Tensor")
1534+
1535+
t_img = img
1536+
if not isinstance(img, torch.Tensor):
1537+
if not F_pil._is_pil_image(img):
1538+
raise TypeError(f"img should be PIL Image or Tensor. Got {type(img)}")
1539+
t_img = pil_to_tensor(img)
1540+
1541+
output = F_t.elastic_transform(
1542+
t_img,
1543+
displacement,
1544+
interpolation=interpolation.value,
1545+
fill=fill,
1546+
)
1547+
1548+
if not isinstance(img, torch.Tensor):
1549+
output = to_pil_image(output, mode=img.mode)
1550+
return output

torchvision/transforms/functional_tensor.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -968,3 +968,23 @@ def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool
968968

969969
img[..., i : i + h, j : j + w] = v
970970
return img
971+
972+
973+
def elastic_transform(
974+
img: Tensor,
975+
displacement: Tensor,
976+
interpolation: str = "bilinear",
977+
fill: Optional[List[float]] = None,
978+
) -> Tensor:
979+
980+
if not (isinstance(img, torch.Tensor)):
981+
raise TypeError(f"img should be Tensor. Got {type(img)}")
982+
983+
size = list(img.shape[-2:])
984+
displacement = displacement.to(img.device)
985+
986+
hw_space = [torch.linspace((-s + 1) / s, (s - 1) / s, s) for s in size]
987+
grid_y, grid_x = torch.meshgrid(hw_space, indexing="ij")
988+
identity_grid = torch.stack([grid_x, grid_y], -1).unsqueeze(0) # 1 x H x W x 2
989+
grid = identity_grid.to(img.device) + displacement
990+
return _apply_grid_transform(img, grid, interpolation, fill)

torchvision/transforms/transforms.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
"RandomAdjustSharpness",
5454
"RandomAutocontrast",
5555
"RandomEqualize",
56+
"ElasticTransform",
5657
]
5758

5859

@@ -2049,3 +2050,117 @@ def forward(self, img):
20492050

20502051
def __repr__(self) -> str:
20512052
return f"{self.__class__.__name__}(p={self.p})"
2053+
2054+
2055+
class ElasticTransform(torch.nn.Module):
2056+
"""Transform a tensor image with elastic transformations.
2057+
Given alpha and sigma, it will generate displacement
2058+
vectors for all pixels based on random offsets. Alpha controls the strength
2059+
and sigma controls the smoothness of the displacements.
2060+
The displacements are added to an identity grid and the resulting grid is
2061+
used to grid_sample from the image.
2062+
2063+
Applications:
2064+
Randomly transforms the morphology of objects in images and produces a
2065+
see-through-water-like effect.
2066+
2067+
Args:
2068+
alpha (float or sequence of floats): Magnitude of displacements. Default is 50.0.
2069+
sigma (float or sequence of floats): Smoothness of displacements. Default is 5.0.
2070+
interpolation (InterpolationMode): Desired interpolation enum defined by
2071+
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
2072+
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
2073+
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
2074+
fill (sequence or number): Pixel fill value for the area outside the transformed
2075+
image. Default is ``0``. If given a number, the value is used for all bands respectively.
2076+
2077+
"""
2078+
2079+
def __init__(self, alpha=50.0, sigma=5.0, interpolation=InterpolationMode.BILINEAR, fill=0):
2080+
super().__init__()
2081+
_log_api_usage_once(self)
2082+
if not isinstance(alpha, (float, Sequence)):
2083+
raise TypeError(f"alpha should be float or a sequence of floats. Got {type(alpha)}")
2084+
if isinstance(alpha, Sequence) and len(alpha) != 2:
2085+
raise ValueError(f"If alpha is a sequence its length should be 2. Got {len(alpha)}")
2086+
if isinstance(alpha, Sequence):
2087+
for element in alpha:
2088+
if not isinstance(element, float):
2089+
raise TypeError(f"alpha should be a sequence of floats. Got {type(element)}")
2090+
2091+
if isinstance(alpha, float):
2092+
alpha = [float(alpha), float(alpha)]
2093+
if isinstance(alpha, (list, tuple)) and len(alpha) == 1:
2094+
alpha = [alpha[0], alpha[0]]
2095+
2096+
self.alpha = alpha
2097+
2098+
if not isinstance(sigma, (float, Sequence)):
2099+
raise TypeError(f"sigma should be float or a sequence of floats. Got {type(sigma)}")
2100+
if isinstance(sigma, Sequence) and len(sigma) != 2:
2101+
raise ValueError(f"If sigma is a sequence its length should be 2. Got {len(sigma)}")
2102+
if isinstance(sigma, Sequence):
2103+
for element in sigma:
2104+
if not isinstance(element, float):
2105+
raise TypeError(f"sigma should be a sequence of floats. Got {type(element)}")
2106+
2107+
if isinstance(sigma, float):
2108+
sigma = [float(sigma), float(sigma)]
2109+
if isinstance(sigma, (list, tuple)) and len(sigma) == 1:
2110+
sigma = [sigma[0], sigma[0]]
2111+
2112+
self.sigma = sigma
2113+
2114+
# Backward compatibility with integer value
2115+
if isinstance(interpolation, int):
2116+
warnings.warn(
2117+
"Argument interpolation should be of type InterpolationMode instead of int. "
2118+
"Please, use InterpolationMode enum."
2119+
)
2120+
interpolation = _interpolation_modes_from_int(interpolation)
2121+
self.interpolation = interpolation
2122+
2123+
if not isinstance(fill, (int, float)):
2124+
raise TypeError(f"fill should be int or float. Got {type(fill)}")
2125+
self.fill = fill
2126+
2127+
@staticmethod
2128+
def get_params(alpha: List[float], sigma: List[float], size: List[int]) -> Tensor:
2129+
dx = torch.rand([1, 1] + size) * 2 - 1
2130+
if sigma[0] > 0.0:
2131+
kx = int(8 * sigma[0] + 1)
2132+
# if kernel size is even we have to make it odd
2133+
if kx % 2 == 0:
2134+
kx += 1
2135+
dx = F.gaussian_blur(dx, [kx, kx], sigma)
2136+
dx = dx * alpha[0] / size[0]
2137+
2138+
dy = torch.rand([1, 1] + size) * 2 - 1
2139+
if sigma[1] > 0.0:
2140+
ky = int(8 * sigma[1] + 1)
2141+
# if kernel size is even we have to make it odd
2142+
if ky % 2 == 0:
2143+
ky += 1
2144+
dy = F.gaussian_blur(dy, [ky, ky], sigma)
2145+
dy = dy * alpha[1] / size[1]
2146+
return torch.concat([dx, dy], 1).permute([0, 2, 3, 1]) # 1 x H x W x 2
2147+
2148+
def forward(self, tensor: Tensor) -> Tensor:
2149+
"""
2150+
Args:
2151+
img (PIL Image or Tensor): Image to be transformed.
2152+
2153+
Returns:
2154+
PIL Image or Tensor: Transformed image.
2155+
"""
2156+
size = F.get_image_size(tensor)[::-1]
2157+
displacement = self.get_params(self.alpha, self.sigma, size)
2158+
return F.elastic_transform(tensor, displacement, self.interpolation, self.fill)
2159+
2160+
def __repr__(self):
2161+
format_string = self.__class__.__name__ + "(alpha="
2162+
format_string += str(self.alpha) + ")"
2163+
format_string += ", (sigma=" + str(self.sigma) + ")"
2164+
format_string += ", interpolation={self.interpolation}"
2165+
format_string += ", fill={self.fill})"
2166+
return format_string

0 commit comments

Comments
 (0)