Skip to content

Normalize, LinearTransformation are scriptable #2645

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Sep 23, 2020
Merged
20 changes: 20 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,26 @@ All transformations accept PIL Image, Tensor Image or batch of Tensor Images as
Tensor Images is a tensor of ``(B, C, H, W)`` shape, where ``B`` is a number of images in the batch. Deterministic or
random transformations applied on the batch of Tensor Images identically transform all the images of the batch.


Scriptable transforms
^^^^^^^^^^^^^^^^^^^^^

In order to script the transformations, please use ``torch.nn.Sequential`` instead of :class:`Compose`.

.. code:: python

transforms = torch.nn.Sequential(
transforms.CenterCrop(10),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
)
scripted_transforms = torch.jit.script(transforms)

Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor`` and does not require
`lambda` functions or ``PIL.Image``.

For any custom transformations to be used with ``torch.jit.script``, they should be derived from ``torch.nn.Module``.


.. autoclass:: Compose

Transforms on PIL Image
Expand Down
57 changes: 57 additions & 0 deletions test/test_transforms_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,63 @@ def test_to_grayscale(self):
"RandomGrayscale", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
)

def test_normalize(self):
tensor, _ = self._create_data(26, 34, device=self.device)
batch_tensors = torch.rand(4, 3, 44, 56, device=self.device)

tensor = tensor.to(dtype=torch.float32) / 255.0
# test for class interface
fn = T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
scripted_fn = torch.jit.script(fn)

self._test_transform_vs_scripted(fn, scripted_fn, tensor)
self._test_transform_vs_scripted_on_batch(fn, scripted_fn, batch_tensors)

def test_linear_transformation(self):
c, h, w = 3, 24, 32

tensor, _ = self._create_data(h, w, channels=c, device=self.device)

matrix = torch.rand(c * h * w, c * h * w, device=self.device)
mean_vector = torch.rand(c * h * w, device=self.device)

fn = T.LinearTransformation(matrix, mean_vector)
scripted_fn = torch.jit.script(fn)

self._test_transform_vs_scripted(fn, scripted_fn, tensor)

batch_tensors = torch.rand(4, c, h, w, device=self.device)
# We skip some tests from _test_transform_vs_scripted_on_batch as
# results for scripted and non-scripted transformations are not exactly the same
torch.manual_seed(12)
transformed_batch = fn(batch_tensors)
torch.manual_seed(12)
s_transformed_batch = scripted_fn(batch_tensors)
self.assertTrue(transformed_batch.equal(s_transformed_batch))

def test_compose(self):
tensor, _ = self._create_data(26, 34, device=self.device)
tensor = tensor.to(dtype=torch.float32) / 255.0

