Skip to content

Commit ace22a5

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] [proto] Added some transformations and fixed type hints (#6245)
Summary: * Another attempt to add transforms * Fixed padding type hint * Fixed fill arg for pad and rotate, affine * code formatting and type hints for affine transformation * Fixed flake8 * Updated tests to save and load transforms * Fixed code formatting issue * Fixed jit loading issue * Restored fill default value to None Updated code according to the review * Added tests for rotation, affine and zoom transforms * Put back commented code * Random erase bypass boxes and masks Go back with if-return/elif-return/else-return * Fixed acceptable and non-acceptable types for Cutmix/Mixup * Updated conditions for _BaseMixupCutmix Reviewed By: jdsgomes Differential Revision: D37993418 fbshipit-source-id: 900faa217ce7cc0297eaa62671fe0518e8b4bc83
1 parent f113498 commit ace22a5

17 files changed

+497
-498
lines changed

test/test_functional_tensor.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -955,18 +955,7 @@ def test_adjust_gamma(device, dtype, config, channels):
955955

956956
@pytest.mark.parametrize("device", cpu_and_gpu())
957957
@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
958-
@pytest.mark.parametrize(
959-
"pad",
960-
[
961-
2,
962-
[
963-
3,
964-
],
965-
[0, 3],
966-
(3, 3),
967-
[4, 2, 4, 3],
968-
],
969-
)
958+
@pytest.mark.parametrize("pad", [2, [3], [0, 3], (3, 3), [4, 2, 4, 3]])
970959
@pytest.mark.parametrize(
971960
"config",
972961
[

test/test_prototype_transforms.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@
33
import pytest
44
import torch
55
from common_utils import assert_equal
6-
from test_prototype_transforms_functional import make_images, make_bounding_boxes, make_one_hot_labels
6+
from test_prototype_transforms_functional import (
7+
make_images,
8+
make_bounding_boxes,
9+
make_one_hot_labels,
10+
)
711
from torchvision.prototype import transforms, features
812
from torchvision.transforms.functional import to_pil_image, pil_to_tensor
913

@@ -72,6 +76,9 @@ class TestSmoke:
7276
transforms.ConvertImageDtype(),
7377
transforms.RandomHorizontalFlip(),
7478
transforms.Pad(5),
79+
transforms.RandomZoomOut(),
80+
transforms.RandomRotation(degrees=(-45, 45)),
81+
transforms.RandomAffine(degrees=(-45, 45)),
7582
)
7683
def test_common(self, transform, input):
7784
transform(input)

test/test_prototype_transforms_functional.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ def rotate_image_tensor():
317317
[-87, 15, 90], # angle
318318
[True, False], # expand
319319
[None, [12, 23]], # center
320-
[None, [128]], # fill
320+
[None, [128], [12.0]], # fill
321321
):
322322
if center is not None and expand:
323323
# Skip warning: The provided center argument is ignored if expand is True

torchvision/prototype/features/_bounding_box.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -128,13 +128,20 @@ def resized_crop(
128128
return BoundingBox.new_like(self, output, image_size=image_size, dtype=output.dtype)
129129

130130
def pad(
131-
self, padding: List[int], fill: Union[int, float, Sequence[float]] = 0, padding_mode: str = "constant"
131+
self,
132+
padding: Union[int, Sequence[int]],
133+
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
134+
padding_mode: str = "constant",
132135
) -> BoundingBox:
133136
from torchvision.prototype.transforms import functional as _F
134137

135138
if padding_mode not in ["constant"]:
136139
raise ValueError(f"Padding mode '{padding_mode}' is not supported with bounding boxes")
137140

141+
# This cast does Sequence[int] -> List[int] and is required to make mypy happy
142+
if not isinstance(padding, int):
143+
padding = list(padding)
144+
138145
output = _F.pad_bounding_box(self, padding, format=self.format)
139146

140147
# Update output image size:
@@ -153,7 +160,7 @@ def rotate(
153160
angle: float,
154161
interpolation: InterpolationMode = InterpolationMode.NEAREST,
155162
expand: bool = False,
156-
fill: Optional[List[float]] = None,
163+
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
157164
center: Optional[List[float]] = None,
158165
) -> BoundingBox:
159166
from torchvision.prototype.transforms import functional as _F
@@ -173,7 +180,7 @@ def affine(
173180
scale: float,
174181
shear: List[float],
175182
interpolation: InterpolationMode = InterpolationMode.NEAREST,
176-
fill: Optional[List[float]] = None,
183+
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
177184
center: Optional[List[float]] = None,
178185
) -> BoundingBox:
179186
from torchvision.prototype.transforms import functional as _F
@@ -194,18 +201,9 @@ def perspective(
194201
self,
195202
perspective_coeffs: List[float],
196203
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
197-
fill: Optional[List[float]] = None,
204+
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
198205
) -> BoundingBox:
199206
from torchvision.prototype.transforms import functional as _F
200207

201208
output = _F.perspective_bounding_box(self, self.format, perspective_coeffs)
202209
return BoundingBox.new_like(self, output, dtype=output.dtype)
203-
204-
def erase(self, i: int, j: int, h: int, w: int, v: torch.Tensor) -> BoundingBox:
205-
raise TypeError("Erase transformation does not support bounding boxes")
206-
207-
def mixup(self, lam: float) -> BoundingBox:
208-
raise TypeError("Mixup transformation does not support bounding boxes")
209-
210-
def cutmix(self, box: Tuple[int, int, int, int], lam_adjusted: float) -> BoundingBox:
211-
raise TypeError("Cutmix transformation does not support bounding boxes")

torchvision/prototype/features/_feature.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,10 @@ def resized_crop(
120120
return self
121121

122122
def pad(
123-
self, padding: List[int], fill: Union[int, float, Sequence[float]] = 0, padding_mode: str = "constant"
123+
self,
124+
padding: Union[int, Sequence[int]],
125+
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
126+
padding_mode: str = "constant",
124127
) -> Any:
125128
return self
126129

@@ -129,7 +132,7 @@ def rotate(
129132
angle: float,
130133
interpolation: InterpolationMode = InterpolationMode.NEAREST,
131134
expand: bool = False,
132-
fill: Optional[List[float]] = None,
135+
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
133136
center: Optional[List[float]] = None,
134137
) -> Any:
135138
return self
@@ -141,7 +144,7 @@ def affine(
141144
scale: float,
142145
shear: List[float],
143146
interpolation: InterpolationMode = InterpolationMode.NEAREST,
144-
fill: Optional[List[float]] = None,
147+
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
145148
center: Optional[List[float]] = None,
146149
) -> Any:
147150
return self
@@ -150,7 +153,7 @@ def perspective(
150153
self,
151154
perspective_coeffs: List[float],
152155
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
153-
fill: Optional[List[float]] = None,
156+
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
154157
) -> Any:
155158
return self
156159

@@ -186,12 +189,3 @@ def equalize(self) -> Any:
186189

187190
def invert(self) -> Any:
188191
return self
189-
190-
def erase(self, i: int, j: int, h: int, w: int, v: torch.Tensor) -> Any:
191-
return self
192-
193-
def mixup(self, lam: float) -> Any:
194-
return self
195-
196-
def cutmix(self, box: Tuple[int, int, int, int], lam_adjusted: float) -> Any:
197-
return self

torchvision/prototype/features/_image.py

Lines changed: 23 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -164,10 +164,20 @@ def resized_crop(
164164
return Image.new_like(self, output)
165165

166166
def pad(
167-
self, padding: List[int], fill: Union[int, float, Sequence[float]] = 0, padding_mode: str = "constant"
167+
self,
168+
padding: Union[int, Sequence[int]],
169+
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
170+
padding_mode: str = "constant",
168171
) -> Image:
169172
from torchvision.prototype.transforms import functional as _F
170173

174+
# This cast does Sequence[int] -> List[int] and is required to make mypy happy
175+
if not isinstance(padding, int):
176+
padding = list(padding)
177+
178+
if fill is None:
179+
fill = 0
180+
171181
# PyTorch's pad supports only scalars on fill. So we need to overwrite the colour
172182
if isinstance(fill, (int, float)):
173183
output = _F.pad_image_tensor(self, padding, fill=fill, padding_mode=padding_mode)
@@ -183,10 +193,12 @@ def rotate(
183193
angle: float,
184194
interpolation: InterpolationMode = InterpolationMode.NEAREST,
185195
expand: bool = False,
186-
fill: Optional[List[float]] = None,
196+
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
187197
center: Optional[List[float]] = None,
188198
) -> Image:
189-
from torchvision.prototype.transforms import functional as _F
199+
from torchvision.prototype.transforms.functional import _geometry as _F
200+
201+
fill = _F._convert_fill_arg(fill)
190202

191203
output = _F.rotate_image_tensor(
192204
self, angle, interpolation=interpolation, expand=expand, fill=fill, center=center
@@ -200,10 +212,12 @@ def affine(
200212
scale: float,
201213
shear: List[float],
202214
interpolation: InterpolationMode = InterpolationMode.NEAREST,
203-
fill: Optional[List[float]] = None,
215+
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
204216
center: Optional[List[float]] = None,
205217
) -> Image:
206-
from torchvision.prototype.transforms import functional as _F
218+
from torchvision.prototype.transforms.functional import _geometry as _F
219+
220+
fill = _F._convert_fill_arg(fill)
207221

208222
output = _F.affine_image_tensor(
209223
self,
@@ -221,9 +235,11 @@ def perspective(
221235
self,
222236
perspective_coeffs: List[float],
223237
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
224-
fill: Optional[List[float]] = None,
238+
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
225239
) -> Image:
226-
from torchvision.prototype.transforms import functional as _F
240+
from torchvision.prototype.transforms.functional import _geometry as _F
241+
242+
fill = _F._convert_fill_arg(fill)
227243

228244
output = _F.perspective_image_tensor(self, perspective_coeffs, interpolation=interpolation, fill=fill)
229245
return Image.new_like(self, output)
@@ -293,25 +309,3 @@ def invert(self) -> Image:
293309

294310
output = _F.invert_image_tensor(self)
295311
return Image.new_like(self, output)
296-
297-
def erase(self, i: int, j: int, h: int, w: int, v: torch.Tensor) -> Image:
298-
from torchvision.prototype.transforms import functional as _F
299-
300-
output = _F.erase_image_tensor(self, i, j, h, w, v)
301-
return Image.new_like(self, output)
302-
303-
def mixup(self, lam: float) -> Image:
304-
if self.ndim < 4:
305-
raise ValueError("Need a batch of images")
306-
output = self.clone()
307-
output = output.roll(1, -4).mul_(1 - lam).add_(output.mul_(lam))
308-
return Image.new_like(self, output)
309-
310-
def cutmix(self, box: Tuple[int, int, int, int], lam_adjusted: float) -> Image:
311-
if self.ndim < 4:
312-
raise ValueError("Need a batch of images")
313-
x1, y1, x2, y2 = box
314-
image_rolled = self.roll(1, -4)
315-
output = self.clone()
316-
output[..., y1:y2, x1:x2] = image_rolled[..., y1:y2, x1:x2]
317-
return Image.new_like(self, output)

torchvision/prototype/features/_label.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import Any, Optional, Sequence, cast, Union, Tuple
3+
from typing import Any, Optional, Sequence, cast, Union
44

55
import torch
66
from torchvision.prototype.utils._internal import apply_recursively
@@ -77,14 +77,3 @@ def new_like(
7777
return super().new_like(
7878
other, data, categories=categories if categories is not None else other.categories, **kwargs
7979
)
80-
81-
def mixup(self, lam: float) -> OneHotLabel:
82-
if self.ndim < 2:
83-
raise ValueError("Need a batch of one hot labels")
84-
output = self.clone()
85-
output = output.roll(1, -2).mul_(1 - lam).add_(output.mul_(lam))
86-
return OneHotLabel.new_like(self, output)
87-
88-
def cutmix(self, box: Tuple[int, int, int, int], lam_adjusted: float) -> OneHotLabel:
89-
box # unused
90-
return self.mixup(lam_adjusted)

torchvision/prototype/features/_segmentation_mask.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
from __future__ import annotations
22

3-
from typing import Tuple, List, Optional, Union, Sequence
3+
from typing import List, Optional, Union, Sequence
44

5-
import torch
65
from torchvision.transforms import InterpolationMode
76

87
from ._feature import _Feature
@@ -61,10 +60,17 @@ def resized_crop(
6160
return SegmentationMask.new_like(self, output)
6261

6362
def pad(
64-
self, padding: List[int], fill: Union[int, float, Sequence[float]] = 0, padding_mode: str = "constant"
63+
self,
64+
padding: Union[int, Sequence[int]],
65+
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
66+
padding_mode: str = "constant",
6567
) -> SegmentationMask:
6668
from torchvision.prototype.transforms import functional as _F
6769

70+
# This cast does Sequence[int] -> List[int] and is required to make mypy happy
71+
if not isinstance(padding, int):
72+
padding = list(padding)
73+
6874
output = _F.pad_segmentation_mask(self, padding, padding_mode=padding_mode)
6975
return SegmentationMask.new_like(self, output)
7076

@@ -73,7 +79,7 @@ def rotate(
7379
angle: float,
7480
interpolation: InterpolationMode = InterpolationMode.NEAREST,
7581
expand: bool = False,
76-
fill: Optional[List[float]] = None,
82+
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
7783
center: Optional[List[float]] = None,
7884
) -> SegmentationMask:
7985
from torchvision.prototype.transforms import functional as _F
@@ -88,7 +94,7 @@ def affine(
8894
scale: float,
8995
shear: List[float],
9096
interpolation: InterpolationMode = InterpolationMode.NEAREST,
91-
fill: Optional[List[float]] = None,
97+
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
9298
center: Optional[List[float]] = None,
9399
) -> SegmentationMask:
94100
from torchvision.prototype.transforms import functional as _F
@@ -107,18 +113,9 @@ def perspective(
107113
self,
108114
perspective_coeffs: List[float],
109115
interpolation: InterpolationMode = InterpolationMode.NEAREST,
110-
fill: Optional[List[float]] = None,
116+
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
111117
) -> SegmentationMask:
112118
from torchvision.prototype.transforms import functional as _F
113119

114120
output = _F.perspective_segmentation_mask(self, perspective_coeffs)
115121
return SegmentationMask.new_like(self, output)
116-
117-
def erase(self, i: int, j: int, h: int, w: int, v: torch.Tensor) -> SegmentationMask:
118-
raise TypeError("Erase transformation does not support segmentation masks")
119-
120-
def mixup(self, lam: float) -> SegmentationMask:
121-
raise TypeError("Mixup transformation does not support segmentation masks")
122-
123-
def cutmix(self, box: Tuple[int, int, int, int], lam_adjusted: float) -> SegmentationMask:
124-
raise TypeError("Cutmix transformation does not support segmentation masks")

torchvision/prototype/transforms/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
RandomVerticalFlip,
1818
Pad,
1919
RandomZoomOut,
20+
RandomRotation,
21+
RandomAffine,
2022
)
2123
from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace
2224
from ._misc import Identity, Normalize, ToDtype, Lambda

0 commit comments

Comments
 (0)