Skip to content

Make ColorJitter torchscriptable #2298

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 10, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions test/test_transforms_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,32 @@ def test_random_horizontal_flip(self):
def test_random_vertical_flip(self):
self._test_flip('vflip', 'RandomVerticalFlip')

def test_adjustments(self):
fns = ['adjust_brightness', 'adjust_contrast', 'adjust_saturation']
for _ in range(20):
factor = 3 * torch.rand(1).item()
tensor, _ = self._create_data()
pil_img = T.ToPILImage()(tensor)

for func in fns:
adjusted_tensor = getattr(F, func)(tensor, factor)
adjusted_pil_img = getattr(F, func)(pil_img, factor)

adjusted_pil_tensor = T.ToTensor()(adjusted_pil_img)
scripted_fn = torch.jit.script(getattr(F, func))
adjusted_tensor_script = scripted_fn(tensor, factor)

if not tensor.dtype.is_floating_point:
adjusted_tensor = adjusted_tensor.to(torch.float) / 255
adjusted_tensor_script = adjusted_tensor_script.to(torch.float) / 255

# F uses uint8 and F_t uses float, so there is a small
# difference in values caused by (at most 5) truncations.
max_diff = (adjusted_tensor - adjusted_pil_tensor).abs().max()
max_diff_scripted = (adjusted_tensor - adjusted_tensor_script).abs().max()
self.assertLess(max_diff, 5 / 255 + 1e-5)
self.assertLess(max_diff_scripted, 5 / 255 + 1e-5)


if __name__ == '__main__':
unittest.main()
42 changes: 18 additions & 24 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,64 +633,58 @@ def ten_crop(img, size, vertical_flip=False):
return first_five + second_five


def adjust_brightness(img, brightness_factor):
def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
"""Adjust brightness of an Image.

Args:
img (PIL Image): PIL Image to be adjusted.
img (PIL Image or Torch Tensor): Image to be adjusted.
brightness_factor (float): How much to adjust the brightness. Can be
any non negative number. 0 gives a black image, 1 gives the
original image while 2 increases the brightness by a factor of 2.

Returns:
PIL Image: Brightness adjusted image.
PIL Image or Torch Tensor: Brightness adjusted image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
if not isinstance(img, torch.Tensor):
return F_pil.adjust_brightness(img, brightness_factor)

enhancer = ImageEnhance.Brightness(img)
img = enhancer.enhance(brightness_factor)
return img
return F_t.adjust_brightness(img, brightness_factor)


def adjust_contrast(img, contrast_factor):
def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
"""Adjust contrast of an Image.

Args:
img (PIL Image): PIL Image to be adjusted.
img (PIL Image or Torch Tensor): Image to be adjusted.
contrast_factor (float): How much to adjust the contrast. Can be any
non negative number. 0 gives a solid gray image, 1 gives the
original image while 2 increases the contrast by a factor of 2.

Returns:
PIL Image: Contrast adjusted image.
PIL Image or Torch Tensor: Contrast adjusted image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
if not isinstance(img, torch.Tensor):
return F_pil.adjust_contrast(img, contrast_factor)

enhancer = ImageEnhance.Contrast(img)
img = enhancer.enhance(contrast_factor)
return img
return F_t.adjust_contrast(img, contrast_factor)


def adjust_saturation(img, saturation_factor):
def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
"""Adjust color saturation of an image.

Args:
img (PIL Image): PIL Image to be adjusted.
img (PIL Image or Torch Tensor): Image to be adjusted.
saturation_factor (float): How much to adjust the saturation. 0 will
give a black and white image, 1 will give the original image while
2 will enhance the saturation by a factor of 2.

Returns:
PIL Image: Saturation adjusted image.
PIL Image or Torch Tensor: Saturation adjusted image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
if not isinstance(img, torch.Tensor):
return F_pil.adjust_saturation(img, saturation_factor)

enhancer = ImageEnhance.Color(img)
img = enhancer.enhance(saturation_factor)
return img
return F_t.adjust_saturation(img, saturation_factor)


def adjust_hue(img, hue_factor):
Expand Down
56 changes: 56 additions & 0 deletions torchvision/transforms/functional_pil.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,59 @@ def vflip(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

return img.transpose(Image.FLIP_TOP_BOTTOM)


@torch.jit.unused
def adjust_brightness(img, brightness_factor):
"""Adjust brightness of an RGB image.

Args:
img (PIL Image): Image to be adjusted.
brightness_factor (float): How much to adjust the brightness. Can be
any non negative number. 0 gives a black image, 1 gives the
original image while 2 increases the brightness by a factor of 2.

Returns:
PIL Image: Brightness adjusted image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

enhancer = ImageEnhance.Brightness(img)
img = enhancer.enhance(brightness_factor)
return img


@torch.jit.unused
def adjust_contrast(img, contrast_factor):
"""Adjust contrast of an Image.
Args:
img (PIL Image): PIL Image to be adjusted.
contrast_factor (float): How much to adjust the contrast. Can be any
non negative number. 0 gives a solid gray image, 1 gives the
original image while 2 increases the contrast by a factor of 2.
Returns:
PIL Image: Contrast adjusted image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

enhancer = ImageEnhance.Contrast(img)
img = enhancer.enhance(contrast_factor)
return img


@torch.jit.unused
def adjust_saturation(img, saturation_factor):
"""Adjust color saturation of an image.
Args:
img (PIL Image): PIL Image to be adjusted.
saturation_factor (float): How much to adjust the saturation. 0 will
give a black and white image, 1 will give the original image while
2 will enhance the saturation by a factor of 2.
Returns:
PIL Image: Saturation adjusted image.
"""
enhancer = ImageEnhance.Color(img)
img = enhancer.enhance(saturation_factor)
return img
9 changes: 5 additions & 4 deletions torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,7 +865,7 @@ def __repr__(self):
return format_string


class ColorJitter(object):
class ColorJitter(torch.nn.Module):
"""Randomly change the brightness, contrast and saturation of an image.

Args:
Expand All @@ -883,6 +883,7 @@ class ColorJitter(object):
Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
"""
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
super().__init__()
self.brightness = self._check_input(brightness, 'brightness')
self.contrast = self._check_input(contrast, 'contrast')
self.saturation = self._check_input(saturation, 'saturation')
Expand Down Expand Up @@ -941,13 +942,13 @@ def get_params(brightness, contrast, saturation, hue):

return transform

def __call__(self, img):
def forward(self, img):
"""
Args:
img (PIL Image): Input image.
img (PIL Image or Tensor): Input image.

Returns:
PIL Image: Color jittered image.
PIL Image or Tensor: Color jittered image.
"""
transform = self.get_params(self.brightness, self.contrast,
self.saturation, self.hue)
Expand Down