Skip to content

Use torch.testing.assert_close in test_functional_tensor #3876

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 30 commits into from
May 24, 2021
Merged
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
c4fc01b
adopt `torch.testing.assert_close` in test suite
pmeier May 20, 2021
bfbe19b
revert some changes
pmeier May 20, 2021
09f86f4
add todo
pmeier May 20, 2021
86402f0
flake8
pmeier May 20, 2021
48d32e6
Hopefully fixed test_functional_tensor
NicolasHug May 20, 2021
15b50e3
hopefully fixed test_ops
NicolasHug May 20, 2021
a54880f
Merge branch 'master' of github.com:pytorch/vision into assert-close
NicolasHug May 20, 2021
61874ac
Fix test_utils
NicolasHug May 20, 2021
30f20a3
revert unwanted changes to test_image
NicolasHug May 20, 2021
3a29ae3
maybe fixed test_transforms
NicolasHug May 20, 2021
d6d73d0
Merge branch 'master' into assert-close
NicolasHug May 20, 2021
863f144
fix test_datasets_video_utils
pmeier May 21, 2021
c8a5afa
fix test_transforms
pmeier May 21, 2021
e697e88
Merge branch 'master' into assert-close
pmeier May 21, 2021
93614f0
flake8
pmeier May 21, 2021
11caf01
Merge branch 'master' of github.com:pytorch/vision into assert-close
NicolasHug May 21, 2021
d7fde8c
Merge branch 'assert-close' of github.com:pmeier/vision into assert-c…
NicolasHug May 21, 2021
0b237c7
use cu102 see if the nightlies are actual nightlies?
NicolasHug May 21, 2021
c2ace86
obviously forgot to call regenerate.py
NicolasHug May 21, 2021
d78226a
not as obvious, reverting
NicolasHug May 21, 2021
bb543a7
Merge branch 'master' into assert-close
NicolasHug May 21, 2021
7507a0c
Merge branch 'master' into assert-close
NicolasHug May 21, 2021
4d2fbfb
revert everything but functional_tensor
NicolasHug May 21, 2021
7f0f769
Merge branch 'master' into assert_close_func_tensor
NicolasHug May 21, 2021
81f6604
Merge branch 'master' into assert_close_func_tensor
NicolasHug May 21, 2021
dcbe8b4
Merge branch 'master' into assert_close_func_tensor
NicolasHug May 22, 2021
3c16b26
Merge branch 'master' of github.com:pytorch/vision into assert_close_…
NicolasHug May 24, 2021
f009602
Convert PIL images to arrays so we can rely on assert_equal
NicolasHug May 24, 2021
54a124a
Merge branch 'assert_close_func_tensor' of github.com:NicolasHug/visi…
NicolasHug May 24, 2021
df4c366
Merge branch 'master' into assert_close_func_tensor
NicolasHug May 24, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 42 additions & 42 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from torchvision.transforms import InterpolationMode

from common_utils import TransformsTester, cpu_and_gpu, needs_cuda
from _assert_utils import assert_equal

from typing import Dict, List, Sequence, Tuple

