Skip to content

Commit 12d110c

Browse files
datumboxfacebook-github-bot
authored andcommitted
[fbsync] port RandomPhotoMetricDistort to prototype transforms (#5663)
Reviewed By: NicolasHug Differential Revision: D35393168 fbshipit-source-id: 2a81cb6eaf15a3082826940f1aae14cd862bbc35
1 parent e17a323 commit 12d110c

File tree

2 files changed

+68
-2
lines changed

2 files changed

+68
-2
lines changed

torchvision/prototype/transforms/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from ._augment import RandomErasing, RandomMixup, RandomCutmix
66
from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment, AugMix
7-
from ._color import ColorJitter
7+
from ._color import ColorJitter, RandomPhotometricDistort
88
from ._container import Compose, RandomApply, RandomChoice, RandomOrder
99
from ._geometry import (
1010
Resize,

torchvision/prototype/transforms/_color.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
import torch
77
from torchvision.prototype import features
88
from torchvision.prototype.transforms import Transform, functional as F
9+
from torchvision.transforms import functional as _F
910

10-
from ._utils import is_simple_tensor
11+
from ._utils import is_simple_tensor, get_image_dimensions, query_image
1112

1213
T = TypeVar("T", features.Image, torch.Tensor, PIL.Image.Image)
1314

@@ -120,5 +121,70 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
120121

121122
for transform in params["image_transforms"]:
122123
input = transform(input)
124+
return input
125+
126+
127+
class _RandomChannelShuffle(Transform):
128+
def _get_params(self, sample: Any) -> Dict[str, Any]:
129+
image = query_image(sample)
130+
num_channels, _, _ = get_image_dimensions(image)
131+
return dict(permutation=torch.randperm(num_channels))
132+
133+
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
134+
if not (isinstance(input, (features.Image, PIL.Image.Image)) or is_simple_tensor(input)):
135+
return input
136+
137+
image = input
138+
if isinstance(input, PIL.Image.Image):
139+
image = _F.pil_to_tensor(image)
140+
141+
output = image[..., params["permutation"], :, :]
142+
143+
if isinstance(input, features.Image):
144+
output = features.Image.new_like(input, output, color_space=features.ColorSpace.OTHER)
145+
elif isinstance(input, PIL.Image.Image):
146+
output = _F.to_pil_image(output)
147+
148+
return output
123149

150+
151+
class RandomPhotometricDistort(Transform):
152+
def __init__(
153+
self,
154+
contrast: Tuple[float, float] = (0.5, 1.5),
155+
saturation: Tuple[float, float] = (0.5, 1.5),
156+
hue: Tuple[float, float] = (-0.05, 0.05),
157+
brightness: Tuple[float, float] = (0.875, 1.125),
158+
p: float = 0.5,
159+
):
160+
super().__init__()
161+
self._brightness = ColorJitter(brightness=brightness)
162+
self._contrast = ColorJitter(contrast=contrast)
163+
self._hue = ColorJitter(hue=hue)
164+
self._saturation = ColorJitter(saturation=saturation)
165+
self._channel_shuffle = _RandomChannelShuffle()
166+
self.p = p
167+
168+
def _get_params(self, sample: Any) -> Dict[str, Any]:
169+
return dict(
170+
zip(
171+
["brightness", "contrast1", "saturation", "hue", "contrast2", "channel_shuffle"],
172+
torch.rand(6) < self.p,
173+
),
174+
contrast_before=torch.rand(()) < 0.5,
175+
)
176+
177+
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
178+
if params["brightness"]:
179+
input = self._brightness(input)
180+
if params["contrast1"] and params["contrast_before"]:
181+
input = self._contrast(input)
182+
if params["saturation"]:
183+
input = self._saturation(input)
184+
if params["saturation"]:
185+
input = self._saturation(input)
186+
if params["contrast2"] and not params["contrast_before"]:
187+
input = self._contrast(input)
188+
if params["channel_shuffle"]:
189+
input = self._channel_shuffle(input)
124190
return input

0 commit comments

Comments
 (0)