Skip to content

[proto] Speed up adjust color ops #6784

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Oct 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions test/test_prototype_transforms_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,9 +254,10 @@ def __init__(
legacy_transforms.RandomAdjustSharpness,
[
ArgsKwargs(p=0, sharpness_factor=0.5),
ArgsKwargs(p=1, sharpness_factor=0.3),
ArgsKwargs(p=1, sharpness_factor=0.2),
ArgsKwargs(p=1, sharpness_factor=0.99),
],
closeness_kwargs={"atol": 1e-6, "rtol": 1e-6},
),
ConsistencyConfig(
prototype_transforms.RandomGrayscale,
Expand Down Expand Up @@ -306,8 +307,9 @@ def __init__(
ArgsKwargs(saturation=(0.8, 0.9)),
ArgsKwargs(hue=0.3),
ArgsKwargs(hue=(-0.1, 0.2)),
ArgsKwargs(brightness=0.1, contrast=0.4, saturation=0.7, hue=0.3),
ArgsKwargs(brightness=0.1, contrast=0.4, saturation=0.5, hue=0.6),
],
closeness_kwargs={"atol": 1e-5, "rtol": 1e-5},
),
*[
ConsistencyConfig(
Expand Down Expand Up @@ -753,7 +755,7 @@ def test_randaug(self, inpt, interpolation, mocker):
expected_output = t_ref(inpt)
output = t(inpt)

assert_equal(expected_output, output)
assert_close(expected_output, output, atol=1, rtol=0.1)

@pytest.mark.parametrize(
"inpt",
Expand Down Expand Up @@ -801,7 +803,7 @@ def test_trivial_aug(self, inpt, interpolation, mocker):
expected_output = t_ref(inpt)
output = t(inpt)

assert_equal(expected_output, output)
assert_close(expected_output, output, atol=1, rtol=0.1)

@pytest.mark.parametrize(
"inpt",
Expand Down
69 changes: 57 additions & 12 deletions torchvision/prototype/transforms/functional/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,29 @@
from torchvision.prototype import features
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT

from ._meta import get_dimensions_image_tensor
from ._meta import _rgb_to_gray, get_dimensions_image_tensor, get_num_channels_image_tensor


def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor:
ratio = float(ratio)
fp = image1.is_floating_point()
bound = 1.0 if fp else 255.0
output = image1.mul(ratio).add_(image2, alpha=(1.0 - ratio)).clamp_(0, bound)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this! Supersedes the work at #6765

return output if fp else output.to(image1.dtype)


def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float) -> torch.Tensor:
if brightness_factor < 0:
raise ValueError(f"brightness_factor ({brightness_factor}) is not non-negative.")

_FT._assert_channels(image, [1, 3])

fp = image.is_floating_point()
bound = 1.0 if fp else 255.0
output = image.mul(brightness_factor).clamp_(0, bound)
Copy link

@MiChatz MiChatz Feb 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue:
This clamp_ implementation is enforcing the histogram to take values between 0 and 1 in my case is not working the way I will expect since my images are normalized to values around 0 (ex: -0.2 to 0.8) so the multiplication with the factor of 1 will not return an identical image.

Suggestion:
output = image.mul(brightness_factor).clamp_(image.min(), image.max())

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MiChatz thanks for the feedback. There is (yet unwritten) assumption for color transformations on float images that image range is between [0, 1].

return output if fp else output.to(image.dtype)


adjust_brightness_image_tensor = _FT.adjust_brightness
adjust_brightness_image_pil = _FP.adjust_brightness


Expand All @@ -21,7 +41,20 @@ def adjust_brightness(inpt: features.InputTypeJIT, brightness_factor: float) ->
return adjust_brightness_image_pil(inpt, brightness_factor=brightness_factor)


adjust_saturation_image_tensor = _FT.adjust_saturation
def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float) -> torch.Tensor:
if saturation_factor < 0:
raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.")

c = get_num_channels_image_tensor(image)
if c not in [1, 3]:
raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}")

if c == 1: # Match PIL behaviour
return image

return _blend(image, _rgb_to_gray(image), saturation_factor)


adjust_saturation_image_pil = _FP.adjust_saturation


Expand All @@ -38,7 +71,19 @@ def adjust_saturation(inpt: features.InputTypeJIT, saturation_factor: float) ->
return adjust_saturation_image_pil(inpt, saturation_factor=saturation_factor)


adjust_contrast_image_tensor = _FT.adjust_contrast
def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) -> torch.Tensor:
if contrast_factor < 0:
raise ValueError(f"contrast_factor ({contrast_factor}) is not non-negative.")