Expand All @@ -39,13 +40,13 @@ def _test_fn_on_batch(self, batch_tensors, fn, scripted_fn_atol=1e-8, **fn_kwarg
for i in range(len(batch_tensors)):
img_tensor = batch_tensors[i, ...]
transformed_img = fn(img_tensor, **fn_kwargs)
self.assertTrue(transformed_img.equal(transformed_batch[i, ...]))
assert_equal(transformed_img, transformed_batch[i, ...])

if scripted_fn_atol >= 0:
scripted_fn = torch.jit.script(fn)
# scriptable function test
s_transformed_batch = scripted_fn(batch_tensors, **fn_kwargs)
self.assertTrue(transformed_batch.allclose(s_transformed_batch, atol=scripted_fn_atol))
torch.testing.assert_close(transformed_batch, s_transformed_batch, rtol=1e-5, atol=scripted_fn_atol)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For ref the rtol=1e-5 comes from the current default of https://pytorch.org/docs/stable/generated/torch.allclose.html

I'd prefer to leave rtol to its default in assert_close if possible, but rtol must be set if atol is set. Would you know the reason @pmeier ? np.testing.assert_allclose doesn't have this constraint it seems

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We went for this logic, since we have non-zero defaults for rtol and atol. Imagine setting atol=0 and the tensors still match because rtol > 0. See https://github.com/pytorch/pytorch/blob/74c12da4517c789bea737dc947d6adc755f63176/torch/testing/_asserts.py#L391-L396.


def test_assert_image_tensor(self):
shape = (100,)
Expand Down Expand Up @@ -79,7 +80,7 @@ def test_vflip(self):

# scriptable function test
vflipped_img_script = script_vflip(img_tensor)
self.assertTrue(vflipped_img.equal(vflipped_img_script))
assert_equal(vflipped_img, vflipped_img_script)

batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
self._test_fn_on_batch(batch_tensors, F.vflip)
Expand All @@ -94,7 +95,7 @@ def test_hflip(self):

# scriptable function test
hflipped_img_script = script_hflip(img_tensor)
self.assertTrue(hflipped_img.equal(hflipped_img_script))
assert_equal(hflipped_img, hflipped_img_script)

batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
self._test_fn_on_batch(batch_tensors, F.hflip)
Expand Down Expand Up @@ -140,11 +141,10 @@ def test_hsv2rgb(self):
for h1, s1, v1 in zip(h, s, v):
rgb.append(colorsys.hsv_to_rgb(h1, s1, v1))
colorsys_img = torch.tensor(rgb, dtype=torch.float32, device=self.device)
max_diff = (ft_img - colorsys_img).abs().max()
self.assertLess(max_diff, 1e-5)
torch.testing.assert_close(ft_img, colorsys_img, rtol=0.0, atol=1e-5)

s_rgb_img = scripted_fn(hsv_img)
self.assertTrue(rgb_img.allclose(s_rgb_img))
torch.testing.assert_close(rgb_img, s_rgb_img)

batch_tensors = self._create_data_batch(120, 100, num_samples=4, device=self.device).float()
self._test_fn_on_batch(batch_tensors, F_t._hsv2rgb)
Expand Down Expand Up @@ -177,7 +177,7 @@ def test_rgb2hsv(self):
self.assertLess(max_diff, 1e-5)

s_hsv_img = scripted_fn(rgb_img)
self.assertTrue(hsv_img.allclose(s_hsv_img, atol=1e-7))
torch.testing.assert_close(hsv_img, s_hsv_img, rtol=1e-5, atol=1e-7)

batch_tensors = self._create_data_batch(120, 100, num_samples=4, device=self.device).float()
self._test_fn_on_batch(batch_tensors, F_t._rgb2hsv)
Expand All @@ -194,7 +194,7 @@ def test_rgb_to_grayscale(self):
self.approxEqualTensorToPIL(gray_tensor.float(), gray_pil_image, tol=1.0 + 1e-10, agg_method="max")

s_gray_tensor = script_rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels)
self.assertTrue(s_gray_tensor.equal(gray_tensor))
assert_equal(s_gray_tensor, gray_tensor)

batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
self._test_fn_on_batch(batch_tensors, F.rgb_to_grayscale, num_output_channels=num_output_channels)
Expand Down Expand Up @@ -240,12 +240,12 @@ def test_five_crop(self):
for j in range(len(tuple_transformed_imgs)):
true_transformed_img = tuple_transformed_imgs[j]
transformed_img = tuple_transformed_batches[j][i, ...]
self.assertTrue(true_transformed_img.equal(transformed_img))
assert_equal(true_transformed_img, transformed_img)

# scriptable function test
s_tuple_transformed_batches = script_five_crop(batch_tensors, [10, 11])
for transformed_batch, s_transformed_batch in zip(tuple_transformed_batches, s_tuple_transformed_batches):
self.assertTrue(transformed_batch.equal(s_transformed_batch))
assert_equal(transformed_batch, s_transformed_batch)

