-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Use torch.testing.assert_close in test_functional_tensor #3876
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
NicolasHug
merged 30 commits into
pytorch:master
from
NicolasHug:assert_close_func_tensor
May 24, 2021
Merged
Changes from all commits
Commits
Show all changes
30 commits
Select commit
Hold shift + click to select a range
c4fc01b
adopt `torch.testing.assert_close` in test suite
pmeier bfbe19b
revert some changes
pmeier 09f86f4
add todo
pmeier 86402f0
flake8
pmeier 48d32e6
Hopefully fixed test_functional_tensor
NicolasHug 15b50e3
hopefully fixed test_ops
NicolasHug a54880f
Merge branch 'master' of github.com:pytorch/vision into assert-close
NicolasHug 61874ac
Fix test_utils
NicolasHug 30f20a3
revert unwanted changes to test_image
NicolasHug 3a29ae3
maybe fixed test_transforms
NicolasHug d6d73d0
Merge branch 'master' into assert-close
NicolasHug 863f144
fix test_datasets_video_utils
pmeier c8a5afa
fix test_transforms
pmeier e697e88
Merge branch 'master' into assert-close
pmeier 93614f0
flake8
pmeier 11caf01
Merge branch 'master' of github.com:pytorch/vision into assert-close
NicolasHug d7fde8c
Merge branch 'assert-close' of github.com:pmeier/vision into assert-c…
NicolasHug 0b237c7
use cu102 see if the nightlies are actual nightlies?
NicolasHug c2ace86
obviously forgot to call regenerate.py
NicolasHug d78226a
not as obvious, reverting
NicolasHug bb543a7
Merge branch 'master' into assert-close
NicolasHug 7507a0c
Merge branch 'master' into assert-close
NicolasHug 4d2fbfb
revert everything but functional_tensor
NicolasHug 7f0f769
Merge branch 'master' into assert_close_func_tensor
NicolasHug 81f6604
Merge branch 'master' into assert_close_func_tensor
NicolasHug dcbe8b4
Merge branch 'master' into assert_close_func_tensor
NicolasHug 3c16b26
Merge branch 'master' of github.com:pytorch/vision into assert_close_…
NicolasHug f009602
Convert PIL images to arrays so we can rely on assert_equal
NicolasHug 54a124a
Merge branch 'assert_close_func_tensor' of github.com:NicolasHug/visi…
NicolasHug df4c366
Merge branch 'master' into assert_close_func_tensor
NicolasHug File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,7 @@ | |
from torchvision.transforms import InterpolationMode | ||
|
||
from common_utils import TransformsTester, cpu_and_gpu, needs_cuda | ||
from _assert_utils import assert_equal | ||
|
||
from typing import Dict, List, Sequence, Tuple | ||
|
||
|
@@ -39,13 +40,13 @@ def _test_fn_on_batch(self, batch_tensors, fn, scripted_fn_atol=1e-8, **fn_kwarg | |
for i in range(len(batch_tensors)): | ||
img_tensor = batch_tensors[i, ...] | ||
transformed_img = fn(img_tensor, **fn_kwargs) | ||
self.assertTrue(transformed_img.equal(transformed_batch[i, ...])) | ||
assert_equal(transformed_img, transformed_batch[i, ...]) | ||
|
||
if scripted_fn_atol >= 0: | ||
scripted_fn = torch.jit.script(fn) | ||
# scriptable function test | ||
s_transformed_batch = scripted_fn(batch_tensors, **fn_kwargs) | ||
self.assertTrue(transformed_batch.allclose(s_transformed_batch, atol=scripted_fn_atol)) | ||
torch.testing.assert_close(transformed_batch, s_transformed_batch, rtol=1e-5, atol=scripted_fn_atol) | ||
|
||
def test_assert_image_tensor(self): | ||
shape = (100,) | ||
|
@@ -79,7 +80,7 @@ def test_vflip(self): | |
|
||
# scriptable function test | ||
vflipped_img_script = script_vflip(img_tensor) | ||
self.assertTrue(vflipped_img.equal(vflipped_img_script)) | ||
assert_equal(vflipped_img, vflipped_img_script) | ||
|
||
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device) | ||
self._test_fn_on_batch(batch_tensors, F.vflip) | ||
|
@@ -94,7 +95,7 @@ def test_hflip(self): | |
|
||
# scriptable function test | ||
hflipped_img_script = script_hflip(img_tensor) | ||
self.assertTrue(hflipped_img.equal(hflipped_img_script)) | ||
assert_equal(hflipped_img, hflipped_img_script) | ||
|
||
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device) | ||
self._test_fn_on_batch(batch_tensors, F.hflip) | ||
|
@@ -140,11 +141,10 @@ def test_hsv2rgb(self): | |
for h1, s1, v1 in zip(h, s, v): | ||
rgb.append(colorsys.hsv_to_rgb(h1, s1, v1)) | ||
colorsys_img = torch.tensor(rgb, dtype=torch.float32, device=self.device) | ||
max_diff = (ft_img - colorsys_img).abs().max() | ||
self.assertLess(max_diff, 1e-5) | ||
torch.testing.assert_close(ft_img, colorsys_img, rtol=0.0, atol=1e-5) | ||
|
||
s_rgb_img = scripted_fn(hsv_img) | ||
self.assertTrue(rgb_img.allclose(s_rgb_img)) | ||
torch.testing.assert_close(rgb_img, s_rgb_img) | ||
|
||
batch_tensors = self._create_data_batch(120, 100, num_samples=4, device=self.device).float() | ||
self._test_fn_on_batch(batch_tensors, F_t._hsv2rgb) | ||
|
@@ -177,7 +177,7 @@ def test_rgb2hsv(self): | |
self.assertLess(max_diff, 1e-5) | ||
|
||
s_hsv_img = scripted_fn(rgb_img) | ||
self.assertTrue(hsv_img.allclose(s_hsv_img, atol=1e-7)) | ||
torch.testing.assert_close(hsv_img, s_hsv_img, rtol=1e-5, atol=1e-7) | ||
|
||
batch_tensors = self._create_data_batch(120, 100, num_samples=4, device=self.device).float() | ||
self._test_fn_on_batch(batch_tensors, F_t._rgb2hsv) | ||
|
@@ -194,7 +194,7 @@ def test_rgb_to_grayscale(self): | |
self.approxEqualTensorToPIL(gray_tensor.float(), gray_pil_image, tol=1.0 + 1e-10, agg_method="max") | ||
|
||
s_gray_tensor = script_rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels) | ||
self.assertTrue(s_gray_tensor.equal(gray_tensor)) | ||
assert_equal(s_gray_tensor, gray_tensor) | ||
|
||
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device) | ||
self._test_fn_on_batch(batch_tensors, F.rgb_to_grayscale, num_output_channels=num_output_channels) | ||
|
@@ -240,12 +240,12 @@ def test_five_crop(self): | |
for j in range(len(tuple_transformed_imgs)): | ||
true_transformed_img = tuple_transformed_imgs[j] | ||
transformed_img = tuple_transformed_batches[j][i, ...] | ||
self.assertTrue(true_transformed_img.equal(transformed_img)) | ||
assert_equal(true_transformed_img, transformed_img) | ||
|
||
# scriptable function test | ||
s_tuple_transformed_batches = script_five_crop(batch_tensors, [10, 11]) | ||
for transformed_batch, s_transformed_batch in zip(tuple_transformed_batches, s_tuple_transformed_batches): | ||
self.assertTrue(transformed_batch.equal(s_transformed_batch)) | ||
assert_equal(transformed_batch, s_transformed_batch) | ||
|
||
def test_ten_crop(self): | ||
script_ten_crop = torch.jit.script(F.ten_crop) | ||
|
@@ -272,12 +272,12 @@ def test_ten_crop(self): | |
for j in range(len(tuple_transformed_imgs)): | ||
true_transformed_img = tuple_transformed_imgs[j] | ||
transformed_img = tuple_transformed_batches[j][i, ...] | ||
self.assertTrue(true_transformed_img.equal(transformed_img)) | ||
assert_equal(true_transformed_img, transformed_img) | ||
|
||
# scriptable function test | ||
s_tuple_transformed_batches = script_ten_crop(batch_tensors, [10, 11]) | ||
for transformed_batch, s_transformed_batch in zip(tuple_transformed_batches, s_tuple_transformed_batches): | ||
self.assertTrue(transformed_batch.equal(s_transformed_batch)) | ||
assert_equal(transformed_batch, s_transformed_batch) | ||
|
||
def test_pad(self): | ||
script_fn = torch.jit.script(F.pad) | ||
|
@@ -320,7 +320,7 @@ def test_pad(self): | |
else: | ||
script_pad = pad | ||
pad_tensor_script = script_fn(tensor, script_pad, **kwargs) | ||
self.assertTrue(pad_tensor.equal(pad_tensor_script), msg="{}, {}".format(pad, kwargs)) | ||
assert_equal(pad_tensor, pad_tensor_script, msg="{}, {}".format(pad, kwargs)) | ||
|
||
self._test_fn_on_batch(batch_tensors, F.pad, padding=script_pad, **kwargs) | ||
|
||
|
@@ -348,9 +348,10 @@ def test_resize(self): | |
resized_tensor = F.resize(tensor, size=size, interpolation=interpolation, max_size=max_size) | ||
resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation, max_size=max_size) | ||
|
||
self.assertEqual( | ||
resized_tensor.size()[1:], resized_pil_img.size[::-1], | ||
msg="{}, {}".format(size, interpolation) | ||
assert_equal( | ||
resized_tensor.size()[1:], | ||
resized_pil_img.size[::-1], | ||
msg="{}, {}".format(size, interpolation), | ||
) | ||
|
||
if interpolation not in [NEAREST, ]: | ||
|
@@ -374,7 +375,7 @@ def test_resize(self): | |
|
||
resize_result = script_fn(tensor, size=script_size, interpolation=interpolation, | ||
max_size=max_size) | ||
self.assertTrue(resized_tensor.equal(resize_result), msg="{}, {}".format(size, interpolation)) | ||
assert_equal(resized_tensor, resize_result, msg="{}, {}".format(size, interpolation)) | ||
|
||
self._test_fn_on_batch( | ||
batch_tensors, F.resize, size=script_size, interpolation=interpolation, max_size=max_size | ||
|
@@ -384,7 +385,7 @@ def test_resize(self): | |
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"): | ||
res1 = F.resize(tensor, size=32, interpolation=2) | ||
res2 = F.resize(tensor, size=32, interpolation=BILINEAR) | ||
self.assertTrue(res1.equal(res2)) | ||
assert_equal(res1, res2) | ||
|
||
for img in (tensor, pil_img): | ||
exp_msg = "max_size should only be passed if size specifies the length of the smaller edge" | ||
|
@@ -400,15 +401,17 @@ def test_resized_crop(self): | |
|
||
for mode in [NEAREST, BILINEAR, BICUBIC]: | ||
out_tensor = F.resized_crop(tensor, top=0, left=0, height=26, width=36, size=[26, 36], interpolation=mode) | ||
self.assertTrue(tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])) | ||
assert_equal(tensor, out_tensor, msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])) | ||
|
||
# 2) resize by half and crop a TL corner | ||
tensor, _ = self._create_data(26, 36, device=self.device) | ||
out_tensor = F.resized_crop(tensor, top=0, left=0, height=20, width=30, size=[10, 15], interpolation=NEAREST) | ||
expected_out_tensor = tensor[:, :20:2, :30:2] | ||
self.assertTrue( | ||
expected_out_tensor.equal(out_tensor), | ||
msg="{} vs {}".format(expected_out_tensor[0, :10, :10], out_tensor[0, :10, :10]) | ||
assert_equal( | ||
expected_out_tensor, | ||
out_tensor, | ||
check_stride=False, | ||
msg="{} vs {}".format(expected_out_tensor[0, :10, :10], out_tensor[0, :10, :10]), | ||
) | ||
|
||
batch_tensors = self._create_data_batch(26, 36, num_samples=4, device=self.device) | ||
|
@@ -420,15 +423,11 @@ def _test_affine_identity_map(self, tensor, scripted_affine): | |
# 1) identity map | ||
out_tensor = F.affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST) | ||
|
||
self.assertTrue( | ||
tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5]) | ||
) | ||
assert_equal(tensor, out_tensor, msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])) | ||
out_tensor = scripted_affine( | ||
tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST | ||
) | ||
self.assertTrue( | ||
tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5]) | ||
) | ||
assert_equal(tensor, out_tensor, msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])) | ||
|
||
def _test_affine_square_rotations(self, tensor, pil_img, scripted_affine): | ||
# 2) Test rotation | ||
|
@@ -452,9 +451,11 @@ def _test_affine_square_rotations(self, tensor, pil_img, scripted_affine): | |
tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST | ||
) | ||
if true_tensor is not None: | ||
self.assertTrue( | ||
true_tensor.equal(out_tensor), | ||
msg="{}\n{} vs \n{}".format(a, out_tensor[0, :5, :5], true_tensor[0, :5, :5]) | ||
assert_equal( | ||
true_tensor, | ||
out_tensor, | ||
msg="{}\n{} vs \n{}".format(a, out_tensor[0, :5, :5], true_tensor[0, :5, :5]), | ||
check_stride=False, | ||
) | ||
|
||
if out_tensor.dtype != torch.uint8: | ||
|
@@ -593,18 +594,19 @@ def test_affine(self): | |
with self.assertWarnsRegex(UserWarning, r"Argument resample is deprecated and will be removed"): | ||
res1 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=2) | ||
res2 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=BILINEAR) | ||
self.assertTrue(res1.equal(res2)) | ||
assert_equal(res1, res2) | ||
|
||
# assert changed type warning | ||
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"): | ||
res1 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=2) | ||
res2 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=BILINEAR) | ||
self.assertTrue(res1.equal(res2)) | ||
assert_equal(res1, res2) | ||
|
||
with self.assertWarnsRegex(UserWarning, r"Argument fillcolor is deprecated and will be removed"): | ||
res1 = F.affine(pil_img, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], fillcolor=10) | ||
res2 = F.affine(pil_img, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], fill=10) | ||
self.assertEqual(res1, res2) | ||
# we convert the PIL images to numpy as assert_equal doesn't work on PIL images. | ||
assert_equal(np.asarray(res1), np.asarray(res2)) | ||
|
||
def _test_rotate_all_options(self, tensor, pil_img, scripted_rotate, centers): | ||
img_size = pil_img.size | ||
|
@@ -682,13 +684,13 @@ def test_rotate(self): | |
with self.assertWarnsRegex(UserWarning, r"Argument resample is deprecated and will be removed"): | ||
res1 = F.rotate(tensor, 45, resample=2) | ||
res2 = F.rotate(tensor, 45, interpolation=BILINEAR) | ||
self.assertTrue(res1.equal(res2)) | ||
assert_equal(res1, res2) | ||
|
||
# assert changed type warning | ||
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"): | ||
res1 = F.rotate(tensor, 45, interpolation=2) | ||
res2 = F.rotate(tensor, 45, interpolation=BILINEAR) | ||
self.assertTrue(res1.equal(res2)) | ||
assert_equal(res1, res2) | ||
|
||
def test_gaussian_blur(self): | ||
small_image_tensor = torch.from_numpy( | ||
|
@@ -747,10 +749,8 @@ def test_gaussian_blur(self): | |
|
||
for fn in [F.gaussian_blur, scripted_transform]: | ||
out = fn(tensor, kernel_size=ksize, sigma=sigma) | ||
self.assertEqual(true_out.shape, out.shape, msg="{}, {}".format(ksize, sigma)) | ||
self.assertLessEqual( | ||
torch.max(true_out.float() - out.float()), | ||
1.0, | ||
torch.testing.assert_close( | ||
out, true_out, rtol=0.0, atol=1.0, check_stride=False, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we're comparing the max abs difference here instead of the max difference in master, but it's probably more correct with the proposed changes |
||
msg="{}, {}".format(ksize, sigma) | ||
) | ||
|
||
|
@@ -771,7 +771,7 @@ def test_scale_channel(self): | |
img_chan = torch.randint(0, 256, size=size).to('cpu') | ||
scaled_cpu = F_t._scale_channel(img_chan) | ||
scaled_cuda = F_t._scale_channel(img_chan.to('cuda')) | ||
self.assertTrue(scaled_cpu.equal(scaled_cuda.to('cpu'))) | ||
assert_equal(scaled_cpu, scaled_cuda.to('cpu')) | ||
|
||
|
||
def _get_data_dims_and_points_for_perspective(): | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For ref the
rtol=1e-5
comes from the current default of https://pytorch.org/docs/stable/generated/torch.allclose.htmlI'd prefer to leave
rtol
to its default inassert_close
if possible, butrtol
must be set ifatol
is set. Would you know the reason @pmeier ?np.testing.assert_allclose
doesn't have this constraint it seemsThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We went for this logic, since we have non-zero defaults for
rtol
andatol
. Imagine settingatol=0
and the tensors still match becausertol > 0
. See https://github.com/pytorch/pytorch/blob/74c12da4517c789bea737dc947d6adc755f63176/torch/testing/_asserts.py#L391-L396.