Skip to content

adjust_hue now supports inputs of type Tensor #2566

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Sep 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
653d2fd
adjust_hue now supports inputs of type Tensor
CristianManta Aug 9, 2020
ff93f78
Added comparison between original adjust_hue and its Tensor and torch…
CristianManta Aug 11, 2020
b565af5
Added a few type checkings related to adjust_hue in functional_tensor…
CristianManta Aug 11, 2020
535d1de
Changed implementation of _rgb2hsv and removed useless type declarati…
CristianManta Aug 12, 2020
5f9ba27
Handled the range of hue_factor in the assertions and temporarily inc…
CristianManta Aug 13, 2020
3f6c5a5
Fixed some lint issues with CircleCI and added type hints in function…
CristianManta Aug 13, 2020
270e02e
Corrected type hint mistakes.
CristianManta Aug 13, 2020
3b072db
Followed PR review recommendations and added test for class interface…
CristianManta Aug 14, 2020
bdae1cc
Refactored test_functional_tensor.py to match vfdev-5's d016cab branc…
CristianManta Aug 15, 2020
70cff26
Removed test_adjustments from test_transforms_tensor.py and moved the…
CristianManta Aug 15, 2020
bb7ec8c
Added cuda test cases for test_adjustments and tried to fix conflict.
CristianManta Aug 21, 2020
aead63d
Merge branch 'master' into adjust_hue_tensor
CristianManta Aug 21, 2020
e28e558
[WIP] Merge branch 'master' of https://github.com/pytorch/vision into…
vfdev-5 Sep 1, 2020
d4dd848
Updated tests
vfdev-5 Sep 1, 2020
8551011
Fixes incompatible devices
vfdev-5 Sep 1, 2020
71185bd
Increased tol for cuda tests
vfdev-5 Sep 2, 2020
3f23938
Merge branch 'master' of https://github.com/pytorch/vision into adjus…
vfdev-5 Sep 2, 2020
38f33e7
Merge branch 'master' of https://github.com/pytorch/vision into adjus…
vfdev-5 Sep 2, 2020
e8b5f28
Merge branch 'master' of github.com:pytorch/vision into cm/adjust_hue…
vfdev-5 Sep 2, 2020
c58d151
Fixes potential issue with inplace op
vfdev-5 Sep 2, 2020
a143835
Reverted fmod -> %
vfdev-5 Sep 2, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,11 +217,9 @@ def test_pad(self):
with self.assertRaises(ValueError, msg="Padding can not be negative for symmetric padding_mode"):
F_t.pad(tensor, (-2, -3), padding_mode="symmetric")

def _test_adjust_fn(self, fn, fn_pil, fn_t, configs):
def _test_adjust_fn(self, fn, fn_pil, fn_t, configs, tol=2.0 + 1e-10, agg_method="max"):
script_fn = torch.jit.script(fn)

torch.manual_seed(15)

tensor, pil_img = self._create_data(26, 34, device=self.device)

for dt in [None, torch.float32, torch.float64]:
Expand All @@ -230,7 +228,6 @@ def _test_adjust_fn(self, fn, fn_pil, fn_t, configs):
tensor = F.convert_image_dtype(tensor, dt)

for config in configs:

adjusted_tensor = fn_t(tensor, **config)
adjusted_pil = fn_pil(pil_img, **config)
scripted_result = script_fn(tensor, **config)
Expand All @@ -245,9 +242,12 @@ def _test_adjust_fn(self, fn, fn_pil, fn_t, configs):

# Check that max difference does not exceed 2 in [0, 255] range
# Exact matching is not possible due to incompatibility convert_image_dtype and PIL results
tol = 2.0 + 1e-10
self.approxEqualTensorToPIL(rbg_tensor.float(), adjusted_pil, tol, msg=msg, agg_method="max")
self.assertTrue(adjusted_tensor.allclose(scripted_result), msg=msg)
self.approxEqualTensorToPIL(rbg_tensor.float(), adjusted_pil, tol=tol, msg=msg, agg_method=agg_method)

atol = 1e-6
if adjusted_tensor.dtype == torch.uint8 and "cuda" in torch.device(self.device).type:
atol = 1.0
self.assertTrue(adjusted_tensor.allclose(scripted_result, atol=atol), msg=msg)

