Skip to content

Commit d5276bf

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] port vertical flip (#7712)
Reviewed By: vmoens Differential Revision: D47186566 fbshipit-source-id: 92dd32411629e98d4b82c69cd9a000bd92eeb5fb
1 parent 541dbe2 commit d5276bf

File tree

5 files changed

+146
-151
lines changed

5 files changed

+146
-151
lines changed

test/test_transforms_v2.py

Lines changed: 1 addition & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from torch.utils._pytree import tree_flatten, tree_unflatten
3030
from torchvision import datapoints
3131
from torchvision.ops.boxes import box_iou
32-
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image
32+
from torchvision.transforms.functional import InterpolationMode, to_pil_image
3333
from torchvision.transforms.v2 import functional as F
3434
from torchvision.transforms.v2.utils import check_type, is_simple_tensor, query_chw
3535

@@ -406,59 +406,6 @@ def was_applied(output, inpt):
406406
assert transform.was_applied(output, input)
407407

408408

409-
@pytest.mark.parametrize("p", [0.0, 1.0])
410-
class TestRandomVerticalFlip:
411-
def input_expected_image_tensor(self, p, dtype=torch.float32):
412-
input = torch.tensor([[[1, 1], [0, 0]], [[1, 1], [0, 0]]], dtype=dtype)
413-
expected = torch.tensor([[[0, 0], [1, 1]], [[0, 0], [1, 1]]], dtype=dtype)
414-
415-
return input, expected if p == 1 else input
416-
417-
def test_simple_tensor(self, p):
418-
input, expected = self.input_expected_image_tensor(p)
419-
transform = transforms.RandomVerticalFlip(p=p)
420-
421-
actual = transform(input)
422-
423-
assert_equal(expected, actual)
424-
425-
def test_pil_image(self, p):
426-
input, expected = self.input_expected_image_tensor(p, dtype=torch.uint8)
427-
transform = transforms.RandomVerticalFlip(p=p)
428-
429-
actual = transform(to_pil_image(input))
430-
431-
assert_equal(expected, pil_to_tensor(actual))
432-
433-
def test_datapoints_image(self, p):
434-
input, expected = self.input_expected_image_tensor(p)
435-
transform = transforms.RandomVerticalFlip(p=p)
436-
437-
actual = transform(datapoints.Image(input))
438-
439-
assert_equal(datapoints.Image(expected), actual)
440-
441-
def test_datapoints_mask(self, p):
442-
input, expected = self.input_expected_image_tensor(p)
443-
transform = transforms.RandomVerticalFlip(p=p)
444-
445-
actual = transform(datapoints.Mask(input))
446-
447-
assert_equal(datapoints.Mask(expected), actual)
448-
449-
def test_datapoints_bounding_box(self, p):
450-
input = datapoints.BoundingBox([0, 0, 5, 5], format=datapoints.BoundingBoxFormat.XYXY, spatial_size=(10, 10))
451-
transform = transforms.RandomVerticalFlip(p=p)
452-
453-
actual = transform(input)
454-
455-
expected_image_tensor = torch.tensor([0, 5, 5, 10]) if p == 1.0 else input
456-
expected = datapoints.BoundingBox.wrap_like(input, expected_image_tensor)
457-
assert_equal(expected, actual)
458-
assert actual.format == expected.format
459-
assert actual.spatial_size == expected.spatial_size
460-
461-
462409
class TestPad:
463410
def test_assertions(self):
464411
with pytest.raises(TypeError, match="Got inappropriate padding arg"):

test/test_transforms_v2_refactored.py

Lines changed: 143 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -842,7 +842,7 @@ def _reference_horizontal_flip_bounding_box(self, bounding_box):
842842
"fn", [F.horizontal_flip, transform_cls_to_functional(transforms.RandomHorizontalFlip, p=1)]
843843
)
844844
def test_bounding_box_correctness(self, format, fn):
845-
bounding_box = self._make_input(datapoints.BoundingBox)
845+
bounding_box = self._make_input(datapoints.BoundingBox, format=format)
846846

847847
actual = fn(bounding_box)
848848
expected = self._reference_horizontal_flip_bounding_box(bounding_box)
@@ -1025,12 +1025,10 @@ def test_kernel_bounding_box(self, param, value, format, dtype, device):
10251025

