Skip to content

Commit 7be2f55

Browse files
federicopozzi33Federico Pozzipmeier
authored
port RandomHorizontalFlip to prototype API (#5563)
* refactor: port RandomHorizontalFlip to prototype API (#5523) * refactor: merge HorizontalFlip and RandomHorizontalFlip Add unit tests for RandomHorizontalFlip * test: RandomHorizontalFlip with p=0 * refactor: remove type annotations from tests * refactor: improve tests * Update test/test_prototype_transforms.py Co-authored-by: Federico Pozzi <[email protected]> Co-authored-by: Philip Meier <[email protected]>
1 parent 6013230 commit 7be2f55

File tree

5 files changed

+77
-4
lines changed

5 files changed

+77
-4
lines changed

test/test_prototype_transforms.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22

33
import pytest
44
import torch
5+
from common_utils import assert_equal
56
from test_prototype_transforms_functional import make_images, make_bounding_boxes, make_one_hot_labels
67
from torchvision.prototype import transforms, features
7-
from torchvision.transforms.functional import to_pil_image
8+
from torchvision.transforms.functional import to_pil_image, pil_to_tensor
89

910

1011
def make_vanilla_tensor_images(*args, **kwargs):
@@ -66,10 +67,10 @@ def parametrize_from_transforms(*transforms):
6667
class TestSmoke:
6768
@parametrize_from_transforms(
6869
transforms.RandomErasing(p=1.0),
69-
transforms.HorizontalFlip(),
7070
transforms.Resize([16, 16]),
7171
transforms.CenterCrop([16, 16]),
7272
transforms.ConvertImageDtype(),
73+
transforms.RandomHorizontalFlip(),
7374
)
7475
def test_common(self, transform, input):
7576
transform(input)
@@ -188,3 +189,56 @@ def test_random_resized_crop(self, transform, input):
188189
)
189190
def test_convert_image_color_space(self, transform, input):
190191
transform(input)
192+
193+
194+
@pytest.mark.parametrize("p", [0.0, 1.0])
195+
class TestRandomHorizontalFlip:
196+
def input_expected_image_tensor(self, p, dtype=torch.float32):
197+
input = torch.tensor([[[0, 1], [0, 1]], [[1, 0], [1, 0]]], dtype=dtype)
198+
expected = torch.tensor([[[1, 0], [1, 0]], [[0, 1], [0, 1]]], dtype=dtype)
199+
200+
return input, expected if p == 1 else input
201+
202+
def test_simple_tensor(self, p):
203+
input, expected = self.input_expected_image_tensor(p)
204+
transform = transforms.RandomHorizontalFlip(p=p)
205+
206+
actual = transform(input)
207+
208+
assert_equal(expected, actual)
209+
210+
def test_pil_image(self, p):
211+
input, expected = self.input_expected_image_tensor(p, dtype=torch.uint8)
212+
transform = transforms.RandomHorizontalFlip(p=p)
213+
214+
actual = transform(to_pil_image(input))
215+
216+
assert_equal(expected, pil_to_tensor(actual))
217+
218+
def test_features_image(self, p):
219+
input, expected = self.input_expected_image_tensor(p)
220+
transform = transforms.RandomHorizontalFlip(p=p)
221+
222+
actual = transform(features.Image(input))
223+
224+
assert_equal(features.Image(expected), actual)
225+
226+
def test_features_segmentation_mask(self, p):
227+
input, expected = self.input_expected_image_tensor(p)
228+
transform = transforms.RandomHorizontalFlip(p=p)
229+
230+
actual = transform(features.SegmentationMask(input))
231+
232+
assert_equal(features.SegmentationMask(expected), actual)
233+
234+
def test_features_bounding_box(self, p):
235+
input = features.BoundingBox([0, 0, 5, 5], format=features.BoundingBoxFormat.XYXY, image_size=(10, 10))
236+
transform = transforms.RandomHorizontalFlip(p=p)
237+
238+
actual = transform(input)
239+
240+
expected_image_tensor = torch.tensor([5, 0, 10, 5]) if p == 1.0 else input
241+
expected = features.BoundingBox.new_like(input, data=expected_image_tensor)
242+
assert_equal(expected, actual)
243+
assert actual.format == expected.format
244+
assert actual.image_size == expected.image_size

torchvision/prototype/transforms/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@
88
from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment, AugMix
99
from ._container import Compose, RandomApply, RandomChoice, RandomOrder
1010
from ._geometry import (
11-
HorizontalFlip,
1211
Resize,
1312
CenterCrop,
1413
RandomResizedCrop,
1514
FiveCrop,
1615
TenCrop,
1716
BatchMultiCrop,
17+
RandomHorizontalFlip,
1818
RandomZoomOut,
1919
)
2020
from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace

torchvision/prototype/transforms/_geometry.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,25 @@
1313
from ._utils import query_image, get_image_dimensions, has_any, is_simple_tensor
1414

1515

16-
class HorizontalFlip(Transform):
16+
class RandomHorizontalFlip(Transform):
17+
def __init__(self, p: float = 0.5) -> None:
18+
super().__init__()
19+
self.p = p
20+
21+
def forward(self, *inputs: Any) -> Any:
22+
sample = inputs if len(inputs) > 1 else inputs[0]
23+
if torch.rand(1) >= self.p:
24+
return sample
25+
26+
return super().forward(sample)
27+
1728
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
1829
if isinstance(input, features.Image):
1930
output = F.horizontal_flip_image_tensor(input)
2031
return features.Image.new_like(input, output)
32+
elif isinstance(input, features.SegmentationMask):
33+
output = F.horizontal_flip_segmentation_mask(input)
34+
return features.SegmentationMask.new_like(input, output)
2135
elif isinstance(input, features.BoundingBox):
2236
output = F.horizontal_flip_bounding_box(input, format=input.format, image_size=input.image_size)
2337
return features.BoundingBox.new_like(input, output)

torchvision/prototype/transforms/functional/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
horizontal_flip_bounding_box,
4141
horizontal_flip_image_tensor,
4242
horizontal_flip_image_pil,
43+
horizontal_flip_segmentation_mask,
4344
resize_bounding_box,
4445
resize_image_tensor,
4546
resize_image_pil,

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515
horizontal_flip_image_pil = _FP.hflip
1616

1717

18+
def horizontal_flip_segmentation_mask(segmentation_mask: torch.Tensor) -> torch.Tensor:
19+
return horizontal_flip_image_tensor(segmentation_mask)
20+
21+
1822
def horizontal_flip_bounding_box(
1923
bounding_box: torch.Tensor, format: features.BoundingBoxFormat, image_size: Tuple[int, int]
2024
) -> torch.Tensor:

0 commit comments

Comments
 (0)