def test_adjust_brightness(self):
self._test_adjust_fn(
Expand All @@ -273,6 +273,16 @@ def test_adjust_saturation(self):
[{"saturation_factor": f} for f in [0.5, 0.75, 1.0, 1.5, 2.0]]
)

def test_adjust_hue(self):
self._test_adjust_fn(
F.adjust_hue,
F_pil.adjust_hue,
F_t.adjust_hue,
[{"hue_factor": f} for f in [-0.45, -0.25, 0.0, 0.25, 0.45]],
tol=0.1,
agg_method="mean"
)

def test_adjust_gamma(self):
self._test_adjust_fn(
F.adjust_gamma,
Expand Down
18 changes: 15 additions & 3 deletions test/test_transforms_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,24 +60,36 @@ def test_random_vertical_flip(self):
def test_color_jitter(self):

tol = 1.0 + 1e-10
for f in [0.1, 0.5, 1.0, 1.34]:
for f in [0.1, 0.5, 1.0, 1.34, (0.3, 0.7), [0.4, 0.5]]:
meth_kwargs = {"brightness": f}
self._test_class_op(
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
)

for f in [0.2, 0.5, 1.0, 1.5]:
for f in [0.2, 0.5, 1.0, 1.5, (0.3, 0.7), [0.4, 0.5]]:
meth_kwargs = {"contrast": f}
self._test_class_op(
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
)

for f in [0.5, 0.75, 1.0, 1.25]:
for f in [0.5, 0.75, 1.0, 1.25, (0.3, 0.7), [0.3, 0.4]]:
meth_kwargs = {"saturation": f}
self._test_class_op(
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
)

for f in [0.2, 0.5, (-0.2, 0.3), [-0.4, 0.5]]:
meth_kwargs = {"hue": f}
self._test_class_op(
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=0.1, agg_method="mean"
)

# All 4 parameters together
meth_kwargs = {"brightness": 0.2, "contrast": 0.2, "saturation": 0.2, "hue": 0.2}
self._test_class_op(
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=0.1, agg_method="mean"
)

def test_pad(self):

# Test functional.pad (PIL and Tensor) with padding as single int
Expand Down
6 changes: 3 additions & 3 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,20 +736,20 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
.. _Hue: https://en.wikipedia.org/wiki/Hue

Args:
img (PIL Image): PIL Image to be adjusted.
img (PIL Image or Tensor): Image to be adjusted.
hue_factor (float): How much to shift the hue channel. Should be in
[-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
HSV space in positive and negative direction respectively.
0 means no shift. Therefore, both -0.5 and 0.5 will give an image
with complementary colors while 0 gives the original image.

Returns:
PIL Image: Hue adjusted image.
PIL Image or Tensor: Hue adjusted image.
"""
if not isinstance(img, torch.Tensor):
return F_pil.adjust_hue(img, hue_factor)

raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return F_t.adjust_hue(img, hue_factor)


def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
Expand Down
11 changes: 6 additions & 5 deletions torchvision/transforms/functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
return _blend(img, mean, contrast_factor)


def adjust_hue(img, hue_factor):
def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
"""Adjust hue of an image.

The image hue is adjusted by converting the image to HSV and
Expand Down Expand Up @@ -185,17 +185,16 @@ def adjust_hue(img, hue_factor):
if not (-0.5 <= hue_factor <= 0.5):
raise ValueError('hue_factor ({}) is not in [-0.5, 0.5].'.format(hue_factor))

if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')
if not (isinstance(img, torch.Tensor) and _is_tensor_a_torch_image(img)):
raise TypeError('img should be Tensor image. Got {}'.format(type(img)))

orig_dtype = img.dtype
if img.dtype == torch.uint8:
img = img.to(dtype=torch.float32) / 255.0

img = _rgb2hsv(img)
h, s, v = img.unbind(0)
h += hue_factor
h = h % 1.0
h = (h + hue_factor) % 1.0
img = torch.stack((h, s, v))
img_hue_adj = _hsv2rgb(img)

Expand Down Expand Up @@ -408,6 +407,8 @@ def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor:
def _rgb2hsv(img):
r, g, b = img.unbind(0)

# Implementation is based on https://github.com/python-pillow/Pillow/blob/4174d4267616897df3746d315d5a2d0f82c656ee/
# src/libImaging/Convert.c#L330
maxc = torch.max(img, dim=0).values
minc = torch.min(img, dim=0).values

Expand Down