def test_ten_crop(self):
script_ten_crop = torch.jit.script(F.ten_crop)
Expand All @@ -272,12 +272,12 @@ def test_ten_crop(self):
for j in range(len(tuple_transformed_imgs)):
true_transformed_img = tuple_transformed_imgs[j]
transformed_img = tuple_transformed_batches[j][i, ...]
self.assertTrue(true_transformed_img.equal(transformed_img))
assert_equal(true_transformed_img, transformed_img)

# scriptable function test
s_tuple_transformed_batches = script_ten_crop(batch_tensors, [10, 11])
for transformed_batch, s_transformed_batch in zip(tuple_transformed_batches, s_tuple_transformed_batches):
self.assertTrue(transformed_batch.equal(s_transformed_batch))
assert_equal(transformed_batch, s_transformed_batch)

def test_pad(self):
script_fn = torch.jit.script(F.pad)
Expand Down Expand Up @@ -320,7 +320,7 @@ def test_pad(self):
else:
script_pad = pad
pad_tensor_script = script_fn(tensor, script_pad, **kwargs)
self.assertTrue(pad_tensor.equal(pad_tensor_script), msg="{}, {}".format(pad, kwargs))
assert_equal(pad_tensor, pad_tensor_script, msg="{}, {}".format(pad, kwargs))

self._test_fn_on_batch(batch_tensors, F.pad, padding=script_pad, **kwargs)

Expand Down Expand Up @@ -348,9 +348,10 @@ def test_resize(self):
resized_tensor = F.resize(tensor, size=size, interpolation=interpolation, max_size=max_size)
resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation, max_size=max_size)

self.assertEqual(
resized_tensor.size()[1:], resized_pil_img.size[::-1],
msg="{}, {}".format(size, interpolation)
assert_equal(
resized_tensor.size()[1:],
resized_pil_img.size[::-1],
msg="{}, {}".format(size, interpolation),
)

if interpolation not in [NEAREST, ]:
Expand All @@ -374,7 +375,7 @@ def test_resize(self):

resize_result = script_fn(tensor, size=script_size, interpolation=interpolation,
max_size=max_size)
self.assertTrue(resized_tensor.equal(resize_result), msg="{}, {}".format(size, interpolation))
assert_equal(resized_tensor, resize_result, msg="{}, {}".format(size, interpolation))

self._test_fn_on_batch(
batch_tensors, F.resize, size=script_size, interpolation=interpolation, max_size=max_size
Expand All @@ -384,7 +385,7 @@ def test_resize(self):
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"):
res1 = F.resize(tensor, size=32, interpolation=2)
res2 = F.resize(tensor, size=32, interpolation=BILINEAR)
self.assertTrue(res1.equal(res2))
assert_equal(res1, res2)

for img in (tensor, pil_img):
exp_msg = "max_size should only be passed if size specifies the length of the smaller edge"
Expand All @@ -400,15 +401,17 @@ def test_resized_crop(self):

for mode in [NEAREST, BILINEAR, BICUBIC]:
out_tensor = F.resized_crop(tensor, top=0, left=0, height=26, width=36, size=[26, 36], interpolation=mode)
self.assertTrue(tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5]))
assert_equal(tensor, out_tensor, msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5]))

# 2) resize by half and crop a TL corner
tensor, _ = self._create_data(26, 36, device=self.device)
out_tensor = F.resized_crop(tensor, top=0, left=0, height=20, width=30, size=[10, 15], interpolation=NEAREST)
expected_out_tensor = tensor[:, :20:2, :30:2]
self.assertTrue(
expected_out_tensor.equal(out_tensor),
msg="{} vs {}".format(expected_out_tensor[0, :10, :10], out_tensor[0, :10, :10])
assert_equal(
expected_out_tensor,
out_tensor,
check_stride=False,
msg="{} vs {}".format(expected_out_tensor[0, :10, :10], out_tensor[0, :10, :10]),
)

batch_tensors = self._create_data_batch(26, 36, num_samples=4, device=self.device)
Expand All @@ -420,15 +423,11 @@ def _test_affine_identity_map(self, tensor, scripted_affine):
# 1) identity map
out_tensor = F.affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST)

