Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions test/test_prototype_transforms_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,9 +254,10 @@ def __init__(
legacy_transforms.RandomAdjustSharpness,
[
ArgsKwargs(p=0, sharpness_factor=0.5),
ArgsKwargs(p=1, sharpness_factor=0.3),
ArgsKwargs(p=1, sharpness_factor=0.2),
ArgsKwargs(p=1, sharpness_factor=0.99),
],
closeness_kwargs={"atol": 1e-6, "rtol": 1e-6},
),
ConsistencyConfig(
prototype_transforms.RandomGrayscale,
Expand Down Expand Up @@ -306,8 +307,9 @@ def __init__(
ArgsKwargs(saturation=(0.8, 0.9)),
ArgsKwargs(hue=0.3),
ArgsKwargs(hue=(-0.1, 0.2)),
ArgsKwargs(brightness=0.1, contrast=0.4, saturation=0.7, hue=0.3),
ArgsKwargs(brightness=0.1, contrast=0.4, saturation=0.5, hue=0.6),
],
closeness_kwargs={"atol": 1e-5, "rtol": 1e-5},
),
*[
ConsistencyConfig(
Expand Down Expand Up @@ -753,7 +755,7 @@ def test_randaug(self, inpt, interpolation, mocker):
expected_output = t_ref(inpt)
output = t(inpt)

assert_equal(expected_output, output)
assert_close(expected_output, output, atol=1, rtol=0.1)

@pytest.mark.parametrize(
"inpt",
Expand Down Expand Up @@ -801,7 +803,7 @@ def test_trivial_aug(self, inpt, interpolation, mocker):
expected_output = t_ref(inpt)
output = t(inpt)

assert_equal(expected_output, output)
assert_close(expected_output, output, atol=1, rtol=0.1)

@pytest.mark.parametrize(
"inpt",
Expand Down
58 changes: 53 additions & 5 deletions torchvision/prototype/transforms/functional/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,29 @@
from torchvision.prototype import features
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT

from ._meta import get_dimensions_image_tensor
from ._meta import get_dimensions_image_tensor, get_num_channels_image_tensor


def _blend(img1: torch.Tensor, img2: torch.Tensor, ratio: float) -> torch.Tensor:
ratio = float(ratio)
fp = img1.is_floating_point()
bound = 1.0 if fp else 255.0
output = img1.mul(ratio).add_(img2, alpha=(1.0 - ratio)).clamp_(0, bound)
return output if fp else output.to(img1.dtype)


def adjust_brightness_image_tensor(img: torch.Tensor, brightness_factor: float) -> torch.Tensor:
if brightness_factor < 0:
raise ValueError(f"brightness_factor ({brightness_factor}) is not non-negative.")

_FT._assert_channels(img, [1, 3])

fp = img.is_floating_point()
bound = 1.0 if fp else 255.0
output = img.mul(brightness_factor).clamp_(0, bound)
return output if fp else output.to(img.dtype)


adjust_brightness_image_tensor = _FT.adjust_brightness
adjust_brightness_image_pil = _FP.adjust_brightness


Expand All @@ -21,7 +41,20 @@ def adjust_brightness(inpt: features.InputTypeJIT, brightness_factor: float) ->
return adjust_brightness_image_pil(inpt, brightness_factor=brightness_factor)


adjust_saturation_image_tensor = _FT.adjust_saturation
def adjust_saturation_image_tensor(img: torch.Tensor, saturation_factor: float) -> torch.Tensor:
if saturation_factor < 0:
raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.")

c = get_num_channels_image_tensor(img)
if c not in [1, 3]:
raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}")

if c == 1: # Match PIL behaviour
return img

return _blend(img, _FT.rgb_to_grayscale(img), saturation_factor)


adjust_saturation_image_pil = _FP.adjust_saturation


Expand All @@ -38,7 +71,22 @@ def adjust_saturation(inpt: features.InputTypeJIT, saturation_factor: float) ->
return adjust_saturation_image_pil(inpt, saturation_factor=saturation_factor)


adjust_contrast_image_tensor = _FT.adjust_contrast
def adjust_contrast_image_tensor(img: torch.Tensor, contrast_factor: float) -> torch.Tensor:
if contrast_factor < 0:
raise ValueError(f"contrast_factor ({contrast_factor}) is not non-negative.")

c = get_num_channels_image_tensor(img)
if c not in [1, 3]:
raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}")
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
if c == 3:
mean = torch.mean(_FT.rgb_to_grayscale(img).to(dtype), dim=(-3, -2, -1), keepdim=True)
else:
mean = torch.mean(img.to(dtype), dim=(-3, -2, -1), keepdim=True)

return _blend(img, mean, contrast_factor)


adjust_contrast_image_pil = _FP.adjust_contrast


Expand Down Expand Up @@ -74,7 +122,7 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float)
else:
needs_unsquash = False

output = _FT._blend(image, _FT._blurred_degenerate_image(image), sharpness_factor)
output = _blend(image, _FT._blurred_degenerate_image(image), sharpness_factor)

if needs_unsquash:
output = output.reshape(shape)
Expand Down
7 changes: 1 addition & 6 deletions torchvision/transforms/functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,12 +816,7 @@ def _blurred_degenerate_image(img: Tensor) -> Tensor:
kernel /= kernel.sum()
kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1])

result_tmp, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(
img,
[
kernel.dtype,
],
)
result_tmp, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [kernel.dtype])
result_tmp = conv2d(result_tmp, kernel, groups=result_tmp.shape[-3])
result_tmp = _cast_squeeze_out(result_tmp, need_cast, need_squeeze, out_dtype)

Expand Down