10261026
@pytest.mark.parametrize("mask_type", ["segmentation", "detection"])
10271027
def test_kernel_mask(self, mask_type):
1028-
check_kernel(
1029-
F.affine_mask, self._make_input(datapoints.Mask, mask_type=mask_type), **self._MINIMAL_AFFINE_KWARGS
1030-
)
1028+
self._check_kernel(F.affine_mask, self._make_input(datapoints.Mask, mask_type=mask_type))
10311029

10321030
def test_kernel_video(self):
1033-
check_kernel(F.affine_video, self._make_input(datapoints.Video), **self._MINIMAL_AFFINE_KWARGS)
1031+
self._check_kernel(F.affine_video, self._make_input(datapoints.Video))
10341032

10351033
@pytest.mark.parametrize(
10361034
("input_type", "kernel"),
@@ -1301,3 +1299,143 @@ def test_transform_negative_shear_error(self):
13011299
def test_transform_unknown_fill_error(self):
13021300
with pytest.raises(TypeError, match="Got inappropriate fill arg"):
13031301
transforms.RandomAffine(degrees=0, fill="fill")
1302+
1303+
1304+
class TestVerticalFlip:
1305+
def _make_input(self, input_type, *, dtype=None, device="cpu", spatial_size=(17, 11), **kwargs):
1306+
if input_type in {torch.Tensor, PIL.Image.Image, datapoints.Image}:
1307+
input = make_image(size=spatial_size, dtype=dtype or torch.uint8, device=device, **kwargs)
1308+
if input_type is torch.Tensor:
1309+
input = input.as_subclass(torch.Tensor)
1310+
elif input_type is PIL.Image.Image:
1311+
input = F.to_image_pil(input)
1312+
elif input_type is datapoints.BoundingBox:
1313+
kwargs.setdefault("format", datapoints.BoundingBoxFormat.XYXY)
1314+
input = make_bounding_box(
1315+
dtype=dtype or torch.float32,
1316+
device=device,
1317+
spatial_size=spatial_size,
1318+
**kwargs,
1319+
)
1320+
elif input_type is datapoints.Mask:
1321+
input = make_segmentation_mask(size=spatial_size, dtype=dtype or torch.uint8, device=device, **kwargs)
1322+
elif input_type is datapoints.Video:
1323+
input = make_video(size=spatial_size, dtype=dtype or torch.uint8, device=device, **kwargs)
1324+
1325+
return input
1326+
1327+
@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
1328+
@pytest.mark.parametrize("device", cpu_and_cuda())
1329+
def test_kernel_image_tensor(self, dtype, device):
1330+
check_kernel(F.vertical_flip_image_tensor, self._make_input(torch.Tensor, dtype=dtype, device=device))
1331+
1332+
@pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat))
1333+
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
1334+
@pytest.mark.parametrize("device", cpu_and_cuda())
1335+
def test_kernel_bounding_box(self, format, dtype, device):
1336+
bounding_box = self._make_input(datapoints.BoundingBox, dtype=dtype, device=device, format=format)
1337+
check_kernel(
1338+
F.vertical_flip_bounding_box,
1339+
bounding_box,
1340+
format=format,
1341+
spatial_size=bounding_box.spatial_size,
1342+
)
1343+
1344+
@pytest.mark.parametrize(
1345+
"dtype_and_make_mask", [(torch.uint8, make_segmentation_mask), (torch.bool, make_detection_mask)]
1346+
)
1347+
def test_kernel_mask(self, dtype_and_make_mask):
1348+
dtype, make_mask = dtype_and_make_mask
1349+
check_kernel(F.vertical_flip_mask, make_mask(dtype=dtype))
1350+
1351+
def test_kernel_video(self):
1352+
check_kernel(F.vertical_flip_video, self._make_input(datapoints.Video))
1353+
1354+
@pytest.mark.parametrize(
1355+
("input_type", "kernel"),
1356+
[
1357+
(torch.Tensor, F.vertical_flip_image_tensor),
1358+
(PIL.Image.Image, F.vertical_flip_image_pil),
1359+
(datapoints.Image, F.vertical_flip_image_tensor),
1360+
(datapoints.BoundingBox, F.vertical_flip_bounding_box),
1361+
(datapoints.Mask, F.vertical_flip_mask),
1362+
(datapoints.Video, F.vertical_flip_video),
1363+
],
1364+
)
1365+
def test_dispatcher(self, kernel, input_type):
1366+
check_dispatcher(F.vertical_flip, kernel, self._make_input(input_type))
1367+
1368+
@pytest.mark.parametrize(
1369+
("input_type", "kernel"),
1370+
[
1371+
(torch.Tensor, F.vertical_flip_image_tensor),
1372+
(PIL.Image.Image, F.vertical_flip_image_pil),
1373+
(datapoints.Image, F.vertical_flip_image_tensor),
1374+
(datapoints.BoundingBox, F.vertical_flip_bounding_box),
1375+
(datapoints.Mask, F.vertical_flip_mask),
1376+
(datapoints.Video, F.vertical_flip_video),
1377+
],
1378+
)
1379+
def test_dispatcher_signature(self, kernel, input_type):
1380+
check_dispatcher_signatures_match(F.vertical_flip, kernel=kernel, input_type=input_type)
1381+
1382+
@pytest.mark.parametrize(
1383+
"input_type",
1384+
[torch.Tensor, PIL.Image.Image, datapoints.Image, datapoints.BoundingBox, datapoints.Mask, datapoints.Video],
1385+
)
1386+
@pytest.mark.parametrize("device", cpu_and_cuda())
1387+
def test_transform(self, input_type, device):
1388+
input = self._make_input(input_type, device=device)
1389+
1390+
check_transform(transforms.RandomVerticalFlip, input, p=1)
1391+
1392+
@pytest.mark.parametrize("fn", [F.vertical_flip, transform_cls_to_functional(transforms.RandomVerticalFlip, p=1)])
1393+
def test_image_correctness(self, fn):
1394+
image = self._make_input(torch.Tensor, dtype=torch.uint8, device="cpu")
1395+
1396+
actual = fn(image)
1397+
expected = F.to_image_tensor(F.vertical_flip(F.to_image_pil(image)))
1398+
1399+
torch.testing.assert_close(actual, expected)
1400+
1401+
def _reference_vertical_flip_bounding_box(self, bounding_box):
1402+
affine_matrix = np.array(
1403+
[
1404+
[1, 0, 0],
1405+
[0, -1, bounding_box.spatial_size[0]],
1406+
],
1407+
dtype="float64" if bounding_box.dtype == torch.float64 else "float32",
1408+
)
1409+
1410+
expected_bboxes = reference_affine_bounding_box_helper(
1411+
bounding_box,
1412+
format=bounding_box.format,
1413+
spatial_size=bounding_box.spatial_size,
1414+
affine_matrix=affine_matrix,
1415+
)
1416+
1417+
return datapoints.BoundingBox.wrap_like(bounding_box, expected_bboxes)
1418+
1419+
@pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat))
1420+
@pytest.mark.parametrize("fn", [F.vertical_flip, transform_cls_to_functional(transforms.RandomVerticalFlip, p=1)])
1421+
def test_bounding_box_correctness(self, format, fn):
1422+
bounding_box = self._make_input(datapoints.BoundingBox, format=format)
1423+
1424+
actual = fn(bounding_box)
1425+
expected = self._reference_vertical_flip_bounding_box(bounding_box)
1426+
1427+
torch.testing.assert_close(actual, expected)
1428+
1429+
@pytest.mark.parametrize(
1430+
"input_type",
1431+
[torch.Tensor, PIL.Image.Image, datapoints.Image, datapoints.BoundingBox, datapoints.Mask, datapoints.Video],
1432+
)
1433+
@pytest.mark.parametrize("device", cpu_and_cuda())
1434+
def test_transform_noop(self, input_type, device):
1435+
input = self._make_input(input_type, device=device)
1436+
1437+
transform = transforms.RandomVerticalFlip(p=0)
1438+
1439+
output = transform(input)
1440+
1441+
assert_equal(output, input)

