diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index c880e8db55b..e369dad6271 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -487,6 +487,15 @@ def perspective_segmentation_mask(): ) +@register_kernel_info_from_sample_inputs_fn +def center_crop_image_tensor(): + for image, output_size in itertools.product( + make_images(sizes=((16, 16), (7, 33), (31, 9))), + [[4, 3], [42, 70], [4]], # crop sizes < image sizes, crop_sizes > image sizes, single crop size + ): + yield SampleInput(image, output_size) + + @register_kernel_info_from_sample_inputs_fn def center_crop_bounding_box(): for bounding_box, output_size in itertools.product(make_bounding_boxes(), [(24, 12), [16, 18], [46, 48], [12]]): @@ -495,6 +504,7 @@ def center_crop_bounding_box(): ) +@register_kernel_info_from_sample_inputs_fn def center_crop_segmentation_mask(): for mask, output_size in itertools.product( make_segmentation_masks(image_sizes=((16, 16), (7, 33), (31, 9))),