Skip to content

Commit 2099d00

Browse files
committed
Added tests
1 parent fa1a846 commit 2099d00

File tree

2 files changed

+53
-7
lines changed

2 files changed

+53
-7
lines changed

test/test_prototype_transforms_functional.py

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def make_images(
5959
yield make_image(size, color_space=color_space, dtype=dtype)
6060

6161
for color_space, dtype, extra_dims_ in itertools.product(color_spaces, dtypes, extra_dims):
62-
yield make_image(color_space=color_space, extra_dims=extra_dims_, dtype=dtype)
62+
yield make_image(size=sizes[0], color_space=color_space, extra_dims=extra_dims_, dtype=dtype)
6363

6464

6565
def randint_with_tensor_bounds(arg1, arg2=None, **kwargs):
@@ -149,12 +149,12 @@ def make_segmentation_mask(size=None, *, num_categories=80, extra_dims=(), dtype
149149

150150

151151
def make_segmentation_masks(
152-
image_sizes=((16, 16), (7, 33), (31, 9)),
152+
sizes=((16, 16), (7, 33), (31, 9)),
153153
dtypes=(torch.long,),
154154
extra_dims=((), (4,), (2, 3)),
155155
):
156-
for image_size, dtype, extra_dims_ in itertools.product(image_sizes, dtypes, extra_dims):
157-
yield make_segmentation_mask(size=image_size, dtype=dtype, extra_dims=extra_dims_)
156+
for size, dtype, extra_dims_ in itertools.product(sizes, dtypes, extra_dims):
157+
yield make_segmentation_mask(size=size, dtype=dtype, extra_dims=extra_dims_)
158158

159159

160160
class SampleInput:
@@ -587,7 +587,7 @@ def center_crop_bounding_box():
587587
@register_kernel_info_from_sample_inputs_fn
588588
def center_crop_segmentation_mask():
589589
for mask, output_size in itertools.product(
590-
make_segmentation_masks(image_sizes=((16, 16), (7, 33), (31, 9))),
590+
make_segmentation_masks(sizes=((16, 16), (7, 33), (31, 9))),
591591
[[4, 3], [42, 70], [4]], # crop sizes < image sizes, crop_sizes > image sizes, single crop size
592592
):
593593
yield SampleInput(mask, output_size)
@@ -1785,5 +1785,50 @@ def test_correctness_gaussian_blur_image_tensor(device, image_size, dt, ksize, s
17851785
torch.tensor(true_cv2_results[gt_key]).reshape(shape[-2], shape[-1], shape[-3]).permute(2, 0, 1).to(tensor)
17861786
)
17871787

1788-
out = fn(tensor, kernel_size=ksize, sigma=sigma)
1788+
image = features.Image(tensor)
1789+
1790+
out = fn(image, kernel_size=ksize, sigma=sigma)
17891791
torch.testing.assert_close(out, true_out, rtol=0.0, atol=1.0, msg=f"{ksize}, {sigma}")
1792+
1793+
1794+
@pytest.mark.parametrize("device", cpu_and_gpu())
1795+
@pytest.mark.parametrize(
1796+
"fn, make_samples", [(F.elastic_image_tensor, make_images), (F.elastic_segmentation_mask, make_segmentation_masks)]
1797+
)
1798+
def test_correctness_elastic_image_or_mask_tensor(device, fn, make_samples):
1799+
in_box = [10, 15, 25, 35]
1800+
for sample in make_samples(sizes=((64, 76),), extra_dims=((), (4,))):
1801+
c, h, w = sample.shape[-3:]
1802+
# Setup a dummy image with 4 points
1803+
sample[..., in_box[1], in_box[0]] = torch.tensor([12, 34, 96, 112])[:c]
1804+
sample[..., in_box[3] - 1, in_box[0]] = torch.tensor([12, 34, 96, 112])[:c]
1805+
sample[..., in_box[3] - 1, in_box[2] - 1] = torch.tensor([12, 34, 96, 112])[:c]
1806+
sample[..., in_box[1], in_box[2] - 1] = torch.tensor([12, 34, 96, 112])[:c]
1807+
sample = sample.to(device)
1808+
1809+
if fn == F.elastic_image_tensor:
1810+
sample = features.Image(sample)
1811+
kwargs = {"interpolation": F.InterpolationMode.NEAREST}
1812+
else:
1813+
sample = features.SegmentationMask(sample)
1814+
kwargs = {}
1815+
1816+
# Create a displacement grid using sin
1817+
n, m = 5.0, 0.1
1818+
d1 = m * torch.sin(torch.arange(h, dtype=torch.float) * torch.pi * n / h)
1819+
d2 = m * torch.sin(torch.arange(w, dtype=torch.float) * torch.pi * n / w)
1820+
1821+
d1 = d1[:, None].expand((h, w))
1822+
d2 = d2[None, :].expand((h, w))
1823+
1824+
displacement = torch.cat([d1[..., None], d2[..., None]], dim=-1)
1825+
displacement = displacement.reshape(1, h, w, 2)
1826+
1827+
print(sample.dtype, sample.shape)
1828+
output = fn(sample, displacement=displacement, **kwargs)
1829+
1830+
# Check places where transformed points should be
1831+
torch.testing.assert_close(output[..., 12, 9], sample[..., in_box[1], in_box[0]])
1832+
torch.testing.assert_close(output[..., 17, 27], sample[..., in_box[1], in_box[2] - 1])
1833+
torch.testing.assert_close(output[..., 31, 6], sample[..., in_box[3] - 1, in_box[0]])
1834+
torch.testing.assert_close(output[..., 37, 23], sample[..., in_box[3] - 1, in_box[2] - 1])

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -826,7 +826,8 @@ def elastic_bounding_box(
826826
image_size = displacement.shape[-3], displacement.shape[-2]
827827

828828
id_grid = _FT._create_identity_grid(list(image_size)).to(bounding_box.device)
829-
# We construct inverse grid vs grid = id_grid + displacement used for images
829+
# We construct an approximation of inverse grid as inv_grid = id_grid - displacement
830+
# This is not an exact inverse of the grid
830831
inv_grid = id_grid - displacement
831832

832833
# Get points from bboxes

0 commit comments

Comments
 (0)