self.assertTrue(
tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])
)
assert_equal(tensor, out_tensor, msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5]))
out_tensor = scripted_affine(
tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST
)
self.assertTrue(
tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])
)
assert_equal(tensor, out_tensor, msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5]))

def _test_affine_square_rotations(self, tensor, pil_img, scripted_affine):
# 2) Test rotation
Expand All @@ -452,9 +451,11 @@ def _test_affine_square_rotations(self, tensor, pil_img, scripted_affine):
tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST
)
if true_tensor is not None:
self.assertTrue(
true_tensor.equal(out_tensor),
msg="{}\n{} vs \n{}".format(a, out_tensor[0, :5, :5], true_tensor[0, :5, :5])
assert_equal(
true_tensor,
out_tensor,
msg="{}\n{} vs \n{}".format(a, out_tensor[0, :5, :5], true_tensor[0, :5, :5]),
check_stride=False,
)

if out_tensor.dtype != torch.uint8:
Expand Down Expand Up @@ -593,18 +594,19 @@ def test_affine(self):
with self.assertWarnsRegex(UserWarning, r"Argument resample is deprecated and will be removed"):
res1 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=2)
res2 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=BILINEAR)
self.assertTrue(res1.equal(res2))
assert_equal(res1, res2)

# assert changed type warning
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"):
res1 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=2)
res2 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=BILINEAR)
self.assertTrue(res1.equal(res2))
assert_equal(res1, res2)

with self.assertWarnsRegex(UserWarning, r"Argument fillcolor is deprecated and will be removed"):
res1 = F.affine(pil_img, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], fillcolor=10)
res2 = F.affine(pil_img, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], fill=10)
self.assertEqual(res1, res2)
# we convert the PIL images to numpy as assert_equal doesn't work on PIL images.
assert_equal(np.asarray(res1), np.asarray(res2))

def _test_rotate_all_options(self, tensor, pil_img, scripted_rotate, centers):
img_size = pil_img.size
Expand Down Expand Up @@ -682,13 +684,13 @@ def test_rotate(self):
with self.assertWarnsRegex(UserWarning, r"Argument resample is deprecated and will be removed"):
res1 = F.rotate(tensor, 45, resample=2)
res2 = F.rotate(tensor, 45, interpolation=BILINEAR)
self.assertTrue(res1.equal(res2))
assert_equal(res1, res2)

# assert changed type warning
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"):
res1 = F.rotate(tensor, 45, interpolation=2)
res2 = F.rotate(tensor, 45, interpolation=BILINEAR)
self.assertTrue(res1.equal(res2))
assert_equal(res1, res2)

def test_gaussian_blur(self):
small_image_tensor = torch.from_numpy(
Expand Down Expand Up @@ -747,10 +749,8 @@ def test_gaussian_blur(self):

for fn in [F.gaussian_blur, scripted_transform]:
out = fn(tensor, kernel_size=ksize, sigma=sigma)
self.assertEqual(true_out.shape, out.shape, msg="{}, {}".format(ksize, sigma))
self.assertLessEqual(
torch.max(true_out.float() - out.float()),
1.0,
torch.testing.assert_close(
out, true_out, rtol=0.0, atol=1.0, check_stride=False,
Copy link
Member Author

@NicolasHug NicolasHug May 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we're comparing the max abs difference here instead of the max difference in master, but it's probably more correct with the proposed changes

msg="{}, {}".format(ksize, sigma)
)

Expand All @@ -771,7 +771,7 @@ def test_scale_channel(self):
img_chan = torch.randint(0, 256, size=size).to('cpu')
scaled_cpu = F_t._scale_channel(img_chan)
scaled_cuda = F_t._scale_channel(img_chan.to('cuda'))
self.assertTrue(scaled_cpu.equal(scaled_cuda.to('cpu')))
assert_equal(scaled_cpu, scaled_cuda.to('cpu'))


def _get_data_dims_and_points_for_perspective():
Expand Down