|
6 | 6 | import torch
|
7 | 7 | from torchvision.prototype import features
|
8 | 8 | from torchvision.prototype.transforms import Transform, functional as F
|
| 9 | +from torchvision.transforms import functional as _F |
9 | 10 |
|
10 |
| -from ._utils import is_simple_tensor |
| 11 | +from ._utils import is_simple_tensor, get_image_dimensions, query_image |
11 | 12 |
|
12 | 13 | T = TypeVar("T", features.Image, torch.Tensor, PIL.Image.Image)
|
13 | 14 |
|
@@ -120,5 +121,70 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
|
120 | 121 |
|
121 | 122 | for transform in params["image_transforms"]:
|
122 | 123 | input = transform(input)
|
| 124 | + return input |
| 125 | + |
| 126 | + |
| 127 | +class _RandomChannelShuffle(Transform): |
| 128 | + def _get_params(self, sample: Any) -> Dict[str, Any]: |
| 129 | + image = query_image(sample) |
| 130 | + num_channels, _, _ = get_image_dimensions(image) |
| 131 | + return dict(permutation=torch.randperm(num_channels)) |
| 132 | + |
| 133 | + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: |
| 134 | + if not (isinstance(input, (features.Image, PIL.Image.Image)) or is_simple_tensor(input)): |
| 135 | + return input |
| 136 | + |
| 137 | + image = input |
| 138 | + if isinstance(input, PIL.Image.Image): |
| 139 | + image = _F.pil_to_tensor(image) |
| 140 | + |
| 141 | + output = image[..., params["permutation"], :, :] |
| 142 | + |
| 143 | + if isinstance(input, features.Image): |
| 144 | + output = features.Image.new_like(input, output, color_space=features.ColorSpace.OTHER) |
| 145 | + elif isinstance(input, PIL.Image.Image): |
| 146 | + output = _F.to_pil_image(output) |
| 147 | + |
| 148 | + return output |
123 | 149 |
|
| 150 | + |
| 151 | +class RandomPhotometricDistort(Transform): |
| 152 | + def __init__( |
| 153 | + self, |
| 154 | + contrast: Tuple[float, float] = (0.5, 1.5), |
| 155 | + saturation: Tuple[float, float] = (0.5, 1.5), |
| 156 | + hue: Tuple[float, float] = (-0.05, 0.05), |
| 157 | + brightness: Tuple[float, float] = (0.875, 1.125), |
| 158 | + p: float = 0.5, |
| 159 | + ): |
| 160 | + super().__init__() |
| 161 | + self._brightness = ColorJitter(brightness=brightness) |
| 162 | + self._contrast = ColorJitter(contrast=contrast) |
| 163 | + self._hue = ColorJitter(hue=hue) |
| 164 | + self._saturation = ColorJitter(saturation=saturation) |
| 165 | + self._channel_shuffle = _RandomChannelShuffle() |
| 166 | + self.p = p |
| 167 | + |
| 168 | + def _get_params(self, sample: Any) -> Dict[str, Any]: |
| 169 | + return dict( |
| 170 | + zip( |
| 171 | + ["brightness", "contrast1", "saturation", "hue", "contrast2", "channel_shuffle"], |
| 172 | + torch.rand(6) < self.p, |
| 173 | + ), |
| 174 | + contrast_before=torch.rand(()) < 0.5, |
| 175 | + ) |
| 176 | + |
| 177 | + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: |
| 178 | + if params["brightness"]: |
| 179 | + input = self._brightness(input) |
| 180 | + if params["contrast1"] and params["contrast_before"]: |
| 181 | + input = self._contrast(input) |
| 182 | + if params["saturation"]: |
| 183 | + input = self._saturation(input) |
| 184 | + if params["saturation"]: |
| 185 | + input = self._saturation(input) |
| 186 | + if params["contrast2"] and not params["contrast_before"]: |
| 187 | + input = self._contrast(input) |
| 188 | + if params["channel_shuffle"]: |
| 189 | + input = self._channel_shuffle(input) |
124 | 190 | return input
|
0 commit comments