diff --git a/test/test_transforms.py b/test/test_transforms.py index 54aae796301..80eb3f2edc4 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -328,14 +328,15 @@ def test_pad_raises_with_invalid_pad_sequence_len(self): transforms.Pad((1, 2, 3, 4, 5)) def test_lambda(self): - trans = transforms.Lambda(lambda x: x.add(10)) + trans = transforms.Lambda(lambda x, a: (x.add(10), a)) x = torch.randn(10) - y = trans(x) + a = [] + y, a = trans(x, a) assert (y.equal(torch.add(x, 10))) - trans = transforms.Lambda(lambda x: x.add_(10)) + trans = transforms.Lambda(lambda x, a: (x.add_(10), a)) x = torch.randn(10) - y = trans(x) + y, a = trans(x, a) assert (y.equal(x)) # Checking if Lambda can be printed as string diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 964504eb9dc..7b4be27e85e 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -40,7 +40,30 @@ } -class Compose(object): +class AbstractTransform(object): + """Base class for all transforms. + + Its role is to simulate parametric polymorphism so that + if the transform is called with image only it return img only + and if the transform is call with image and keypoints, it return both. + + This is done in order to add keypoint transformation without breaking previous interface. + """ + + def __init__(self): + pass + + def __call__(self, img, keypoints=None): + kp = [] + if keypoints is not None: + kp = keypoints + img, kp = self.run(img, kp) + if keypoints is not None: + return img, kp + return img + + +class Compose(AbstractTransform): """Composes several transforms together. Args: @@ -56,10 +79,10 @@ class Compose(object): def __init__(self, transforms): self.transforms = transforms - def __call__(self, img): + def run(self, img, keypoints): for t in self.transforms: - img = t(img) - return img + img, keypoints = t(img, keypoints) + return img, keypoints def __repr__(self): format_string = self.__class__.__name__ + '(' @@ -70,7 +93,7 @@ def __repr__(self): return format_string -class ToTensor(object): +class ToTensor(AbstractTransform): """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. Converts a PIL Image or numpy.ndarray (H x W x C) in the range @@ -81,7 +104,10 @@ class ToTensor(object): In the other cases, tensors are returned without scaling. """ - def __call__(self, pic): + def __init__(self): + super(ToTensor, self).__init__() + + def run(self, pic, keypoints): """ Args: pic (PIL Image or numpy.ndarray): Image to be converted to tensor. @@ -89,13 +115,13 @@ def __call__(self, pic): Returns: Tensor: Converted image. """ - return F.to_tensor(pic) + return F.to_tensor(pic), keypoints def __repr__(self): return self.__class__.__name__ + '()' -class ToPILImage(object): +class ToPILImage(AbstractTransform): """Convert a tensor or an ndarray to PIL Image. Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape @@ -115,7 +141,7 @@ class ToPILImage(object): def __init__(self, mode=None): self.mode = mode - def __call__(self, pic): + def run(self, pic, keypoints): """ Args: pic (Tensor or numpy.ndarray): Image to be converted to PIL Image. @@ -124,7 +150,7 @@ def __call__(self, pic): PIL Image: Image converted to PIL Image. """ - return F.to_pil_image(pic, self.mode) + return F.to_pil_image(pic, self.mode), keypoints def __repr__(self): format_string = self.__class__.__name__ + '(' @@ -134,7 +160,7 @@ def __repr__(self): return format_string -class Normalize(object): +class Normalize(AbstractTransform): """Normalize a tensor image with mean and standard deviation. Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform will normalize each channel of the input ``torch.*Tensor`` i.e. @@ -155,7 +181,7 @@ def __init__(self, mean, std, inplace=False): self.std = std self.inplace = inplace - def __call__(self, tensor): + def run(self, tensor, keypoints): """ Args: tensor (Tensor): Tensor image of size (C, H, W) to be normalized. @@ -163,13 +189,13 @@ def __call__(self, tensor): Returns: Tensor: Normalized Tensor image. """ - return F.normalize(tensor, self.mean, self.std, self.inplace) + return F.normalize(tensor, self.mean, self.std, self.inplace), keypoints def __repr__(self): return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) -class Resize(object): +class Resize(AbstractTransform): """Resize the input PIL Image to the given size. Args: @@ -187,7 +213,7 @@ def __init__(self, size, interpolation=Image.BILINEAR): self.size = size self.interpolation = interpolation - def __call__(self, img): + def run(self, img, keypoints): """ Args: img (PIL Image): Image to be scaled. @@ -195,7 +221,18 @@ def __call__(self, img): Returns: PIL Image: Rescaled image. """ - return F.resize(img, self.size, self.interpolation) + ratioX, ratioY = 1, 1 + if isinstance(self.size, numbers.Number): + smallSize = min(img.width, img.height) + ratioX = float(self.size) / smallSize + ratioY = float(self.size) / smallSize + else: + ratioX = float(self.size[0]) / img.width + ratioY = float(self.size[1]) / img.height + for pointPair in keypoints: + pointPair[0] *= ratioX + pointPair[1] *= ratioY + return F.resize(img, self.size, self.interpolation), keypoints def __repr__(self): interpolate_str = _pil_interpolation_to_str[self.interpolation] @@ -212,7 +249,7 @@ def __init__(self, *args, **kwargs): super(Scale, self).__init__(*args, **kwargs) -class CenterCrop(object): +class CenterCrop(AbstractTransform): """Crops the given PIL Image at the center. Args: @@ -227,7 +264,7 @@ def __init__(self, size): else: self.size = size - def __call__(self, img): + def run(self, img, keypoints): """ Args: img (PIL Image): Image to be cropped. @@ -235,13 +272,18 @@ def __call__(self, img): Returns: PIL Image: Cropped image. """ - return F.center_crop(img, self.size) + croppedX = (img.width - self.size[0]) / 2 + croppedY = (img.height - self.size[1]) / 2 + for pointPair in keypoints: + pointPair[0] -= croppedX + pointPair[1] -= croppedY + return F.center_crop(img, self.size), keypoints def __repr__(self): return self.__class__.__name__ + '(size={0})'.format(self.size) -class Pad(object): +class Pad(AbstractTransform): """Pad the given PIL Image on all sides with the given "pad" value. Args: @@ -283,7 +325,7 @@ def __init__(self, padding, fill=0, padding_mode='constant'): self.fill = fill self.padding_mode = padding_mode - def __call__(self, img): + def run(self, img, keypoints): """ Args: img (PIL Image): Image to be padded. @@ -291,14 +333,22 @@ def __call__(self, img): Returns: PIL Image: Padded image. """ - return F.pad(img, self.padding, self.fill, self.padding_mode) + padX, padY = 0, 0 + if isinstance(self.padding, numbers.Number): + padX, padY = self.padding, self.padding + else: + padX, padY = self.padding[0], self.padding[1] + for pointPair in keypoints: + pointPair[0] += padX + pointPair[1] += padY + return F.pad(img, self.padding, self.fill, self.padding_mode), keypoints def __repr__(self): return self.__class__.__name__ + '(padding={0}, fill={1}, padding_mode={2})'.\ format(self.padding, self.fill, self.padding_mode) -class Lambda(object): +class Lambda(AbstractTransform): """Apply a user-defined lambda as a transform. Args: @@ -309,14 +359,14 @@ def __init__(self, lambd): assert callable(lambd), repr(type(lambd).__name__) + " object is not callable" self.lambd = lambd - def __call__(self, img): - return self.lambd(img) + def run(self, *args, **kwargs): + return self.lambd(*args, **kwargs) def __repr__(self): return self.__class__.__name__ + '()' -class RandomTransforms(object): +class RandomTransforms(AbstractTransform): """Base class for a list of transformations with randomness Args: @@ -327,7 +377,7 @@ def __init__(self, transforms): assert isinstance(transforms, (list, tuple)) self.transforms = transforms - def __call__(self, *args, **kwargs): + def run(self, *args, **kwargs): raise NotImplementedError() def __repr__(self): @@ -351,12 +401,12 @@ def __init__(self, transforms, p=0.5): super(RandomApply, self).__init__(transforms) self.p = p - def __call__(self, img): + def run(self, img, keypoints): if self.p < random.random(): - return img + return img, keypoints for t in self.transforms: - img = t(img) - return img + img, keypoints = t(img, keypoints) + return img, keypoints def __repr__(self): format_string = self.__class__.__name__ + '(' @@ -371,23 +421,23 @@ def __repr__(self): class RandomOrder(RandomTransforms): """Apply a list of transformations in a random order """ - def __call__(self, img): + def run(self, img, keypoints): order = list(range(len(self.transforms))) random.shuffle(order) for i in order: - img = self.transforms[i](img) - return img + img, keypoints = self.transforms[i](img, keypoints) + return img, keypoints class RandomChoice(RandomTransforms): """Apply single transformation randomly picked from a list """ - def __call__(self, img): + def run(self, img, keypoints): t = random.choice(self.transforms) - return t(img) + return t(img, keypoints) -class RandomCrop(object): +class RandomCrop(AbstractTransform): """Crop the given PIL Image at a random location. Args: @@ -453,7 +503,7 @@ def get_params(img, output_size): j = random.randint(0, w - tw) return i, j, th, tw - def __call__(self, img): + def run(self, img, keypoints): """ Args: img (PIL Image): Image to be cropped. @@ -461,25 +511,39 @@ def __call__(self, img): Returns: PIL Image: Cropped image. """ + deltaX, deltaY = 0, 0 if self.padding is not None: + if isinstance(self.padding, numbers.Number): + deltaX += self.padding + deltaY += self.padding + else: + deltaX += self.padding[0] + deltaY += self.padding[1] img = F.pad(img, self.padding, self.fill, self.padding_mode) # pad the width if needed if self.pad_if_needed and img.size[0] < self.size[1]: img = F.pad(img, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode) + deltaX += (self.size[1] - img.size[0]) # pad the height if needed if self.pad_if_needed and img.size[1] < self.size[0]: img = F.pad(img, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode) + deltaY += (self.size[0] - img.size[1]) i, j, h, w = self.get_params(img, self.size) + deltaX += i + deltaX += j - return F.crop(img, i, j, h, w) + for pointPair in keypoints: + pointPair[0] += deltaX + pointPair[1] += deltaY + return F.crop(img, i, j, h, w), keypoints def __repr__(self): return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding) -class RandomHorizontalFlip(object): +class RandomHorizontalFlip(AbstractTransform): """Horizontally flip the given PIL Image randomly with a given probability. Args: @@ -489,7 +553,7 @@ class RandomHorizontalFlip(object): def __init__(self, p=0.5): self.p = p - def __call__(self, img): + def run(self, img, keypoints): """ Args: img (PIL Image): Image to be flipped. @@ -498,14 +562,16 @@ def __call__(self, img): PIL Image: Randomly flipped image. """ if random.random() < self.p: - return F.hflip(img) - return img + for pointPair in keypoints: + pointPair[0] = img.width - pointPair[0] + return F.hflip(img), keypoints + return img, keypoints def __repr__(self): return self.__class__.__name__ + '(p={})'.format(self.p) -class RandomVerticalFlip(object): +class RandomVerticalFlip(AbstractTransform): """Vertically flip the given PIL Image randomly with a given probability. Args: @@ -515,7 +581,7 @@ class RandomVerticalFlip(object): def __init__(self, p=0.5): self.p = p - def __call__(self, img): + def run(self, img, keypoints): """ Args: img (PIL Image): Image to be flipped. @@ -524,14 +590,16 @@ def __call__(self, img): PIL Image: Randomly flipped image. """ if random.random() < self.p: - return F.vflip(img) - return img + for pointPair in keypoints: + pointPair[1] = img.height - pointPair[1] + return F.vflip(img), keypoints + return img, keypoints def __repr__(self): return self.__class__.__name__ + '(p={})'.format(self.p) -class RandomPerspective(object): +class RandomPerspective(AbstractTransform): """Performs Perspective transformation of the given PIL Image randomly with a given probability. Args: @@ -548,7 +616,7 @@ def __init__(self, distortion_scale=0.5, p=0.5, interpolation=Image.BICUBIC): self.interpolation = interpolation self.distortion_scale = distortion_scale - def __call__(self, img): + def run(self, img, keypoints): """ Args: img (PIL Image): Image to be Perspectively transformed. @@ -562,7 +630,19 @@ def __call__(self, img): if random.random() < self.p: width, height = img.size startpoints, endpoints = self.get_params(width, height, self.distortion_scale) - return F.perspective(img, startpoints, endpoints, self.interpolation) + print(startpoints) + print(endpoints) + tl, tr, br, bl = endpoints + for pointPair in keypoints: + x = pointPair[0] + y = pointPair[1] + lineStartX = tl[0] + (y / img.height) * (bl[0] - tl[0]) + lineStartY = tl[1] + (y / img.height) * (bl[1] - tl[1]) + lineEndX = tr[0] + (y / img.height) * (br[0] - tr[0]) + lineEndY = tr[1] + (y / img.height) * (br[1] - tr[1]) + pointPair[0] = int(lineStartX + (x / img.width) * (lineEndX - lineStartX)) + pointPair[1] = int(lineStartY + (x / img.width) * (lineEndY - lineStartY)) + return F.perspective(img, startpoints, endpoints, self.interpolation), keypoints return img @staticmethod @@ -595,7 +675,7 @@ def __repr__(self): return self.__class__.__name__ + '(p={})'.format(self.p) -class RandomResizedCrop(object): +class RandomResizedCrop(AbstractTransform): """Crop the given PIL Image to random size and aspect ratio. A crop of random size (default: of 0.08 to 1.0) of the original size and a random @@ -665,7 +745,7 @@ def get_params(img, scale, ratio): j = (img.size[0] - w) // 2 return i, j, h, w - def __call__(self, img): + def run(self, img, keypoints): """ Args: img (PIL Image): Image to be cropped and resized. @@ -674,7 +754,12 @@ def __call__(self, img): PIL Image: Randomly cropped and resized image. """ i, j, h, w = self.get_params(img, self.scale, self.ratio) - return F.resized_crop(img, i, j, h, w, self.size, self.interpolation) + for pointPair in keypoints: + pointPair[0] -= j + pointPair[1] -= i + pointPair[0] *= self.size[0] / w + pointPair[1] *= self.size[1] / h + return F.resized_crop(img, i, j, h, w, self.size, self.interpolation), keypoints def __repr__(self): interpolate_str = _pil_interpolation_to_str[self.interpolation] @@ -695,7 +780,7 @@ def __init__(self, *args, **kwargs): super(RandomSizedCrop, self).__init__(*args, **kwargs) -class FiveCrop(object): +class FiveCrop(AbstractTransform): """Crop the given PIL Image into four corners and the central crop .. Note:: @@ -727,14 +812,14 @@ def __init__(self, size): assert len(size) == 2, "Please provide only two dimensions (h, w) for size." self.size = size - def __call__(self, img): - return F.five_crop(img, self.size) + def run(self, img, keypoints): + return F.five_crop(img, self.size), keypoints def __repr__(self): return self.__class__.__name__ + '(size={0})'.format(self.size) -class TenCrop(object): +class TenCrop(AbstractTransform): """Crop the given PIL Image into four corners and the central crop plus the flipped version of these (horizontal flipping is used by default) @@ -770,14 +855,14 @@ def __init__(self, size, vertical_flip=False): self.size = size self.vertical_flip = vertical_flip - def __call__(self, img): - return F.ten_crop(img, self.size, self.vertical_flip) + def run(self, img, keypoints): + return F.ten_crop(img, self.size, self.vertical_flip), keypoints def __repr__(self): return self.__class__.__name__ + '(size={0}, vertical_flip={1})'.format(self.size, self.vertical_flip) -class LinearTransformation(object): +class LinearTransformation(AbstractTransform): """Transform a tensor image with a square transformation matrix and a mean_vector computed offline. Given transformation_matrix and mean_vector, will flatten the torch.*Tensor and @@ -808,7 +893,7 @@ def __init__(self, transformation_matrix, mean_vector): self.transformation_matrix = transformation_matrix self.mean_vector = mean_vector - def __call__(self, tensor): + def run(self, tensor, keypoints): """ Args: tensor (Tensor): Tensor image of size (C, H, W) to be whitened. @@ -823,7 +908,7 @@ def __call__(self, tensor): flat_tensor = tensor.view(1, -1) - self.mean_vector transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix) tensor = transformed_tensor.view(tensor.size()) - return tensor + return tensor, keypoints def __repr__(self): format_string = self.__class__.__name__ + '(transformation_matrix=' @@ -832,7 +917,7 @@ def __repr__(self): return format_string -class ColorJitter(object): +class ColorJitter(AbstractTransform): """Randomly change the brightness, contrast and saturation of an image. Args: @@ -889,26 +974,26 @@ def get_params(brightness, contrast, saturation, hue): if brightness is not None: brightness_factor = random.uniform(brightness[0], brightness[1]) - transforms.append(Lambda(lambda img: F.adjust_brightness(img, brightness_factor))) + transforms.append(Lambda(lambda img, anno: (F.adjust_brightness(img, brightness_factor), anno))) if contrast is not None: contrast_factor = random.uniform(contrast[0], contrast[1]) - transforms.append(Lambda(lambda img: F.adjust_contrast(img, contrast_factor))) + transforms.append(Lambda(lambda img, anno: (F.adjust_contrast(img, contrast_factor), anno))) if saturation is not None: saturation_factor = random.uniform(saturation[0], saturation[1]) - transforms.append(Lambda(lambda img: F.adjust_saturation(img, saturation_factor))) + transforms.append(Lambda(lambda img, anno: (F.adjust_saturation(img, saturation_factor), anno))) if hue is not None: hue_factor = random.uniform(hue[0], hue[1]) - transforms.append(Lambda(lambda img: F.adjust_hue(img, hue_factor))) + transforms.append(Lambda(lambda img, anno: (F.adjust_hue(img, hue_factor), anno))) random.shuffle(transforms) transform = Compose(transforms) return transform - def __call__(self, img): + def run(self, img, keypoints): """ Args: img (PIL Image): Input image. @@ -918,7 +1003,7 @@ def __call__(self, img): """ transform = self.get_params(self.brightness, self.contrast, self.saturation, self.hue) - return transform(img) + return transform(img), keypoints def __repr__(self): format_string = self.__class__.__name__ + '(' @@ -929,7 +1014,7 @@ def __repr__(self): return format_string -class RandomRotation(object): +class RandomRotation(AbstractTransform): """Rotate the image by angle. Args: @@ -976,7 +1061,7 @@ def get_params(degrees): return angle - def __call__(self, img): + def run(self, img, keypoints): """ Args: img (PIL Image): Image to be rotated. @@ -985,9 +1070,19 @@ def __call__(self, img): PIL Image: Rotated image. """ + if self.center is None: + self.center = [img.width / 2, img.height / 2] angle = self.get_params(self.degrees) - - return F.rotate(img, angle, self.resample, self.expand, self.center) + inrad = -math.radians(angle) + for pointPair in keypoints: + x, y = pointPair + x -= self.center[0] + y -= self.center[1] + pointPair[0] = math.cos(inrad) * x - math.sin(inrad) * y + pointPair[1] = math.sin(inrad) * x + math.cos(inrad) * y + pointPair[0] += self.center[0] + pointPair[1] += self.center[1] + return F.rotate(img, angle, self.resample, self.expand, self.center), keypoints def __repr__(self): format_string = self.__class__.__name__ + '(degrees={0}'.format(self.degrees) @@ -999,7 +1094,7 @@ def __repr__(self): return format_string -class RandomAffine(object): +class RandomAffine(AbstractTransform): """Random affine transformation of the image keeping center invariant Args: @@ -1106,7 +1201,7 @@ def get_params(degrees, translate, scale_ranges, shears, img_size): return angle, translations, scale, shear - def __call__(self, img): + def run(self, img): """ img (PIL Image): Image to be transformed. @@ -1134,7 +1229,7 @@ def __repr__(self): return s.format(name=self.__class__.__name__, **d) -class Grayscale(object): +class Grayscale(AbstractTransform): """Convert image to grayscale. Args: @@ -1150,7 +1245,7 @@ class Grayscale(object): def __init__(self, num_output_channels=1): self.num_output_channels = num_output_channels - def __call__(self, img): + def run(self, img, keypoints): """ Args: img (PIL Image): Image to be converted to grayscale. @@ -1158,13 +1253,13 @@ def __call__(self, img): Returns: PIL Image: Randomly grayscaled image. """ - return F.to_grayscale(img, num_output_channels=self.num_output_channels) + return F.to_grayscale(img, num_output_channels=self.num_output_channels), keypoints def __repr__(self): return self.__class__.__name__ + '(num_output_channels={0})'.format(self.num_output_channels) -class RandomGrayscale(object): +class RandomGrayscale(AbstractTransform): """Randomly convert image to grayscale with a probability of p (default 0.1). Args: @@ -1181,7 +1276,7 @@ class RandomGrayscale(object): def __init__(self, p=0.1): self.p = p - def __call__(self, img): + def run(self, img, keypoints): """ Args: img (PIL Image): Image to be converted to grayscale. @@ -1191,14 +1286,14 @@ def __call__(self, img): """ num_output_channels = 1 if img.mode == 'L' else 3 if random.random() < self.p: - return F.to_grayscale(img, num_output_channels=num_output_channels) - return img + return F.to_grayscale(img, num_output_channels=num_output_channels), keypoints + return img, keypoints def __repr__(self): return self.__class__.__name__ + '(p={0})'.format(self.p) -class RandomErasing(object): +class RandomErasing(AbstractTransform): """ Randomly selects a rectangle region in an image and erases its pixels. 'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/pdf/1708.04896.pdf @@ -1274,7 +1369,7 @@ def get_params(img, scale, ratio, value=0): # Return original image return 0, 0, img_h, img_w, img - def __call__(self, img): + def run(self, img, keypoints): """ Args: img (Tensor): Tensor image of size (C, H, W) to be erased. @@ -1284,5 +1379,5 @@ def __call__(self, img): """ if random.uniform(0, 1) < self.p: x, y, h, w, v = self.get_params(img, scale=self.scale, ratio=self.ratio, value=self.value) - return F.erase(img, x, y, h, w, v, self.inplace) - return img + return F.erase(img, x, y, h, w, v, self.inplace), keypoints + return img, keypoints