c = get_num_channels_image_tensor(image)
if c not in [1, 3]:
raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}")
dtype = image.dtype if torch.is_floating_point(image) else torch.float32
grayscale_image = _rgb_to_gray(image) if c == 3 else image
mean = torch.mean(grayscale_image.to(dtype), dim=(-3, -2, -1), keepdim=True)
Comment on lines +82 to +83
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
grayscale_image = _rgb_to_gray(image) if c == 3 else image
mean = torch.mean(grayscale_image.to(dtype), dim=(-3, -2, -1), keepdim=True)
grayscale_image = _rgb_to_gray(image.to(dtype)) if c == 3 else image.to(dtype)
mean = torch.mean(grayscale_image, dim=(-3, -2, -1), keepdim=True)

This saves one conversion in _rgb_to_gray in case the input is uint8: _rgb_to_gray would convert the output of its floating point computation back to uint8 just for the result being converted back to floating point in before the torch.mean call.

Copy link
Collaborator Author

@vfdev-5 vfdev-5 Oct 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pmeier actually, doing so the output of mean is not the same for uint8 image.
As we originally applied uint8 cast in the end of _rgb_to_gray we get rid of all floating point values. This is not the case if image is casted to float before _rgb_to_gray.
Here is an example of difference for mean:

tensor([[[125.9521]]])   # original implementation
# vs
tensor([[[126.4607]]])  # cast to float before `_rgb_to_gray`.

So, finally consistency tests report for example:

Mismatched elements: 256 / 2772 (9.2%)                                                                                                                                      
Greatest absolute difference: 3 at index (1, 2, 4, 21) (up to 1e-05 allowed)   
Greatest relative difference: 0.25 at index (2, 0, 6, 21) (up to 1e-05 allowed)

and this is a real failure, IMO.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that the behavior changes, but IMO repeatedly converting to uint8 in the computation and thus eliminating intermediate values sounds more like a missed opportunity in the original kernel than a bug now. Thus, I would consider this more like a "bug fix" rather than a BC breaking change. On the other hand, that is not a strong opinion. Not going to block over this.

return _blend(image, mean, contrast_factor)


adjust_contrast_image_pil = _FP.adjust_contrast


Expand Down Expand Up @@ -74,7 +119,7 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float)
else:
needs_unsquash = False

output = _FT._blend(image, _FT._blurred_degenerate_image(image), sharpness_factor)
output = _blend(image, _FT._blurred_degenerate_image(image), sharpness_factor)

if needs_unsquash:
output = output.reshape(shape)
Expand Down Expand Up @@ -183,13 +228,13 @@ def autocontrast(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
return autocontrast_image_pil(inpt)


def _equalize_image_tensor_vec(img: torch.Tensor) -> torch.Tensor:
# input img shape should be [N, H, W]
shape = img.shape
def _equalize_image_tensor_vec(image: torch.Tensor) -> torch.Tensor:
# input image shape should be [N, H, W]
shape = image.shape
# Compute image histogram:
flat_img = img.flatten(start_dim=1).to(torch.long) # -> [N, H * W]
hist = flat_img.new_zeros(shape[0], 256)
hist.scatter_add_(dim=1, index=flat_img, src=flat_img.new_ones(1).expand_as(flat_img))
flat_image = image.flatten(start_dim=1).to(torch.long) # -> [N, H * W]
hist = flat_image.new_zeros(shape[0], 256)
hist.scatter_add_(dim=1, index=flat_image, src=flat_image.new_ones(1).expand_as(flat_image))

# Compute image cdf
chist = hist.cumsum_(dim=1)
Expand All @@ -213,7 +258,7 @@ def _equalize_image_tensor_vec(img: torch.Tensor) -> torch.Tensor:
zeros = lut.new_zeros((1, 1)).expand(shape[0], 1)
lut = torch.cat([zeros, lut[:, :-1]], dim=1)

return torch.where((step == 0).unsqueeze(-1), img, lut.gather(dim=1, index=flat_img).reshape_as(img))
return torch.where((step == 0).unsqueeze(-1), image, lut.gather(dim=1, index=flat_image).reshape_as(image))


def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
Expand Down
6 changes: 5 additions & 1 deletion torchvision/prototype/transforms/functional/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,11 @@ def _gray_to_rgb(grayscale: torch.Tensor) -> torch.Tensor:
return grayscale.repeat(repeats)


_rgb_to_gray = _FT.rgb_to_grayscale
def _rgb_to_gray(image: torch.Tensor) -> torch.Tensor:
r, g, b = image.unbind(dim=-3)
l_img = (0.2989 * r).add_(g, alpha=0.587).add_(b, alpha=0.114)
l_img = l_img.to(image.dtype).unsqueeze(dim=-3)
return l_img


def convert_color_space_image_tensor(
Expand Down
7 changes: 1 addition & 6 deletions torchvision/transforms/functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,12 +816,7 @@ def _blurred_degenerate_image(img: Tensor) -> Tensor:
kernel /= kernel.sum()
kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1])

result_tmp, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(
img,
[
kernel.dtype,
],
)
result_tmp, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [kernel.dtype])
result_tmp = conv2d(result_tmp, kernel, groups=result_tmp.shape[-3])
result_tmp = _cast_squeeze_out(result_tmp, need_cast, need_squeeze, out_dtype)

Expand Down