@@ -200,6 +200,30 @@ def horizontal_flip_bounding_box():
200
200
yield SampleInput (bounding_box , format = bounding_box .format , image_size = bounding_box .image_size )
201
201
202
202
203
+ @register_kernel_info_from_sample_inputs_fn
204
+ def horizontal_flip_segmentation_mask ():
205
+ for mask in make_segmentation_masks ():
206
+ yield SampleInput (mask )
207
+
208
+
209
+ @register_kernel_info_from_sample_inputs_fn
210
+ def vertical_flip_image_tensor ():
211
+ for image in make_images ():
212
+ yield SampleInput (image )
213
+
214
+
215
+ @register_kernel_info_from_sample_inputs_fn
216
+ def vertical_flip_bounding_box ():
217
+ for bounding_box in make_bounding_boxes (formats = [features .BoundingBoxFormat .XYXY ]):
218
+ yield SampleInput (bounding_box , format = bounding_box .format , image_size = bounding_box .image_size )
219
+
220
+
221
+ @register_kernel_info_from_sample_inputs_fn
222
+ def vertical_flip_segmentation_mask ():
223
+ for mask in make_segmentation_masks ():
224
+ yield SampleInput (mask )
225
+
226
+
203
227
@register_kernel_info_from_sample_inputs_fn
204
228
def resize_image_tensor ():
205
229
for image , interpolation , max_size , antialias in itertools .product (
@@ -404,9 +428,17 @@ def crop_segmentation_mask():
404
428
405
429
406
430
@register_kernel_info_from_sample_inputs_fn
407
- def vertical_flip_segmentation_mask ():
408
- for mask in make_segmentation_masks ():
409
- yield SampleInput (mask )
431
+ def resized_crop_image_tensor ():
432
+ for mask , top , left , height , width , size , antialias in itertools .product (
433
+ make_images (),
434
+ [- 8 , 9 ],
435
+ [- 8 , 9 ],
436
+ [12 ],
437
+ [12 ],
438
+ [(16 , 18 )],
439
+ [True , False ],
440
+ ):
441
+ yield SampleInput (mask , top = top , left = left , height = height , width = width , size = size , antialias = antialias )
410
442
411
443
412
444
@register_kernel_info_from_sample_inputs_fn
@@ -457,6 +489,23 @@ def pad_bounding_box():
457
489
yield SampleInput (bounding_box , padding = padding , format = bounding_box .format )
458
490
459
491
492
+ @register_kernel_info_from_sample_inputs_fn
493
+ def perspective_image_tensor ():
494
+ for image , perspective_coeffs , fill in itertools .product (
495
+ make_images (extra_dims = ((), (4 ,))),
496
+ [
497
+ [1.2405 , 0.1772 , - 6.9113 , 0.0463 , 1.251 , - 5.235 , 0.00013 , 0.0018 ],
498
+ [0.7366 , - 0.11724 , 1.45775 , - 0.15012 , 0.73406 , 2.6019 , - 0.0072 , - 0.0063 ],
499
+ ],
500
+ [None , [128 ], [12.0 ]], # fill
501
+ ):
502
+ yield SampleInput (
503
+ image ,
504
+ perspective_coeffs = perspective_coeffs ,
505
+ fill = fill
506
+ )
507
+
508
+
460
509
@register_kernel_info_from_sample_inputs_fn
461
510
def perspective_bounding_box ():
462
511
for bounding_box , perspective_coeffs in itertools .product (
@@ -488,6 +537,15 @@ def perspective_segmentation_mask():
488
537
)
489
538
490
539
540
+ @register_kernel_info_from_sample_inputs_fn
541
+ def center_crop_image_tensor ():
542
+ for mask , output_size in itertools .product (
543
+ make_images (sizes = ((16 , 16 ), (7 , 33 ), (31 , 9 ))),
544
+ [[4 , 3 ], [42 , 70 ], [4 ]], # crop sizes < image sizes, crop_sizes > image sizes, single crop size
545
+ ):
546
+ yield SampleInput (mask , output_size )
547
+
548
+
491
549
@register_kernel_info_from_sample_inputs_fn
492
550
def center_crop_bounding_box ():
493
551
for bounding_box , output_size in itertools .product (make_bounding_boxes (), [(24 , 12 ), [16 , 18 ], [46 , 48 ], [12 ]]):
@@ -1181,6 +1239,18 @@ def _compute_expected_mask(mask, top_, left_, height_, width_):
1181
1239
torch .testing .assert_close (output_mask , expected_mask )
1182
1240
1183
1241
1242
+ @pytest .mark .parametrize ("device" , cpu_and_gpu ())
1243
+ def test_correctness_horizontal_flip_segmentation_mask_on_fixed_input (device ):
1244
+ mask = torch .zeros ((3 , 3 , 3 ), dtype = torch .long , device = device )
1245
+ mask [:, :, 0 ] = 1
1246
+
1247
+ out_mask = F .horizontal_flip_segmentation_mask (mask )
1248
+
1249
+ expected_mask = torch .zeros ((3 , 3 , 3 ), dtype = torch .long , device = device )
1250
+ expected_mask [:, :, - 1 ] = 1
1251
+ torch .testing .assert_close (out_mask , expected_mask )
1252
+
1253
+
1184
1254
@pytest .mark .parametrize ("device" , cpu_and_gpu ())
1185
1255
def test_correctness_vertical_flip_segmentation_mask_on_fixed_input (device ):
1186
1256
mask = torch .zeros ((3 , 3 , 3 ), dtype = torch .long , device = device )
0 commit comments