Skip to content

Commit 6b900ab

Browse files
committed
Added more functional tests
1 parent 3a9aca1 commit 6b900ab

File tree

1 file changed

+73
-3
lines changed

1 file changed

+73
-3
lines changed

test/test_prototype_transforms_functional.py

Lines changed: 73 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,30 @@ def horizontal_flip_bounding_box():
200200
yield SampleInput(bounding_box, format=bounding_box.format, image_size=bounding_box.image_size)
201201

202202

203+
@register_kernel_info_from_sample_inputs_fn
204+
def horizontal_flip_segmentation_mask():
205+
for mask in make_segmentation_masks():
206+
yield SampleInput(mask)
207+
208+
209+
@register_kernel_info_from_sample_inputs_fn
210+
def vertical_flip_image_tensor():
211+
for image in make_images():
212+
yield SampleInput(image)
213+
214+
215+
@register_kernel_info_from_sample_inputs_fn
216+
def vertical_flip_bounding_box():
217+
for bounding_box in make_bounding_boxes(formats=[features.BoundingBoxFormat.XYXY]):
218+
yield SampleInput(bounding_box, format=bounding_box.format, image_size=bounding_box.image_size)
219+
220+
221+
@register_kernel_info_from_sample_inputs_fn
222+
def vertical_flip_segmentation_mask():
223+
for mask in make_segmentation_masks():
224+
yield SampleInput(mask)
225+
226+
203227
@register_kernel_info_from_sample_inputs_fn
204228
def resize_image_tensor():
205229
for image, interpolation, max_size, antialias in itertools.product(
@@ -404,9 +428,17 @@ def crop_segmentation_mask():
404428

405429

406430
@register_kernel_info_from_sample_inputs_fn
407-
def vertical_flip_segmentation_mask():
408-
for mask in make_segmentation_masks():
409-
yield SampleInput(mask)
431+
def resized_crop_image_tensor():
432+
for mask, top, left, height, width, size, antialias in itertools.product(
433+
make_images(),
434+
[-8, 9],
435+
[-8, 9],
436+
[12],
437+
[12],
438+
[(16, 18)],
439+
[True, False],
440+
):
441+
yield SampleInput(mask, top=top, left=left, height=height, width=width, size=size, antialias=antialias)
410442

411443

412444
@register_kernel_info_from_sample_inputs_fn
@@ -457,6 +489,23 @@ def pad_bounding_box():
457489
yield SampleInput(bounding_box, padding=padding, format=bounding_box.format)
458490

459491

492+
@register_kernel_info_from_sample_inputs_fn
493+
def perspective_image_tensor():
494+
for image, perspective_coeffs, fill in itertools.product(
495+
make_images(extra_dims=((), (4,))),
496+
[
497+
[1.2405, 0.1772, -6.9113, 0.0463, 1.251, -5.235, 0.00013, 0.0018],
498+
[0.7366, -0.11724, 1.45775, -0.15012, 0.73406, 2.6019, -0.0072, -0.0063],
499+
],
500+
[None, [128], [12.0]], # fill
501+
):
502+
yield SampleInput(
503+
image,
504+
perspective_coeffs=perspective_coeffs,
505+
fill=fill
506+
)
507+
508+
460509
@register_kernel_info_from_sample_inputs_fn
461510
def perspective_bounding_box():
462511
for bounding_box, perspective_coeffs in itertools.product(
@@ -488,6 +537,15 @@ def perspective_segmentation_mask():
488537
)
489538

490539

540+
@register_kernel_info_from_sample_inputs_fn
541+
def center_crop_image_tensor():
542+
for mask, output_size in itertools.product(
543+
make_images(sizes=((16, 16), (7, 33), (31, 9))),
544+
[[4, 3], [42, 70], [4]], # crop sizes < image sizes, crop_sizes > image sizes, single crop size
545+
):
546+
yield SampleInput(mask, output_size)
547+
548+
491549
@register_kernel_info_from_sample_inputs_fn
492550
def center_crop_bounding_box():
493551
for bounding_box, output_size in itertools.product(make_bounding_boxes(), [(24, 12), [16, 18], [46, 48], [12]]):
@@ -1181,6 +1239,18 @@ def _compute_expected_mask(mask, top_, left_, height_, width_):
11811239
torch.testing.assert_close(output_mask, expected_mask)
11821240

11831241

1242+
@pytest.mark.parametrize("device", cpu_and_gpu())
1243+
def test_correctness_horizontal_flip_segmentation_mask_on_fixed_input(device):
1244+
mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device)
1245+
mask[:, :, 0] = 1
1246+
1247+
out_mask = F.horizontal_flip_segmentation_mask(mask)
1248+
1249+
expected_mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device)
1250+
expected_mask[:, :, -1] = 1
1251+
torch.testing.assert_close(out_mask, expected_mask)
1252+
1253+
11841254
@pytest.mark.parametrize("device", cpu_and_gpu())
11851255
def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device):
11861256
mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device)

0 commit comments

Comments
 (0)