@@ -200,6 +200,30 @@ def horizontal_flip_bounding_box():
200
200
yield SampleInput (bounding_box , format = bounding_box .format , image_size = bounding_box .image_size )
201
201
202
202
203
+ @register_kernel_info_from_sample_inputs_fn
204
+ def horizontal_flip_segmentation_mask ():
205
+ for mask in make_segmentation_masks ():
206
+ yield SampleInput (mask )
207
+
208
+
209
+ @register_kernel_info_from_sample_inputs_fn
210
+ def vertical_flip_image_tensor ():
211
+ for image in make_images ():
212
+ yield SampleInput (image )
213
+
214
+
215
+ @register_kernel_info_from_sample_inputs_fn
216
+ def vertical_flip_bounding_box ():
217
+ for bounding_box in make_bounding_boxes (formats = [features .BoundingBoxFormat .XYXY ]):
218
+ yield SampleInput (bounding_box , format = bounding_box .format , image_size = bounding_box .image_size )
219
+
220
+
221
+ @register_kernel_info_from_sample_inputs_fn
222
+ def vertical_flip_segmentation_mask ():
223
+ for mask in make_segmentation_masks ():
224
+ yield SampleInput (mask )
225
+
226
+
203
227
@register_kernel_info_from_sample_inputs_fn
204
228
def resize_image_tensor ():
205
229
for image , interpolation , max_size , antialias in itertools .product (
@@ -404,9 +428,17 @@ def crop_segmentation_mask():
404
428
405
429
406
430
@register_kernel_info_from_sample_inputs_fn
407
- def vertical_flip_segmentation_mask ():
408
- for mask in make_segmentation_masks ():
409
- yield SampleInput (mask )
431
+ def resized_crop_image_tensor ():
432
+ for mask , top , left , height , width , size , antialias in itertools .product (
433
+ make_images (),
434
+ [- 8 , 9 ],
435
+ [- 8 , 9 ],
436
+ [12 ],
437
+ [12 ],
438
+ [(16 , 18 )],
439
+ [True , False ],
440
+ ):
441
+ yield SampleInput (mask , top = top , left = left , height = height , width = width , size = size , antialias = antialias )
410
442
411
443
412
444
@register_kernel_info_from_sample_inputs_fn
@@ -457,6 +489,19 @@ def pad_bounding_box():
457
489
yield SampleInput (bounding_box , padding = padding , format = bounding_box .format )
458
490
459
491
492
+ @register_kernel_info_from_sample_inputs_fn
493
+ def perspective_image_tensor ():
494
+ for image , perspective_coeffs , fill in itertools .product (
495
+ make_images (extra_dims = ((), (4 ,))),
496
+ [
497
+ [1.2405 , 0.1772 , - 6.9113 , 0.0463 , 1.251 , - 5.235 , 0.00013 , 0.0018 ],
498
+ [0.7366 , - 0.11724 , 1.45775 , - 0.15012 , 0.73406 , 2.6019 , - 0.0072 , - 0.0063 ],
499
+ ],
500
+ [None , [128 ], [12.0 ]], # fill
501
+ ):
502
+ yield SampleInput (image , perspective_coeffs = perspective_coeffs , fill = fill )
503
+
504
+
460
505
@register_kernel_info_from_sample_inputs_fn
461
506
def perspective_bounding_box ():
462
507
for bounding_box , perspective_coeffs in itertools .product (
@@ -488,6 +533,15 @@ def perspective_segmentation_mask():
488
533
)
489
534
490
535
536
+ @register_kernel_info_from_sample_inputs_fn
537
+ def center_crop_image_tensor ():
538
+ for mask , output_size in itertools .product (
539
+ make_images (sizes = ((16 , 16 ), (7 , 33 ), (31 , 9 ))),
540
+ [[4 , 3 ], [42 , 70 ], [4 ]], # crop sizes < image sizes, crop_sizes > image sizes, single crop size
541
+ ):
542
+ yield SampleInput (mask , output_size )
543
+
544
+
491
545
@register_kernel_info_from_sample_inputs_fn
492
546
def center_crop_bounding_box ():
493
547
for bounding_box , output_size in itertools .product (make_bounding_boxes (), [(24 , 12 ), [16 , 18 ], [46 , 48 ], [12 ]]):
@@ -1181,6 +1235,18 @@ def _compute_expected_mask(mask, top_, left_, height_, width_):
1181
1235
torch .testing .assert_close (output_mask , expected_mask )
1182
1236
1183
1237
1238
+ @pytest .mark .parametrize ("device" , cpu_and_gpu ())
1239
+ def test_correctness_horizontal_flip_segmentation_mask_on_fixed_input (device ):
1240
+ mask = torch .zeros ((3 , 3 , 3 ), dtype = torch .long , device = device )
1241
+ mask [:, :, 0 ] = 1
1242
+
1243
+ out_mask = F .horizontal_flip_segmentation_mask (mask )
1244
+
1245
+ expected_mask = torch .zeros ((3 , 3 , 3 ), dtype = torch .long , device = device )
1246
+ expected_mask [:, :, - 1 ] = 1
1247
+ torch .testing .assert_close (out_mask , expected_mask )
1248
+
1249
+
1184
1250
@pytest .mark .parametrize ("device" , cpu_and_gpu ())
1185
1251
def test_correctness_vertical_flip_segmentation_mask_on_fixed_input (device ):
1186
1252
mask = torch .zeros ((3 , 3 , 3 ), dtype = torch .long , device = device )
0 commit comments