Skip to content

Commit f6086dd

Browse files
prabhat00155datumbox
authored andcommitted
[fbsync] Prototype transforms cleanup (#5504)
Summary: * fix grayscale to RGB for batches * make unsupported types in auto augment a parameter * make auto augment kwargs explicit * add missing error message * add support for specifying probabilites on RandomChoice * remove TODO for deprecating p on random transforms * streamline sample type checking * address comments * split image_size into height and width in auto augment Reviewed By: datumbox Differential Revision: D34579511 fbshipit-source-id: 757663a5a77f229cd1592b4c23dc17c7e8fe4807 Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent 7998be9 commit f6086dd

File tree

7 files changed

+184
-104
lines changed

7 files changed

+184
-104
lines changed

torchvision/prototype/transforms/_augment.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from torchvision.prototype import features
88
from torchvision.prototype.transforms import Transform, functional as F
99

10-
from ._utils import query_image, get_image_dimensions
10+
from ._utils import query_image, get_image_dimensions, has_all, has_any
1111

1212

1313
class RandomErasing(Transform):
@@ -33,7 +33,6 @@ def __init__(
3333
raise ValueError("Scale should be between 0 and 1")
3434
if p < 0 or p > 1:
3535
raise ValueError("Random erasing probability should be between 0 and 1")
36-
# TODO: deprecate p in favor of wrapping the transform in a RandomApply
3736
self.p = p
3837
self.scale = scale
3938
self.ratio = ratio
@@ -88,9 +87,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
8887
return dict(zip("ijhwv", (i, j, h, w, v)))
8988

9089
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
91-
if isinstance(input, (features.BoundingBox, features.SegmentationMask)):
92-
raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()")
93-
elif isinstance(input, features.Image):
90+
if isinstance(input, features.Image):
9491
output = F.erase_image_tensor(input, **params)
9592
return features.Image.new_like(input, output)
9693
elif isinstance(input, torch.Tensor):
@@ -99,10 +96,13 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
9996
return input
10097

10198
def forward(self, *inputs: Any) -> Any:
102-
if torch.rand(1) >= self.p:
103-
return inputs if len(inputs) > 1 else inputs[0]
99+
sample = inputs if len(inputs) > 1 else inputs[0]
100+
if has_any(sample, features.BoundingBox, features.SegmentationMask):
101+
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()")
102+
elif torch.rand(1) >= self.p:
103+
return sample
104104

105-
return super().forward(*inputs)
105+
return super().forward(sample)
106106

107107

108108
class RandomMixup(Transform):
@@ -115,9 +115,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
115115
return dict(lam=float(self._dist.sample(())))
116116

117117
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
118-
if isinstance(input, (features.BoundingBox, features.SegmentationMask)):
119-
raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()")
120-
elif isinstance(input, features.Image):
118+
if isinstance(input, features.Image):
121119
output = F.mixup_image_tensor(input, **params)
122120
return features.Image.new_like(input, output)
123121
elif isinstance(input, features.OneHotLabel):
@@ -126,6 +124,14 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
126124
else:
127125
return input
128126

127+
def forward(self, *inputs: Any) -> Any:
128+
sample = inputs if len(inputs) > 1 else inputs[0]
129+
if has_any(sample, features.BoundingBox, features.SegmentationMask):
130+
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()")
131+
elif not has_all(sample, features.Image, features.OneHotLabel):
132+
raise TypeError(f"{type(self).__name__}() is only defined for Image's *and* OneHotLabel's.")
133+
return super().forward(sample)
134+
129135

130136
class RandomCutmix(Transform):
131137
def __init__(self, *, alpha: float) -> None:
@@ -157,13 +163,19 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
157163
return dict(box=box, lam_adjusted=lam_adjusted)
158164

159165
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
160-
if isinstance(input, (features.BoundingBox, features.SegmentationMask)):
161-
raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()")
162-
elif isinstance(input, features.Image):
166+
if isinstance(input, features.Image):
163167
output = F.cutmix_image_tensor(input, box=params["box"])
164168
return features.Image.new_like(input, output)
165169
elif isinstance(input, features.OneHotLabel):
166170
output = F.cutmix_one_hot_label(input, lam_adjusted=params["lam_adjusted"])
167171
return features.OneHotLabel.new_like(input, output)
168172
else:
169173
return input
174+
175+
def forward(self, *inputs: Any) -> Any:
176+
sample = inputs if len(inputs) > 1 else inputs[0]
177+
if has_any(sample, features.BoundingBox, features.SegmentationMask):
178+
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()")
179+
elif not has_all(sample, features.Image, features.OneHotLabel):
180+
raise TypeError(f"{type(self).__name__}() is only defined for Image's *and* OneHotLabel's.")
181+
return super().forward(sample)

0 commit comments

Comments
 (0)