transforms = T.Compose([
T.CenterCrop(10),
T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
s_transforms = torch.nn.Sequential(*transforms.transforms)

scripted_fn = torch.jit.script(s_transforms)
torch.manual_seed(12)
transformed_tensor = transforms(tensor)
torch.manual_seed(12)
transformed_tensor_script = scripted_fn(tensor)
self.assertTrue(transformed_tensor.equal(transformed_tensor_script), msg="{}".format(transforms))

t = T.Compose([
lambda x: x,
])
with self.assertRaisesRegex(RuntimeError, r"Could not get name of python class object"):
torch.jit.script(t)


@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
class CUDATester(Tester):
Expand Down
16 changes: 8 additions & 8 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def to_pil_image(pic, mode=None):
return Image.fromarray(npimg, mode=mode)


def normalize(tensor, mean, std, inplace=False):
def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool = False) -> Tensor:
"""Normalize a tensor image with mean and standard deviation.

.. note::
Expand All @@ -292,19 +292,19 @@ def normalize(tensor, mean, std, inplace=False):
See :class:`~torchvision.transforms.Normalize` for more details.

Args:
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
tensor (Tensor): Tensor image of size (C, H, W) or (B, C, H, W) to be normalized.
mean (sequence): Sequence of means for each channel.
std (sequence): Sequence of standard deviations for each channel.
inplace(bool,optional): Bool to make this operation inplace.

Returns:
Tensor: Normalized Tensor image.
"""
if not torch.is_tensor(tensor):
raise TypeError('tensor should be a torch tensor. Got {}.'.format(type(tensor)))
if not isinstance(tensor, torch.Tensor):
raise TypeError('Input tensor should be a torch tensor. Got {}.'.format(type(tensor)))

if tensor.ndimension() != 3:
raise ValueError('Expected tensor to be a tensor image of size (C, H, W). Got tensor.size() = '
if tensor.ndim < 3:
raise ValueError('Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = '
'{}.'.format(tensor.size()))

if not inplace:
Expand All @@ -316,9 +316,9 @@ def normalize(tensor, mean, std, inplace=False):
if (std == 0).any():
raise ValueError('std evaluated to zero after conversion to {}, leading to division by zero.'.format(dtype))
if mean.ndim == 1:
mean = mean[:, None, None]
mean = mean.view(-1, 1, 1)
if std.ndim == 1:
std = std[:, None, None]
std = std.view(-1, 1, 1)
tensor.sub_(mean).div_(std)
return tensor

Expand Down
67 changes: 47 additions & 20 deletions torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import random
import warnings
from collections.abc import Sequence
from typing import Tuple, List, Optional
from typing import Tuple, List, Optional, Any

import torch
from PIL import Image
Expand Down Expand Up @@ -33,7 +33,7 @@
}


class Compose(object):
class Compose:
"""Composes several transforms together.

Args:
Expand All @@ -44,6 +44,19 @@ class Compose(object):
>>> transforms.CenterCrop(10),
>>> transforms.ToTensor(),
>>> ])

.. note::
In order to script the transformations, please use ``torch.nn.Sequential`` as below.

>>> transforms = torch.nn.Sequential(
>>> transforms.CenterCrop(10),
>>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
>>> )
>>> scripted_transforms = torch.jit.script(transforms)

Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
`lambda` functions or ``PIL.Image``.

"""

def __init__(self, transforms):
Expand All @@ -63,7 +76,7 @@ def __repr__(self):
return format_string


class ToTensor(object):
class ToTensor:
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.

Converts a PIL Image or numpy.ndarray (H x W x C) in the range
Expand Down Expand Up @@ -94,7 +107,7 @@ def __repr__(self):
return self.__class__.__name__ + '()'


class PILToTensor(object):
class PILToTensor:
"""Convert a ``PIL Image`` to a tensor of the same type.

Converts a PIL Image (H x W x C) to a Tensor of shape (C x H x W).
Expand All @@ -114,7 +127,7 @@ def __repr__(self):
return self.__class__.__name__ + '()'


class ConvertImageDtype(object):
class ConvertImageDtype:
"""Convert a tensor image to the given ``dtype`` and scale the values accordingly

Args:
Expand All @@ -139,7 +152,7 @@ def __call__(self, image: torch.Tensor) -> torch.Tensor:
return F.convert_image_dtype(image, self.dtype)


class ToPILImage(object):
class ToPILImage:
"""Convert a tensor or an ndarray to PIL Image.

Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
Expand Down Expand Up @@ -178,7 +191,7 @@ def __repr__(self):
return format_string


class Normalize(object):
class Normalize(torch.nn.Module):
"""Normalize a tensor image with mean and standard deviation.
Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n``
channels, this transform will normalize each channel of the input
Expand All @@ -196,11 +209,12 @@ class Normalize(object):
"""

def __init__(self, mean, std, inplace=False):
super().__init__()
self.mean = mean
self.std = std
self.inplace = inplace

def __call__(self, tensor):
def forward(self, tensor: Tensor) -> Tensor:
"""
Args:
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
Expand Down Expand Up @@ -358,15 +372,16 @@ def __repr__(self):
format(self.padding, self.fill, self.padding_mode)


class Lambda(object):
class Lambda:
"""Apply a user-defined lambda as a transform.

Args:
lambd (function): Lambda/function to be used for transform.
"""

def __init__(self, lambd):
assert callable(lambd), repr(type(lambd).__name__) + " object is not callable"
if not callable(lambd):
raise TypeError("Argument lambd should be callable, got {}".format(repr(type(lambd).__name__)))
self.lambd = lambd

def __call__(self, img):
Expand All @@ -376,7 +391,7 @@ def __repr__(self):
return self.__class__.__name__ + '()'


class RandomTransforms(object):
class RandomTransforms:
"""Base class for a list of transformations with randomness

Args:
Expand Down Expand Up @@ -408,7 +423,7 @@ class RandomApply(RandomTransforms):
"""

def __init__(self, transforms, p=0.5):
super(RandomApply, self).__init__(transforms)
super().__init__(transforms)
self.p = p

def __call__(self, img):
Expand Down Expand Up @@ -897,7 +912,7 @@ def __repr__(self):
return self.__class__.__name__ + '(size={0}, vertical_flip={1})'.format(self.size, self.vertical_flip)


class LinearTransformation(object):
class LinearTransformation(torch.nn.Module):
"""Transform a tensor image with a square transformation matrix and a mean_vector computed
offline.
Given transformation_matrix and mean_vector, will flatten the torch.*Tensor and
Expand All @@ -916,6 +931,7 @@ class LinearTransformation(object):
"""

def __init__(self, transformation_matrix, mean_vector):
super().__init__()
if transformation_matrix.size(0) != transformation_matrix.size(1):
raise ValueError("transformation_matrix should be square. Got " +
"[{} x {}] rectangular matrix.".format(*transformation_matrix.size()))
Expand All @@ -925,24 +941,35 @@ def __init__(self, transformation_matrix, mean_vector):
" as any one of the dimensions of the transformation_matrix [{}]"
.format(tuple(transformation_matrix.size())))

if transformation_matrix.device != mean_vector.device:
raise ValueError("Input tensors should be on the same device. Got {} and {}"
.format(transformation_matrix.device, mean_vector.device))

self.transformation_matrix = transformation_matrix
self.mean_vector = mean_vector

def __call__(self, tensor):
def forward(self, tensor: Tensor) -> Tensor:
"""
Args:
tensor (Tensor): Tensor image of size (C, H, W) to be whitened.

Returns:
Tensor: Transformed image.
"""
if tensor.size(0) * tensor.size(1) * tensor.size(2) != self.transformation_matrix.size(0):
raise ValueError("tensor and transformation matrix have incompatible shape." +
"[{} x {} x {}] != ".format(*tensor.size()) +
"{}".format(self.transformation_matrix.size(0)))
flat_tensor = tensor.view(1, -1) - self.mean_vector
shape = tensor.shape
n = shape[-3] * shape[-2] * shape[-1]
if n != self.transformation_matrix.shape[0]:
raise ValueError("Input tensor and transformation matrix have incompatible shape." +
"[{} x {} x {}] != ".format(shape[-3], shape[-2], shape[-1]) +
"{}".format(self.transformation_matrix.shape[0]))

if tensor.device.type != self.mean_vector.device.type:
raise ValueError("Input tensor should be on the same device as transformation matrix and mean vector. "
"Got {} vs {}".format(tensor.device, self.mean_vector.device))

flat_tensor = tensor.view(-1, n) - self.mean_vector
transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix)
tensor = transformed_tensor.view(tensor.size())
tensor = transformed_tensor.view(shape)
return tensor

def __repr__(self):
Expand Down