@@ -59,7 +59,7 @@ def make_images(
59
59
yield make_image (size , color_space = color_space , dtype = dtype )
60
60
61
61
for color_space , dtype , extra_dims_ in itertools .product (color_spaces , dtypes , extra_dims ):
62
- yield make_image (color_space = color_space , extra_dims = extra_dims_ , dtype = dtype )
62
+ yield make_image (size = sizes [ 0 ], color_space = color_space , extra_dims = extra_dims_ , dtype = dtype )
63
63
64
64
65
65
def randint_with_tensor_bounds (arg1 , arg2 = None , ** kwargs ):
@@ -149,12 +149,12 @@ def make_segmentation_mask(size=None, *, num_categories=80, extra_dims=(), dtype
149
149
150
150
151
151
def make_segmentation_masks (
152
- image_sizes = ((16 , 16 ), (7 , 33 ), (31 , 9 )),
152
+ sizes = ((16 , 16 ), (7 , 33 ), (31 , 9 )),
153
153
dtypes = (torch .long ,),
154
154
extra_dims = ((), (4 ,), (2 , 3 )),
155
155
):
156
- for image_size , dtype , extra_dims_ in itertools .product (image_sizes , dtypes , extra_dims ):
157
- yield make_segmentation_mask (size = image_size , dtype = dtype , extra_dims = extra_dims_ )
156
+ for size , dtype , extra_dims_ in itertools .product (sizes , dtypes , extra_dims ):
157
+ yield make_segmentation_mask (size = size , dtype = dtype , extra_dims = extra_dims_ )
158
158
159
159
160
160
class SampleInput :
@@ -587,7 +587,7 @@ def center_crop_bounding_box():
587
587
@register_kernel_info_from_sample_inputs_fn
588
588
def center_crop_segmentation_mask ():
589
589
for mask , output_size in itertools .product (
590
- make_segmentation_masks (image_sizes = ((16 , 16 ), (7 , 33 ), (31 , 9 ))),
590
+ make_segmentation_masks (sizes = ((16 , 16 ), (7 , 33 ), (31 , 9 ))),
591
591
[[4 , 3 ], [42 , 70 ], [4 ]], # crop sizes < image sizes, crop_sizes > image sizes, single crop size
592
592
):
593
593
yield SampleInput (mask , output_size )
@@ -1785,5 +1785,50 @@ def test_correctness_gaussian_blur_image_tensor(device, image_size, dt, ksize, s
1785
1785
torch .tensor (true_cv2_results [gt_key ]).reshape (shape [- 2 ], shape [- 1 ], shape [- 3 ]).permute (2 , 0 , 1 ).to (tensor )
1786
1786
)
1787
1787
1788
- out = fn (tensor , kernel_size = ksize , sigma = sigma )
1788
+ image = features .Image (tensor )
1789
+
1790
+ out = fn (image , kernel_size = ksize , sigma = sigma )
1789
1791
torch .testing .assert_close (out , true_out , rtol = 0.0 , atol = 1.0 , msg = f"{ ksize } , { sigma } " )
1792
+
1793
+
1794
+ @pytest .mark .parametrize ("device" , cpu_and_gpu ())
1795
+ @pytest .mark .parametrize (
1796
+ "fn, make_samples" , [(F .elastic_image_tensor , make_images ), (F .elastic_segmentation_mask , make_segmentation_masks )]
1797
+ )
1798
+ def test_correctness_elastic_image_or_mask_tensor (device , fn , make_samples ):
1799
+ in_box = [10 , 15 , 25 , 35 ]
1800
+ for sample in make_samples (sizes = ((64 , 76 ),), extra_dims = ((), (4 ,))):
1801
+ c , h , w = sample .shape [- 3 :]
1802
+ # Setup a dummy image with 4 points
1803
+ sample [..., in_box [1 ], in_box [0 ]] = torch .tensor ([12 , 34 , 96 , 112 ])[:c ]
1804
+ sample [..., in_box [3 ] - 1 , in_box [0 ]] = torch .tensor ([12 , 34 , 96 , 112 ])[:c ]
1805
+ sample [..., in_box [3 ] - 1 , in_box [2 ] - 1 ] = torch .tensor ([12 , 34 , 96 , 112 ])[:c ]
1806
+ sample [..., in_box [1 ], in_box [2 ] - 1 ] = torch .tensor ([12 , 34 , 96 , 112 ])[:c ]
1807
+ sample = sample .to (device )
1808
+
1809
+ if fn == F .elastic_image_tensor :
1810
+ sample = features .Image (sample )
1811
+ kwargs = {"interpolation" : F .InterpolationMode .NEAREST }
1812
+ else :
1813
+ sample = features .SegmentationMask (sample )
1814
+ kwargs = {}
1815
+
1816
+ # Create a displacement grid using sin
1817
+ n , m = 5.0 , 0.1
1818
+ d1 = m * torch .sin (torch .arange (h , dtype = torch .float ) * torch .pi * n / h )
1819
+ d2 = m * torch .sin (torch .arange (w , dtype = torch .float ) * torch .pi * n / w )
1820
+
1821
+ d1 = d1 [:, None ].expand ((h , w ))
1822
+ d2 = d2 [None , :].expand ((h , w ))
1823
+
1824
+ displacement = torch .cat ([d1 [..., None ], d2 [..., None ]], dim = - 1 )
1825
+ displacement = displacement .reshape (1 , h , w , 2 )
1826
+
1827
+ print (sample .dtype , sample .shape )
1828
+ output = fn (sample , displacement = displacement , ** kwargs )
1829
+
1830
+ # Check places where transformed points should be
1831
+ torch .testing .assert_close (output [..., 12 , 9 ], sample [..., in_box [1 ], in_box [0 ]])
1832
+ torch .testing .assert_close (output [..., 17 , 27 ], sample [..., in_box [1 ], in_box [2 ] - 1 ])
1833
+ torch .testing .assert_close (output [..., 31 , 6 ], sample [..., in_box [3 ] - 1 , in_box [0 ]])
1834
+ torch .testing .assert_close (output [..., 37 , 23 ], sample [..., in_box [3 ] - 1 , in_box [2 ] - 1 ])
0 commit comments