|
2 | 2 |
|
3 | 3 | import pytest
|
4 | 4 | import torch
|
| 5 | +from common_utils import assert_equal |
5 | 6 | from test_prototype_transforms_functional import make_images, make_bounding_boxes, make_one_hot_labels
|
6 | 7 | from torchvision.prototype import transforms, features
|
7 |
| -from torchvision.transforms.functional import to_pil_image |
| 8 | +from torchvision.transforms.functional import to_pil_image, pil_to_tensor |
8 | 9 |
|
9 | 10 |
|
10 | 11 | def make_vanilla_tensor_images(*args, **kwargs):
|
@@ -66,10 +67,10 @@ def parametrize_from_transforms(*transforms):
|
66 | 67 | class TestSmoke:
|
67 | 68 | @parametrize_from_transforms(
|
68 | 69 | transforms.RandomErasing(p=1.0),
|
69 |
| - transforms.HorizontalFlip(), |
70 | 70 | transforms.Resize([16, 16]),
|
71 | 71 | transforms.CenterCrop([16, 16]),
|
72 | 72 | transforms.ConvertImageDtype(),
|
| 73 | + transforms.RandomHorizontalFlip(), |
73 | 74 | )
|
74 | 75 | def test_common(self, transform, input):
|
75 | 76 | transform(input)
|
@@ -188,3 +189,56 @@ def test_random_resized_crop(self, transform, input):
|
188 | 189 | )
|
189 | 190 | def test_convert_image_color_space(self, transform, input):
|
190 | 191 | transform(input)
|
| 192 | + |
| 193 | + |
| 194 | +@pytest.mark.parametrize("p", [0.0, 1.0]) |
| 195 | +class TestRandomHorizontalFlip: |
| 196 | + def input_expected_image_tensor(self, p, dtype=torch.float32): |
| 197 | + input = torch.tensor([[[0, 1], [0, 1]], [[1, 0], [1, 0]]], dtype=dtype) |
| 198 | + expected = torch.tensor([[[1, 0], [1, 0]], [[0, 1], [0, 1]]], dtype=dtype) |
| 199 | + |
| 200 | + return input, expected if p == 1 else input |
| 201 | + |
| 202 | + def test_simple_tensor(self, p): |
| 203 | + input, expected = self.input_expected_image_tensor(p) |
| 204 | + transform = transforms.RandomHorizontalFlip(p=p) |
| 205 | + |
| 206 | + actual = transform(input) |
| 207 | + |
| 208 | + assert_equal(expected, actual) |
| 209 | + |
| 210 | + def test_pil_image(self, p): |
| 211 | + input, expected = self.input_expected_image_tensor(p, dtype=torch.uint8) |
| 212 | + transform = transforms.RandomHorizontalFlip(p=p) |
| 213 | + |
| 214 | + actual = transform(to_pil_image(input)) |
| 215 | + |
| 216 | + assert_equal(expected, pil_to_tensor(actual)) |
| 217 | + |
| 218 | + def test_features_image(self, p): |
| 219 | + input, expected = self.input_expected_image_tensor(p) |
| 220 | + transform = transforms.RandomHorizontalFlip(p=p) |
| 221 | + |
| 222 | + actual = transform(features.Image(input)) |
| 223 | + |
| 224 | + assert_equal(features.Image(expected), actual) |
| 225 | + |
| 226 | + def test_features_segmentation_mask(self, p): |
| 227 | + input, expected = self.input_expected_image_tensor(p) |
| 228 | + transform = transforms.RandomHorizontalFlip(p=p) |
| 229 | + |
| 230 | + actual = transform(features.SegmentationMask(input)) |
| 231 | + |
| 232 | + assert_equal(features.SegmentationMask(expected), actual) |
| 233 | + |
| 234 | + def test_features_bounding_box(self, p): |
| 235 | + input = features.BoundingBox([0, 0, 5, 5], format=features.BoundingBoxFormat.XYXY, image_size=(10, 10)) |
| 236 | + transform = transforms.RandomHorizontalFlip(p=p) |
| 237 | + |
| 238 | + actual = transform(input) |
| 239 | + |
| 240 | + expected_image_tensor = torch.tensor([5, 0, 10, 5]) if p == 1.0 else input |
| 241 | + expected = features.BoundingBox.new_like(input, data=expected_image_tensor) |
| 242 | + assert_equal(expected, actual) |
| 243 | + assert actual.format == expected.format |
| 244 | + assert actual.image_size == expected.image_size |
0 commit comments