Skip to content

Commit d8d980c

Browse files
datumboxfacebook-github-bot
authored andcommitted
[fbsync] support grayscale / RGB alpha conversions (#5567)
Summary: * support grayscale / RGB alpha conversions * use _max_valu from stable * remove extra copy for PIL conversion * simplify test image generation for color spaces with alpha channel * use common _max_value in tests * replace dynamically created dicts with if/else * make color space conversion more explicit * make even more explicit * simplify alpha image generation * fix if / elif * add error for unknown conversions * rename RGBA to RGB_ALPHA * cleanup * GRAYSCALE to GRAY Reviewed By: vmoens Differential Revision: D34878979 fbshipit-source-id: 4ec3af94e73152d11017c4e9c57ded44b2076764
1 parent 01a3fe1 commit d8d980c

File tree

5 files changed

+168
-64
lines changed

5 files changed

+168
-64
lines changed

test/test_prototype_transforms.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,14 @@ def test_mixup_cutmix(self, transform, input):
102102
(
103103
transform,
104104
itertools.chain.from_iterable(
105-
fn(dtypes=[torch.uint8], extra_dims=[(4,)])
105+
fn(
106+
color_spaces=[
107+
features.ColorSpace.GRAY,
108+
features.ColorSpace.RGB,
109+
],
110+
dtypes=[torch.uint8],
111+
extra_dims=[(4,)],
112+
)
106113
for fn in [
107114
make_images,
108115
make_vanilla_tensor_images,
@@ -152,3 +159,32 @@ def test_normalize(self, transform, input):
152159
)
153160
def test_random_resized_crop(self, transform, input):
154161
transform(input)
162+
163+
@parametrize(
164+
[
165+
(
166+
transforms.ConvertImageColorSpace(color_space=new_color_space, old_color_space=old_color_space),
167+
itertools.chain.from_iterable(
168+
[
169+
fn(color_spaces=[old_color_space])
170+
for fn in (
171+
make_images,
172+
make_vanilla_tensor_images,
173+
make_pil_images,
174+
)
175+
]
176+
),
177+
)
178+
for old_color_space, new_color_space in itertools.product(
179+
[
180+
features.ColorSpace.GRAY,
181+
features.ColorSpace.GRAY_ALPHA,
182+
features.ColorSpace.RGB,
183+
features.ColorSpace.RGB_ALPHA,
184+
],
185+
repeat=2,
186+
)
187+
]
188+
)
189+
def test_convert_image_color_space(self, transform, input):
190+
transform(input)

test/test_prototype_transforms_functional.py

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,33 +7,44 @@
77
from torch import jit
88
from torch.nn.functional import one_hot
99
from torchvision.prototype import features
10+
from torchvision.transforms.functional_tensor import _max_value as get_max_value
1011

1112
make_tensor = functools.partial(torch.testing.make_tensor, device="cpu")
1213

1314

14-
def make_image(size=None, *, color_space, extra_dims=(), dtype=torch.float32):
15+
def make_image(size=None, *, color_space, extra_dims=(), dtype=torch.float32, constant_alpha=True):
1516
size = size or torch.randint(16, 33, (2,)).tolist()
1617

17-
num_channels = {
18-
features.ColorSpace.GRAYSCALE: 1,
19-
features.ColorSpace.RGB: 3,
20-
}[color_space]
18+
try:
19+
num_channels = {
20+
features.ColorSpace.GRAY: 1,
21+
features.ColorSpace.GRAY_ALPHA: 2,
22+
features.ColorSpace.RGB: 3,
23+
features.ColorSpace.RGB_ALPHA: 4,
24+
}[color_space]
25+
except KeyError as error:
26+
raise pytest.UsageError() from error
2127

2228
shape = (*extra_dims, num_channels, *size)
23-
if dtype.is_floating_point:
24-
data = torch.rand(shape, dtype=dtype)
25-
else:
26-
data = torch.randint(0, torch.iinfo(dtype).max, shape, dtype=dtype)
29+
max_value = get_max_value(dtype)
30+
data = make_tensor(shape, low=0, high=max_value, dtype=dtype)
31+
if color_space in {features.ColorSpace.GRAY_ALPHA, features.ColorSpace.RGB_ALPHA} and constant_alpha:
32+
data[..., -1, :, :] = max_value
2733
return features.Image(data, color_space=color_space)
2834

2935

30-
make_grayscale_image = functools.partial(make_image, color_space=features.ColorSpace.GRAYSCALE)
36+
make_grayscale_image = functools.partial(make_image, color_space=features.ColorSpace.GRAY)
3137
make_rgb_image = functools.partial(make_image, color_space=features.ColorSpace.RGB)
3238

3339

3440
def make_images(
3541
sizes=((16, 16), (7, 33), (31, 9)),
36-
color_spaces=(features.ColorSpace.GRAYSCALE, features.ColorSpace.RGB),
42+
color_spaces=(
43+
features.ColorSpace.GRAY,
44+
features.ColorSpace.GRAY_ALPHA,
45+
features.ColorSpace.RGB,
46+
features.ColorSpace.RGB_ALPHA,
47+
),
3748
dtypes=(torch.float32, torch.uint8),
3849
extra_dims=((4,), (2, 3)),
3950
):
@@ -48,15 +59,12 @@ def randint_with_tensor_bounds(arg1, arg2=None, **kwargs):
4859
low, high = torch.broadcast_tensors(
4960
*[torch.as_tensor(arg) for arg in ((0, arg1) if arg2 is None else (arg1, arg2))]
5061
)
51-
try:
52-
return torch.stack(
53-
[
54-
torch.randint(low_scalar, high_scalar, (), **kwargs)
55-
for low_scalar, high_scalar in zip(low.flatten().tolist(), high.flatten().tolist())
56-
]
57-
).reshape(low.shape)
58-
except RuntimeError as error:
59-
raise error
62+
return torch.stack(
63+
[
64+
torch.randint(low_scalar, high_scalar, (), **kwargs)
65+
for low_scalar, high_scalar in zip(low.flatten().tolist(), high.flatten().tolist())
66+
]
67+
).reshape(low.shape)
6068

6169

6270
def make_bounding_box(*, format, image_size=(32, 32), extra_dims=(), dtype=torch.int64):
@@ -83,8 +91,8 @@ def make_bounding_box(*, format, image_size=(32, 32), extra_dims=(), dtype=torch
8391
w = randint_with_tensor_bounds(1, torch.minimum(cx, width - cx) + 1)
8492
h = randint_with_tensor_bounds(1, torch.minimum(cy, width - cy) + 1)
8593
parts = (cx, cy, w, h)
86-
else: # format == features.BoundingBoxFormat._SENTINEL:
87-
raise ValueError()
94+
else:
95+
raise pytest.UsageError()
8896

8997
return features.BoundingBox(torch.stack(parts, dim=-1).to(dtype), format=format, image_size=image_size)
9098

torchvision/prototype/features/_image.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,23 @@
1515

1616
class ColorSpace(StrEnum):
1717
OTHER = StrEnum.auto()
18-
GRAYSCALE = StrEnum.auto()
18+
GRAY = StrEnum.auto()
19+
GRAY_ALPHA = StrEnum.auto()
1920
RGB = StrEnum.auto()
21+
RGB_ALPHA = StrEnum.auto()
22+
23+
@classmethod
24+
def from_pil_mode(cls, mode: str) -> ColorSpace:
25+
if mode == "L":
26+
return cls.GRAY
27+
elif mode == "LA":
28+
return cls.GRAY_ALPHA
29+
elif mode == "RGB":
30+
return cls.RGB
31+
elif mode == "RGBA":
32+
return cls.RGB_ALPHA
33+
else:
34+
return cls.OTHER
2035

2136

2237
class Image(_Feature):
@@ -71,13 +86,17 @@ def guess_color_space(data: torch.Tensor) -> ColorSpace:
7186
if data.ndim < 2:
7287
return ColorSpace.OTHER
7388
elif data.ndim == 2:
74-
return ColorSpace.GRAYSCALE
89+
return ColorSpace.GRAY
7590

7691
num_channels = data.shape[-3]
7792
if num_channels == 1:
78-
return ColorSpace.GRAYSCALE
93+
return ColorSpace.GRAY
94+
elif num_channels == 2:
95+
return ColorSpace.GRAY_ALPHA
7996
elif num_channels == 3:
8097
return ColorSpace.RGB
98+
elif num_channels == 4:
99+
return ColorSpace.RGB_ALPHA
81100
else:
82101
return ColorSpace.OTHER
83102

torchvision/prototype/transforms/_meta.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,11 @@ def __init__(
4848
super().__init__()
4949

5050
if isinstance(color_space, str):
51-
color_space = features.ColorSpace[color_space]
51+
color_space = features.ColorSpace.from_str(color_space)
5252
self.color_space = color_space
5353

5454
if isinstance(old_color_space, str):
55-
old_color_space = features.ColorSpace[old_color_space]
55+
old_color_space = features.ColorSpace.from_str(old_color_space)
5656
self.old_color_space = old_color_space
5757

5858
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
@@ -72,13 +72,6 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
7272
input, old_color_space=self.old_color_space, new_color_space=self.color_space
7373
)
7474
elif isinstance(input, PIL.Image.Image):
75-
old_color_space = {
76-
"L": features.ColorSpace.GRAYSCALE,
77-
"RGB": features.ColorSpace.RGB,
78-
}.get(input.mode, features.ColorSpace.OTHER)
79-
80-
return F.convert_image_color_space_pil(
81-
input, old_color_space=old_color_space, new_color_space=self.color_space
82-
)
75+
return F.convert_image_color_space_pil(input, color_space=self.color_space)
8376
else:
8477
return input

torchvision/prototype/transforms/functional/_meta.py

Lines changed: 76 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
from typing import Tuple, Optional
2+
13
import PIL.Image
24
import torch
35
from torchvision.prototype.features import BoundingBoxFormat, ColorSpace
46
from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP
57

6-
78
get_dimensions_image_tensor = _FT.get_dimensions
89
get_dimensions_image_pil = _FP.get_dimensions
910

@@ -57,41 +58,88 @@ def convert_bounding_box_format(
5758
return bounding_box
5859

5960

60-
def _grayscale_to_rgb_tensor(grayscale: torch.Tensor) -> torch.Tensor:
61-
repeats = [1] * grayscale.ndim
62-
repeats[-3] = 3
63-
return grayscale.repeat(repeats)
61+
def _split_alpha(image: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
62+
return image[..., :-1, :, :], image[..., -1:, :, :]
6463

6564

66-
def convert_image_color_space_tensor(
67-
image: torch.Tensor, old_color_space: ColorSpace, new_color_space: ColorSpace
68-
) -> torch.Tensor:
69-
if new_color_space == old_color_space:
70-
return image.clone()
65+
def _strip_alpha(image: torch.Tensor) -> torch.Tensor:
66+
image, alpha = _split_alpha(image)
67+
if not torch.all(alpha == _FT._max_value(alpha.dtype)):
68+
raise RuntimeError(
69+
"Stripping the alpha channel if it contains values other than the max value is not supported."
70+
)
71+
return image
7172

72-
if old_color_space == ColorSpace.GRAYSCALE:
73-
image = _grayscale_to_rgb_tensor(image)
7473

75-
if new_color_space == ColorSpace.GRAYSCALE:
76-
image = _FT.rgb_to_grayscale(image)
74+
def _add_alpha(image: torch.Tensor, alpha: Optional[torch.Tensor] = None) -> torch.Tensor:
75+
if alpha is None:
76+
shape = list(image.shape)
77+
shape[-3] = 1
78+
alpha = torch.full(shape, _FT._max_value(image.dtype), dtype=image.dtype, device=image.device)
79+
return torch.cat((image, alpha), dim=-3)
7780

78-
return image
7981

80-
81-
def _grayscale_to_rgb_pil(grayscale: PIL.Image.Image) -> PIL.Image.Image:
82-
return grayscale.convert("RGB")
82+
def _gray_to_rgb(grayscale: torch.Tensor) -> torch.Tensor:
83+
repeats = [1] * grayscale.ndim
84+
repeats[-3] = 3
85+
return grayscale.repeat(repeats)
8386

8487

85-
def convert_image_color_space_pil(
86-
image: PIL.Image.Image, old_color_space: ColorSpace, new_color_space: ColorSpace
87-
) -> PIL.Image.Image:
88-
if new_color_space == old_color_space:
89-
return image.copy()
88+
_rgb_to_gray = _FT.rgb_to_grayscale
9089

91-
if old_color_space == ColorSpace.GRAYSCALE:
92-
image = _grayscale_to_rgb_pil(image)
9390

94-
if new_color_space == ColorSpace.GRAYSCALE:
95-
image = _FP.to_grayscale(image)
91+
def convert_image_color_space_tensor(
92+
image: torch.Tensor, old_color_space: ColorSpace, new_color_space: ColorSpace
93+
) -> torch.Tensor:
94+
if new_color_space == old_color_space:
95+
return image.clone()
9696

97-
return image
97+
if old_color_space == ColorSpace.OTHER or new_color_space == ColorSpace.OTHER:
98+
raise RuntimeError(f"Conversion to or from {ColorSpace.OTHER} is not supported.")
99+
100+
if old_color_space == ColorSpace.GRAY and new_color_space == ColorSpace.GRAY_ALPHA:
101+
return _add_alpha(image)
102+
elif old_color_space == ColorSpace.GRAY and new_color_space == ColorSpace.RGB:
103+
return _gray_to_rgb(image)
104+
elif old_color_space == ColorSpace.GRAY and new_color_space == ColorSpace.RGB_ALPHA:
105+
return _add_alpha(_gray_to_rgb(image))
106+
elif old_color_space == ColorSpace.GRAY_ALPHA and new_color_space == ColorSpace.GRAY:
107+
return _strip_alpha(image)
108+
elif old_color_space == ColorSpace.GRAY_ALPHA and new_color_space == ColorSpace.RGB:
109+
return _gray_to_rgb(_strip_alpha(image))
110+
elif old_color_space == ColorSpace.GRAY_ALPHA and new_color_space == ColorSpace.RGB_ALPHA:
111+
image, alpha = _split_alpha(image)
112+
return _add_alpha(_gray_to_rgb(image), alpha)
113+
elif old_color_space == ColorSpace.RGB and new_color_space == ColorSpace.GRAY:
114+
return _rgb_to_gray(image)
115+
elif old_color_space == ColorSpace.RGB and new_color_space == ColorSpace.GRAY_ALPHA:
116+
return _add_alpha(_rgb_to_gray(image))
117+
elif old_color_space == ColorSpace.RGB and new_color_space == ColorSpace.RGB_ALPHA:
118+
return _add_alpha(image)
119+
elif old_color_space == ColorSpace.RGB_ALPHA and new_color_space == ColorSpace.GRAY:
120+
return _rgb_to_gray(_strip_alpha(image))
121+
elif old_color_space == ColorSpace.RGB_ALPHA and new_color_space == ColorSpace.GRAY_ALPHA:
122+
image, alpha = _split_alpha(image)
123+
return _add_alpha(_rgb_to_gray(image), alpha)
124+
elif old_color_space == ColorSpace.RGB_ALPHA and new_color_space == ColorSpace.RGB:
125+
return _strip_alpha(image)
126+
else:
127+
raise RuntimeError(f"Conversion from {old_color_space} to {new_color_space} is not supported.")
128+
129+
130+
_COLOR_SPACE_TO_PIL_MODE = {
131+
ColorSpace.GRAY: "L",
132+
ColorSpace.GRAY_ALPHA: "LA",
133+
ColorSpace.RGB: "RGB",
134+
ColorSpace.RGB_ALPHA: "RGBA",
135+
}
136+
137+
138+
def convert_image_color_space_pil(image: PIL.Image.Image, color_space: ColorSpace) -> PIL.Image.Image:
139+
old_mode = image.mode
140+
try:
141+
new_mode = _COLOR_SPACE_TO_PIL_MODE[color_space]
142+
except KeyError:
143+
raise ValueError(f"Conversion from {ColorSpace.from_pil_mode(old_mode)} to {color_space} is not supported.")
144+
145+
return image.convert(new_mode)

0 commit comments

Comments
 (0)