Skip to content

Commit 0fdc835

Browse files
jdsgomesFederico Pozzi
authored andcommitted
[fbsync] refactor: port RandomVerticalFlip to prototype API (#5524) (#5633)
Summary: (Note: this ignores all push blocking failures!) Reviewed By: datumbox Differential Revision: D35216779 fbshipit-source-id: 493f6eac276a1ec2f5ed869493db4eef548b3369 Co-authored-by: Federico Pozzi <[email protected]>
1 parent 9f6006c commit 0fdc835

File tree

5 files changed

+106
-0
lines changed

5 files changed

+106
-0
lines changed

test/test_prototype_transforms.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,3 +243,56 @@ def test_features_bounding_box(self, p):
243243
assert_equal(expected, actual)
244244
assert actual.format == expected.format
245245
assert actual.image_size == expected.image_size
246+
247+
248+
@pytest.mark.parametrize("p", [0.0, 1.0])
249+
class TestRandomVerticalFlip:
250+
def input_expected_image_tensor(self, p, dtype=torch.float32):
251+
input = torch.tensor([[[1, 1], [0, 0]], [[1, 1], [0, 0]]], dtype=dtype)
252+
expected = torch.tensor([[[0, 0], [1, 1]], [[0, 0], [1, 1]]], dtype=dtype)
253+
254+
return input, expected if p == 1 else input
255+
256+
def test_simple_tensor(self, p):
257+
input, expected = self.input_expected_image_tensor(p)
258+
transform = transforms.RandomVerticalFlip(p=p)
259+
260+
actual = transform(input)
261+
262+
assert_equal(expected, actual)
263+
264+
def test_pil_image(self, p):
265+
input, expected = self.input_expected_image_tensor(p, dtype=torch.uint8)
266+
transform = transforms.RandomVerticalFlip(p=p)
267+
268+
actual = transform(to_pil_image(input))
269+
270+
assert_equal(expected, pil_to_tensor(actual))
271+
272+
def test_features_image(self, p):
273+
input, expected = self.input_expected_image_tensor(p)
274+
transform = transforms.RandomVerticalFlip(p=p)
275+
276+
actual = transform(features.Image(input))
277+
278+
assert_equal(features.Image(expected), actual)
279+
280+
def test_features_segmentation_mask(self, p):
281+
input, expected = self.input_expected_image_tensor(p)
282+
transform = transforms.RandomVerticalFlip(p=p)
283+
284+
actual = transform(features.SegmentationMask(input))
285+
286+
assert_equal(features.SegmentationMask(expected), actual)
287+
288+
def test_features_bounding_box(self, p):
289+
input = features.BoundingBox([0, 0, 5, 5], format=features.BoundingBoxFormat.XYXY, image_size=(10, 10))
290+
transform = transforms.RandomVerticalFlip(p=p)
291+
292+
actual = transform(input)
293+
294+
expected_image_tensor = torch.tensor([0, 5, 5, 10]) if p == 1.0 else input
295+
expected = features.BoundingBox.new_like(input, data=expected_image_tensor)
296+
assert_equal(expected, actual)
297+
assert actual.format == expected.format
298+
assert actual.image_size == expected.image_size

torchvision/prototype/transforms/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
TenCrop,
1616
BatchMultiCrop,
1717
RandomHorizontalFlip,
18+
RandomVerticalFlip,
1819
Pad,
1920
RandomZoomOut,
2021
)

torchvision/prototype/transforms/_geometry.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,36 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
4545
return input
4646

4747

48+
class RandomVerticalFlip(Transform):
49+
def __init__(self, p: float = 0.5) -> None:
50+
super().__init__()
51+
self.p = p
52+
53+
def forward(self, *inputs: Any) -> Any:
54+
sample = inputs if len(inputs) > 1 else inputs[0]
55+
if torch.rand(1) > self.p:
56+
return sample
57+
58+
return super().forward(sample)
59+
60+
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
61+
if isinstance(input, features.Image):
62+
output = F.vertical_flip_image_tensor(input)
63+
return features.Image.new_like(input, output)
64+
elif isinstance(input, features.SegmentationMask):
65+
output = F.vertical_flip_segmentation_mask(input)
66+
return features.SegmentationMask.new_like(input, output)
67+
elif isinstance(input, features.BoundingBox):
68+
output = F.vertical_flip_bounding_box(input, format=input.format, image_size=input.image_size)
69+
return features.BoundingBox.new_like(input, output)
70+
elif isinstance(input, PIL.Image.Image):
71+
return F.vertical_flip_image_pil(input)
72+
elif is_simple_tensor(input):
73+
return F.vertical_flip_image_tensor(input)
74+
else:
75+
return input
76+
77+
4878
class Resize(Transform):
4979
def __init__(
5080
self,

torchvision/prototype/transforms/functional/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@
6363
perspective_image_pil,
6464
vertical_flip_image_tensor,
6565
vertical_flip_image_pil,
66+
vertical_flip_bounding_box,
67+
vertical_flip_segmentation_mask,
6668
five_crop_image_tensor,
6769
five_crop_image_pil,
6870
ten_crop_image_tensor,

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,26 @@ def resize_bounding_box(bounding_box: torch.Tensor, size: List[int], image_size:
8181
vertical_flip_image_pil = _FP.vflip
8282

8383

84+
def vertical_flip_segmentation_mask(segmentation_mask: torch.Tensor) -> torch.Tensor:
85+
return vertical_flip_image_tensor(segmentation_mask)
86+
87+
88+
def vertical_flip_bounding_box(
89+
bounding_box: torch.Tensor, format: features.BoundingBoxFormat, image_size: Tuple[int, int]
90+
) -> torch.Tensor:
91+
shape = bounding_box.shape
92+
93+
bounding_box = convert_bounding_box_format(
94+
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
95+
).view(-1, 4)
96+
97+
bounding_box[:, [1, 3]] = image_size[0] - bounding_box[:, [3, 1]]
98+
99+
return convert_bounding_box_format(
100+
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
101+
).view(shape)
102+
103+
84104
def _affine_parse_args(
85105
angle: float,
86106
translate: List[float],

0 commit comments

Comments
 (0)