Skip to content

Use torch.testing.assert_close in test_transforms_tensor.py #3885

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
May 21, 2021
Merged
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
c4fc01b
adopt `torch.testing.assert_close` in test suite
pmeier May 20, 2021
bfbe19b
revert some changes
pmeier May 20, 2021
09f86f4
add todo
pmeier May 20, 2021
86402f0
flake8
pmeier May 20, 2021
48d32e6
Hopefully fixed test_functional_tensor
NicolasHug May 20, 2021
15b50e3
hopefully fixed test_ops
NicolasHug May 20, 2021
a54880f
Merge branch 'master' of github.com:pytorch/vision into assert-close
NicolasHug May 20, 2021
61874ac
Fix test_utils
NicolasHug May 20, 2021
30f20a3
revert unwanted changes to test_image
NicolasHug May 20, 2021
3a29ae3
maybe fixed test_transforms
NicolasHug May 20, 2021
d6d73d0
Merge branch 'master' into assert-close
NicolasHug May 20, 2021
863f144
fix test_datasets_video_utils
pmeier May 21, 2021
c8a5afa
fix test_transforms
pmeier May 21, 2021
e697e88
Merge branch 'master' into assert-close
pmeier May 21, 2021
93614f0
flake8
pmeier May 21, 2021
11caf01
Merge branch 'master' of github.com:pytorch/vision into assert-close
NicolasHug May 21, 2021
d7fde8c
Merge branch 'assert-close' of github.com:pmeier/vision into assert-c…
NicolasHug May 21, 2021
0b237c7
use cu102 see if the nightlies are actual nightlies?
NicolasHug May 21, 2021
c2ace86
obviously forgot to call regenerate.py
NicolasHug May 21, 2021
d78226a
not as obvious, reverting
NicolasHug May 21, 2021
bb543a7
Merge branch 'master' into assert-close
NicolasHug May 21, 2021
7507a0c
Merge branch 'master' into assert-close
NicolasHug May 21, 2021
33de161
revert everything but transforms_tensor
NicolasHug May 21, 2021
7d4dbe5
Merge branch 'master' into assert_close_transforms_tensor
NicolasHug May 21, 2021
2880601
Merge branch 'master' into assert_close_transforms_tensor
NicolasHug May 21, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 18 additions & 11 deletions test/test_transforms_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Sequence

from common_utils import TransformsTester, get_tmp_dir, int_dtypes, float_dtypes
from _assert_utils import assert_equal


NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC
Expand Down Expand Up @@ -38,7 +39,7 @@ def _test_transform_vs_scripted(self, transform, s_transform, tensor, msg=None):
out1 = transform(tensor)
torch.manual_seed(12)
out2 = s_transform(tensor)
self.assertTrue(out1.equal(out2), msg=msg)
assert_equal(out1, out2, msg=msg)

def _test_transform_vs_scripted_on_batch(self, transform, s_transform, batch_tensors, msg=None):
torch.manual_seed(12)
Expand All @@ -48,11 +49,11 @@ def _test_transform_vs_scripted_on_batch(self, transform, s_transform, batch_ten
img_tensor = batch_tensors[i, ...]
torch.manual_seed(12)
transformed_img = transform(img_tensor)
self.assertTrue(transformed_img.equal(transformed_batch[i, ...]), msg=msg)
assert_equal(transformed_img, transformed_batch[i, ...], msg=msg)

torch.manual_seed(12)
s_transformed_batch = s_transform(batch_tensors)
self.assertTrue(transformed_batch.equal(s_transformed_batch), msg=msg)
assert_equal(transformed_batch, s_transformed_batch, msg=msg)

def _test_class_op(self, method, meth_kwargs=None, test_exact_match=True, **match_kwargs):
if meth_kwargs is None:
Expand All @@ -75,7 +76,7 @@ def _test_class_op(self, method, meth_kwargs=None, test_exact_match=True, **matc

torch.manual_seed(12)
transformed_tensor_script = scripted_fn(tensor)
self.assertTrue(transformed_tensor.equal(transformed_tensor_script))
assert_equal(transformed_tensor, transformed_tensor_script)

batch_tensors = self._create_data_batch(height=23, width=34, channels=3, num_samples=4, device=self.device)
self._test_transform_vs_scripted_on_batch(f, scripted_fn, batch_tensors)
Expand Down Expand Up @@ -270,8 +271,11 @@ def _test_op_list_output(self, func, method, out_length, fn_kwargs=None, meth_kw
self.assertEqual(len(transformed_t_list), len(transformed_t_list_script))
self.assertEqual(len(transformed_t_list_script), out_length)
for transformed_tensor, transformed_tensor_script in zip(transformed_t_list, transformed_t_list_script):
self.assertTrue(transformed_tensor.equal(transformed_tensor_script),
msg="{} vs {}".format(transformed_tensor, transformed_tensor_script))
assert_equal(
transformed_tensor,
transformed_tensor_script,
msg="{} vs {}".format(transformed_tensor, transformed_tensor_script),
)

# test for class interface
fn = getattr(T, method)(**meth_kwargs)
Expand All @@ -289,8 +293,11 @@ def _test_op_list_output(self, func, method, out_length, fn_kwargs=None, meth_kw
torch.manual_seed(12)
transformed_img_list = fn(img_tensor)
for transformed_img, transformed_batch in zip(transformed_img_list, transformed_batch_list):
self.assertTrue(transformed_img.equal(transformed_batch[i, ...]),
msg="{} vs {}".format(transformed_img, transformed_batch[i, ...]))
assert_equal(
transformed_img,
transformed_batch[i, ...],
msg="{} vs {}".format(transformed_img, transformed_batch[i, ...]),
)

with get_tmp_dir() as tmp_dir:
scripted_fn.save(os.path.join(tmp_dir, "t_op_list_{}.pt".format(method)))
Expand Down Expand Up @@ -505,7 +512,7 @@ def test_linear_transformation(self):
transformed_batch = fn(batch_tensors)
torch.manual_seed(12)
s_transformed_batch = scripted_fn(batch_tensors)
self.assertTrue(transformed_batch.equal(s_transformed_batch))
assert_equal(transformed_batch, s_transformed_batch)

with get_tmp_dir() as tmp_dir:
scripted_fn.save(os.path.join(tmp_dir, "t_norm.pt"))
Expand All @@ -525,7 +532,7 @@ def test_compose(self):
transformed_tensor = transforms(tensor)
torch.manual_seed(12)
transformed_tensor_script = scripted_fn(tensor)
self.assertTrue(transformed_tensor.equal(transformed_tensor_script), msg="{}".format(transforms))
assert_equal(transformed_tensor, transformed_tensor_script, msg="{}".format(transforms))

t = T.Compose([
lambda x: x,
Expand All @@ -551,7 +558,7 @@ def test_random_apply(self):
transformed_tensor = transforms(tensor)
torch.manual_seed(12)
transformed_tensor_script = scripted_fn(tensor)
self.assertTrue(transformed_tensor.equal(transformed_tensor_script), msg="{}".format(transforms))
assert_equal(transformed_tensor, transformed_tensor_script, msg="{}".format(transforms))

if torch.device(self.device).type == "cpu":
# Can't check this twice, otherwise
Expand Down