test/transforms_v2_dispatcher_infos.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -138,16 +138,6 @@ def fill_sequence_needs_broadcast(args_kwargs):
138138

139139

140140
DISPATCHER_INFOS = [
141-
DispatcherInfo(
142-
F.vertical_flip,
143-
kernels={
144-
datapoints.Image: F.vertical_flip_image_tensor,
145-
datapoints.Video: F.vertical_flip_video,
146-
datapoints.BoundingBox: F.vertical_flip_bounding_box,
147-
datapoints.Mask: F.vertical_flip_mask,
148-
},
149-
pil_kernel_info=PILKernelInfo(F.vertical_flip_image_pil, kernel_name="vertical_flip_image_pil"),
150-
),
151141
DispatcherInfo(
152142
F.rotate,
153143
kernels={

test/transforms_v2_kernel_infos.py

Lines changed: 0 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -264,87 +264,6 @@ def reference_inputs_convert_format_bounding_box():
264264
)
265265

266266

267-
def sample_inputs_vertical_flip_image_tensor():
268-
for image_loader in make_image_loaders(sizes=["random"], dtypes=[torch.float32]):
269-
yield ArgsKwargs(image_loader)
270-
271-
272-
def reference_inputs_vertical_flip_image_tensor():
273-
for image_loader in make_image_loaders(extra_dims=[()], dtypes=[torch.uint8]):
274-
yield ArgsKwargs(image_loader)
275-
276-
277-
def sample_inputs_vertical_flip_bounding_box():
278-
for bounding_box_loader in make_bounding_box_loaders(
279-
formats=[datapoints.BoundingBoxFormat.XYXY], dtypes=[torch.float32]
280-
):
281-
yield ArgsKwargs(
282-
bounding_box_loader, format=bounding_box_loader.format, spatial_size=bounding_box_loader.spatial_size
283-
)
284-
285-
286-
def sample_inputs_vertical_flip_mask():
287-
for image_loader in make_mask_loaders(sizes=["random"], dtypes=[torch.uint8]):
288-
yield ArgsKwargs(image_loader)
289-
290-
291-
def sample_inputs_vertical_flip_video():
292-
for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
293-
yield ArgsKwargs(video_loader)
294-
295-
296-
def reference_vertical_flip_bounding_box(bounding_box, *, format, spatial_size):
297-
affine_matrix = np.array(
298-
[
299-
[1, 0, 0],
300-
[0, -1, spatial_size[0]],
301-
],
302-
dtype="float64" if bounding_box.dtype == torch.float64 else "float32",
303-
)
304-
305-
expected_bboxes = reference_affine_bounding_box_helper(
306-
bounding_box, format=format, spatial_size=spatial_size, affine_matrix=affine_matrix
307-
)
308-
309-
return expected_bboxes
310-
311-
312-
def reference_inputs_vertical_flip_bounding_box():
313-
for bounding_box_loader in make_bounding_box_loaders(extra_dims=[()]):
314-
yield ArgsKwargs(
315-
bounding_box_loader,
316-
format=bounding_box_loader.format,
317-
spatial_size=bounding_box_loader.spatial_size,
318-
)
319-
320-
321-
KERNEL_INFOS.extend(
322-
[
323-
KernelInfo(
324-
F.vertical_flip_image_tensor,
325-
kernel_name="vertical_flip_image_tensor",
326-
sample_inputs_fn=sample_inputs_vertical_flip_image_tensor,
327-
reference_fn=pil_reference_wrapper(F.vertical_flip_image_pil),
328-
reference_inputs_fn=reference_inputs_vertical_flip_image_tensor,
329-
float32_vs_uint8=True,
330-
),
331-
KernelInfo(
332-
F.vertical_flip_bounding_box,
333-
sample_inputs_fn=sample_inputs_vertical_flip_bounding_box,
334-
reference_fn=reference_vertical_flip_bounding_box,
335-
reference_inputs_fn=reference_inputs_vertical_flip_bounding_box,
336-
),
337-
KernelInfo(
338-
F.vertical_flip_mask,
339-
sample_inputs_fn=sample_inputs_vertical_flip_mask,
340-
),
341-
KernelInfo(
342-
F.vertical_flip_video,
343-
sample_inputs_fn=sample_inputs_vertical_flip_video,
344-
),
345-
]
346-
)
347-
348267
_ROTATE_ANGLES = [-87, 15, 90]
349268

350269

torchvision/transforms/v2/functional/_geometry.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ def vertical_flip_image_tensor(image: torch.Tensor) -> torch.Tensor:
9393
return image.flip(-2)
9494

9595

96-
vertical_flip_image_pil = _FP.vflip
96+
def vertical_flip_image_pil(image: PIL.Image) -> PIL.Image:
97+
return _FP.vflip(image)
9798

9899

99100
def vertical_flip_mask(mask: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)