diff --git a/README.md b/README.md index 495310f4cbd..e20a7d0b402 100644 --- a/README.md +++ b/README.md @@ -144,14 +144,6 @@ The data is preprocessed [as described here](https://github.com/facebook/fb.resn Transforms are common image transforms. They can be chained together using `transforms.Compose` -- `ToTensor()` - converts PIL Image to Tensor -- `Normalize(mean, std)` - normalizes the image given mean, std (for example: mean = [0.3, 1.2, 2.1]) -- `Scale(size, interpolation=Image.BILINEAR)` - Scales the smaller image edge to the given size. Interpolation modes are options from PIL -- `CenterCrop(size)` - center-crops the image to the given size -- `RandomCrop(size)` - Random crops the image to the given size. -- `RandomHorizontalFlip()` - hflip the image with probability 0.5 -- `RandomSizedCrop(size, interpolation=Image.BILINEAR)` - Random crop with size 0.08-1 and aspect ratio 3/4 - 4/3 (Inception-style) - ### `transforms.Compose` One can compose several transforms together. @@ -166,3 +158,45 @@ transform = transforms.Compose([ std = [ 0.229, 0.224, 0.225 ]), ]) ``` + +## Transforms on PIL.Image + +### `Scale(size, interpolation=Image.BILINEAR)` +Rescales the input PIL.Image to the given 'size'. +'size' will be the size of the smaller edge. + +For example, if height > width, then image will be +rescaled to (size * height / width, size) +- size: size of the smaller edge +- interpolation: Default: PIL.Image.BILINEAR + +### `CenterCrop(size)` - center-crops the image to the given size +Crops the given PIL.Image at the center to have a region of +the given size. size can be a tuple (target_height, target_width) +or an integer, in which case the target will be of a square shape (size, size) + +### `RandomCrop(size)` +Crops the given PIL.Image at a random location to have a region of +the given size. size can be a tuple (target_height, target_width) +or an integer, in which case the target will be of a square shape (size, size) + +### `RandomHorizontalFlip()` +Randomly horizontally flips the given PIL.Image with a probability of 0.5 + +### `RandomSizedCrop(size, interpolation=Image.BILINEAR)` +Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size +and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio + +This is popularly used to train the Inception networks +- size: size of the smaller edge +- interpolation: Default: PIL.Image.BILINEAR + +## Transforms on torch.*Tensor + +### `Normalize(mean, std)` +Given mean: (R, G, B) and std: (R, G, B), will normalize each channel of the torch.*Tensor, i.e. channel = (channel - mean) / std + +## Conversion Transforms +- `ToTensor()` - Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] +- `ToPILImage()` - Converts a torch.*Tensor of range [0, 1] and shape C x H x W or numpy ndarray of dtype=uint8, range[0, 255] and shape H x W x C to a PIL.Image of range [0, 255] + diff --git a/test/test_transforms.py b/test/test_transforms.py new file mode 100644 index 00000000000..499c8c9ecc4 --- /dev/null +++ b/test/test_transforms.py @@ -0,0 +1,87 @@ +import torch +import torchvision.transforms as transforms +import torchvision.datasets as datasets +import numpy as np +import unittest +import random + +class Tester(unittest.TestCase): + def test_crop(self): + height = random.randint(10, 32) * 2 + width = random.randint(10, 32) * 2 + oheight = random.randint(5, (height - 2) / 2) * 2 + owidth = random.randint(5, (width - 2) / 2) * 2 + + img = torch.ones(3, height, width) + oh1 = (height - oheight) / 2 + ow1 = (width - owidth) / 2 + imgnarrow = img[:, oh1 :oh1 + oheight, ow1 :ow1 + owidth] + imgnarrow.fill_(0) + result = transforms.Compose([ + transforms.ToPILImage(), + transforms.CenterCrop((oheight, owidth)), + transforms.ToTensor(), + ])(img) + assert result.sum() == 0, "height: " + str(height) + " width: " \ + + str( width) + " oheight: " + str(oheight) + " owidth: " + str(owidth) + oheight += 1 + owidth += 1 + result = transforms.Compose([ + transforms.ToPILImage(), + transforms.CenterCrop((oheight, owidth)), + transforms.ToTensor(), + ])(img) + sum1 = result.sum() + assert sum1 > 1, "height: " + str(height) + " width: " \ + + str( width) + " oheight: " + str(oheight) + " owidth: " + str(owidth) + oheight += 1 + owidth += 1 + result = transforms.Compose([ + transforms.ToPILImage(), + transforms.CenterCrop((oheight, owidth)), + transforms.ToTensor(), + ])(img) + sum2 = result.sum() + assert sum2 > 0, "height: " + str(height) + " width: " \ + + str( width) + " oheight: " + str(oheight) + " owidth: " + str(owidth) + assert sum2 > sum1, "height: " + str(height) + " width: " \ + + str( width) + " oheight: " + str(oheight) + " owidth: " + str(owidth) + + def test_scale(self): + height = random.randint(24, 32) * 2 + width = random.randint(24, 32) * 2 + osize = random.randint(5, 12) * 2 + + img = torch.ones(3, height, width) + result = transforms.Compose([ + transforms.ToPILImage(), + transforms.Scale(osize), + transforms.ToTensor(), + ])(img) + # print img.size() + # print 'output size:', osize + # print result.size() + assert osize in result.size() + if height < width: + assert result.size(1) <= result.size(2) + elif width < height: + assert result.size(1) >= result.size(2) + + def test_random_crop(self): + height = random.randint(10, 32) * 2 + width = random.randint(10, 32) * 2 + oheight = random.randint(5, (height - 2) / 2) * 2 + owidth = random.randint(5, (width - 2) / 2) * 2 + img = torch.ones(3, height, width) + result = transforms.Compose([ + transforms.ToPILImage(), + transforms.RandomCrop((oheight, owidth)), + transforms.ToTensor(), + ])(img) + assert result.size(1) == oheight + assert result.size(2) == owidth + + + +if __name__ == '__main__': + unittest.main() diff --git a/torchvision/transforms.py b/torchvision/transforms.py index cfc0b5c8755..797e2363ef5 100644 --- a/torchvision/transforms.py +++ b/torchvision/transforms.py @@ -1,11 +1,19 @@ +from __future__ import division import torch import math import random from PIL import Image import numpy as np - +import numbers class Compose(object): + """ Composes several transforms together. + For example: + >>> transforms.Compose([ + >>> transforms.CenterCrop(10), + >>> transforms.ToTensor(), + >>> ]) + """ def __init__(self, transforms): self.transforms = transforms @@ -16,6 +24,8 @@ def __call__(self, img): class ToTensor(object): + """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255] + to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """ def __call__(self, pic): if isinstance(pic, np.ndarray): # handle numpy array @@ -24,24 +34,50 @@ def __call__(self, pic): # handle PIL Image img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) img = img.view(pic.size[0], pic.size[1], 3) - # put it in CHW format + # put it from WHC to CHW format # yikes, this transpose takes 80% of the loading time/CPU - img = img.transpose(0, 2).transpose(1, 2).contiguous() - return img.float() + img = img.transpose(0, 2).contiguous() + return img.float().div(255) + +class ToPILImage(object): + """ Converts a torch.*Tensor of range [0, 1] and shape C x H x W + or numpy ndarray of dtype=uint8, range[0, 255] and shape H x W x C + to a PIL.Image of range [0, 255] + """ + def __call__(self, pic): + if isinstance(pic, np.ndarray): + # handle numpy array + img = Image.fromarray(pic) + else: + npimg = pic.mul(255).byte().numpy() + npimg = np.transpose(npimg, (1,2,0)) + img = Image.fromarray(npimg) + return img class Normalize(object): + """ Given mean: (R, G, B) and std: (R, G, B), + will normalize each channel of the torch.*Tensor, i.e. + channel = (channel - mean) / std + """ def __init__(self, mean, std): self.mean = mean self.std = std def __call__(self, tensor): + # TODO: make efficient for t, m, s in zip(tensor, self.mean, self.std): t.sub_(m).div_(s) return tensor class Scale(object): - "Scales the smaller edge to size" + """ Rescales the input PIL.Image to the given 'size'. + 'size' will be the size of the smaller edge. + For example, if height > width, then image will be + rescaled to (size * height / width, size) + size: size of the smaller edge + interpolation: Default: PIL.Image.BILINEAR + """ def __init__(self, size, interpolation=Image.BILINEAR): self.size = size self.interpolation = interpolation @@ -51,27 +87,44 @@ def __call__(self, img): if (w <= h and w == self.size) or (h <= w and h == self.size): return img if w < h: - return img.resize((w, int(round(h / w * self.size))), self.interpolation) + ow = self.size + oh = int(self.size * h / w) + return img.resize((ow, oh), self.interpolation) else: - return img.resize((int(round(w / h * self.size)), h), self.interpolation) + oh = self.size + ow = int(self.size * w / h) + return img.resize((ow, oh), self.interpolation) class CenterCrop(object): - "Crop to centered rectangle" + """Crops the given PIL.Image at the center to have a region of + the given size. size can be a tuple (target_height, target_width) + or an integer, in which case the target will be of a square shape (size, size) + """ def __init__(self, size): - self.size = size + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size def __call__(self, img): w, h = img.size - x1 = int(round((w - self.size) / 2)) - y1 = int(round((h - self.size) / 2)) - return img.crop((x1, y1, x1 + self.size, y1 + self.size)) + th, tw = self.size + x1 = int(round((w - tw) / 2)) + y1 = int(round((h - th) / 2)) + return img.crop((x1, y1, x1 + tw, y1 + th)) class RandomCrop(object): - "Random crop form larger image with optional zero padding" + """Crops the given PIL.Image at a random location to have a region of + the given size. size can be a tuple (target_height, target_width) + or an integer, in which case the target will be of a square shape (size, size) + """ def __init__(self, size, padding=0): - self.size = size + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size self.padding = padding def __call__(self, img): @@ -79,16 +132,18 @@ def __call__(self, img): raise NotImplementedError() w, h = img.size - if w == self.size and h == self.size: + th, tw = self.size + if w == tw and h == th: return img - x1 = random.randint(0, w - self.size) - y1 = random.randint(0, h - self.size) - return img.crop((x1, y1, x1 + self.size, y1 + self.size)) + x1 = random.randint(0, w - tw) + y1 = random.randint(0, h - th) + return img.crop((x1, y1, x1 + tw, y1 + th)) class RandomHorizontalFlip(object): - "Horizontal flip with 0.5 probability" + """Randomly horizontally flips the given PIL.Image with a probability of 0.5 + """ def __call__(self, img): if random.random() < 0.5: return img.transpose(Image.FLIP_LEFT_RIGHT) @@ -96,7 +151,12 @@ def __call__(self, img): class RandomSizedCrop(object): - "Random crop with size 0.08-1 and aspect ratio 3/4 - 4/3 (Inception-style)" + """Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size + and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio + This is popularly used to train the Inception networks + size: size of the smaller edge + interpolation: Default: PIL.Image.BILINEAR + """ def __init__(self, size, interpolation=Image.BILINEAR): self.size = size self.interpolation = interpolation