|
| 1 | +from typing import Tuple, Optional |
| 2 | + |
1 | 3 | import PIL.Image
|
2 | 4 | import torch
|
3 | 5 | from torchvision.prototype.features import BoundingBoxFormat, ColorSpace
|
4 | 6 | from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP
|
5 | 7 |
|
6 |
| - |
7 | 8 | get_dimensions_image_tensor = _FT.get_dimensions
|
8 | 9 | get_dimensions_image_pil = _FP.get_dimensions
|
9 | 10 |
|
@@ -57,41 +58,88 @@ def convert_bounding_box_format(
|
57 | 58 | return bounding_box
|
58 | 59 |
|
59 | 60 |
|
60 |
| -def _grayscale_to_rgb_tensor(grayscale: torch.Tensor) -> torch.Tensor: |
61 |
| - repeats = [1] * grayscale.ndim |
62 |
| - repeats[-3] = 3 |
63 |
| - return grayscale.repeat(repeats) |
| 61 | +def _split_alpha(image: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| 62 | + return image[..., :-1, :, :], image[..., -1:, :, :] |
64 | 63 |
|
65 | 64 |
|
66 |
| -def convert_image_color_space_tensor( |
67 |
| - image: torch.Tensor, old_color_space: ColorSpace, new_color_space: ColorSpace |
68 |
| -) -> torch.Tensor: |
69 |
| - if new_color_space == old_color_space: |
70 |
| - return image.clone() |
| 65 | +def _strip_alpha(image: torch.Tensor) -> torch.Tensor: |
| 66 | + image, alpha = _split_alpha(image) |
| 67 | + if not torch.all(alpha == _FT._max_value(alpha.dtype)): |
| 68 | + raise RuntimeError( |
| 69 | + "Stripping the alpha channel if it contains values other than the max value is not supported." |
| 70 | + ) |
| 71 | + return image |
71 | 72 |
|
72 |
| - if old_color_space == ColorSpace.GRAYSCALE: |
73 |
| - image = _grayscale_to_rgb_tensor(image) |
74 | 73 |
|
75 |
| - if new_color_space == ColorSpace.GRAYSCALE: |
76 |
| - image = _FT.rgb_to_grayscale(image) |
| 74 | +def _add_alpha(image: torch.Tensor, alpha: Optional[torch.Tensor] = None) -> torch.Tensor: |
| 75 | + if alpha is None: |
| 76 | + shape = list(image.shape) |
| 77 | + shape[-3] = 1 |
| 78 | + alpha = torch.full(shape, _FT._max_value(image.dtype), dtype=image.dtype, device=image.device) |
| 79 | + return torch.cat((image, alpha), dim=-3) |
77 | 80 |
|
78 |
| - return image |
79 | 81 |
|
80 |
| - |
81 |
| -def _grayscale_to_rgb_pil(grayscale: PIL.Image.Image) -> PIL.Image.Image: |
82 |
| - return grayscale.convert("RGB") |
| 82 | +def _gray_to_rgb(grayscale: torch.Tensor) -> torch.Tensor: |
| 83 | + repeats = [1] * grayscale.ndim |
| 84 | + repeats[-3] = 3 |
| 85 | + return grayscale.repeat(repeats) |
83 | 86 |
|
84 | 87 |
|
85 |
| -def convert_image_color_space_pil( |
86 |
| - image: PIL.Image.Image, old_color_space: ColorSpace, new_color_space: ColorSpace |
87 |
| -) -> PIL.Image.Image: |
88 |
| - if new_color_space == old_color_space: |
89 |
| - return image.copy() |
| 88 | +_rgb_to_gray = _FT.rgb_to_grayscale |
90 | 89 |
|
91 |
| - if old_color_space == ColorSpace.GRAYSCALE: |
92 |
| - image = _grayscale_to_rgb_pil(image) |
93 | 90 |
|
94 |
| - if new_color_space == ColorSpace.GRAYSCALE: |
95 |
| - image = _FP.to_grayscale(image) |
| 91 | +def convert_image_color_space_tensor( |
| 92 | + image: torch.Tensor, old_color_space: ColorSpace, new_color_space: ColorSpace |
| 93 | +) -> torch.Tensor: |
| 94 | + if new_color_space == old_color_space: |
| 95 | + return image.clone() |
96 | 96 |
|
97 |
| - return image |
| 97 | + if old_color_space == ColorSpace.OTHER or new_color_space == ColorSpace.OTHER: |
| 98 | + raise RuntimeError(f"Conversion to or from {ColorSpace.OTHER} is not supported.") |
| 99 | + |
| 100 | + if old_color_space == ColorSpace.GRAY and new_color_space == ColorSpace.GRAY_ALPHA: |
| 101 | + return _add_alpha(image) |
| 102 | + elif old_color_space == ColorSpace.GRAY and new_color_space == ColorSpace.RGB: |
| 103 | + return _gray_to_rgb(image) |
| 104 | + elif old_color_space == ColorSpace.GRAY and new_color_space == ColorSpace.RGB_ALPHA: |
| 105 | + return _add_alpha(_gray_to_rgb(image)) |
| 106 | + elif old_color_space == ColorSpace.GRAY_ALPHA and new_color_space == ColorSpace.GRAY: |
| 107 | + return _strip_alpha(image) |
| 108 | + elif old_color_space == ColorSpace.GRAY_ALPHA and new_color_space == ColorSpace.RGB: |
| 109 | + return _gray_to_rgb(_strip_alpha(image)) |
| 110 | + elif old_color_space == ColorSpace.GRAY_ALPHA and new_color_space == ColorSpace.RGB_ALPHA: |
| 111 | + image, alpha = _split_alpha(image) |
| 112 | + return _add_alpha(_gray_to_rgb(image), alpha) |
| 113 | + elif old_color_space == ColorSpace.RGB and new_color_space == ColorSpace.GRAY: |
| 114 | + return _rgb_to_gray(image) |
| 115 | + elif old_color_space == ColorSpace.RGB and new_color_space == ColorSpace.GRAY_ALPHA: |
| 116 | + return _add_alpha(_rgb_to_gray(image)) |
| 117 | + elif old_color_space == ColorSpace.RGB and new_color_space == ColorSpace.RGB_ALPHA: |
| 118 | + return _add_alpha(image) |
| 119 | + elif old_color_space == ColorSpace.RGB_ALPHA and new_color_space == ColorSpace.GRAY: |
| 120 | + return _rgb_to_gray(_strip_alpha(image)) |
| 121 | + elif old_color_space == ColorSpace.RGB_ALPHA and new_color_space == ColorSpace.GRAY_ALPHA: |
| 122 | + image, alpha = _split_alpha(image) |
| 123 | + return _add_alpha(_rgb_to_gray(image), alpha) |
| 124 | + elif old_color_space == ColorSpace.RGB_ALPHA and new_color_space == ColorSpace.RGB: |
| 125 | + return _strip_alpha(image) |
| 126 | + else: |
| 127 | + raise RuntimeError(f"Conversion from {old_color_space} to {new_color_space} is not supported.") |
| 128 | + |
| 129 | + |
| 130 | +_COLOR_SPACE_TO_PIL_MODE = { |
| 131 | + ColorSpace.GRAY: "L", |
| 132 | + ColorSpace.GRAY_ALPHA: "LA", |
| 133 | + ColorSpace.RGB: "RGB", |
| 134 | + ColorSpace.RGB_ALPHA: "RGBA", |
| 135 | +} |
| 136 | + |
| 137 | + |
| 138 | +def convert_image_color_space_pil(image: PIL.Image.Image, color_space: ColorSpace) -> PIL.Image.Image: |
| 139 | + old_mode = image.mode |
| 140 | + try: |
| 141 | + new_mode = _COLOR_SPACE_TO_PIL_MODE[color_space] |
| 142 | + except KeyError: |
| 143 | + raise ValueError(f"Conversion from {ColorSpace.from_pil_mode(old_mode)} to {color_space} is not supported.") |
| 144 | + |
| 145 | + return image.convert(new_mode) |
0 commit comments