Skip to content

Commit 86f551d

Browse files
authored
update usages of torch.testing internals (#7203)
1 parent 5ea8e01 commit 86f551d

File tree

2 files changed

+13
-9
lines changed

2 files changed

+13
-9
lines changed

test/prototype_common_utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import torch.testing
1515
from datasets_utils import combinations_grid
1616
from torch.nn.functional import one_hot
17-
from torch.testing._comparison import assert_equal as _assert_equal, BooleanPair, NonePair, NumberPair, TensorLikePair
17+
from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair
1818
from torchvision.prototype import datapoints
1919
from torchvision.prototype.transforms.functional import convert_dtype_image_tensor, to_image_tensor
2020
from torchvision.transforms.functional_tensor import _max_value as get_max_value
@@ -73,7 +73,7 @@ def compare(self) -> None:
7373
actual, expected = self._promote_for_comparison(actual, expected)
7474
mae = float(torch.abs(actual - expected).float().mean())
7575
if mae > self.atol:
76-
raise self._make_error_meta(
76+
self._fail(
7777
AssertionError,
7878
f"The MAE of the images is {mae}, but only {self.atol} is allowed.",
7979
)
@@ -99,7 +99,7 @@ def assert_close(
9999
"""Superset of :func:`torch.testing.assert_close` with support for PIL vs. tensor image comparison"""
100100
__tracebackhide__ = True
101101

102-
_assert_equal(
102+
error_metas = not_close_error_metas(
103103
actual,
104104
expected,
105105
pair_types=(
@@ -117,10 +117,12 @@ def assert_close(
117117
check_dtype=check_dtype,
118118
check_layout=check_layout,
119119
check_stride=check_stride,
120-
msg=msg,
121120
**kwargs,
122121
)
123122

123+
if error_metas:
124+
raise error_metas[0].to_error(msg)
125+
124126

125127
assert_equal = functools.partial(assert_close, rtol=0, atol=0)
126128

test/test_prototype_datasets_builtin.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import functools
21
import io
32
import pickle
43
from collections import deque
@@ -9,7 +8,7 @@
98

109
import torchvision.prototype.transforms.utils
1110
from builtin_dataset_mocks import DATASET_MOCKS, parametrize_dataset_mocks
12-
from torch.testing._comparison import assert_equal, ObjectPair, TensorLikePair
11+
from torch.testing._comparison import not_close_error_metas, ObjectPair, TensorLikePair
1312

1413
# TODO: replace with torchdata.dataloader2.DataLoader2 as soon as it is stable-ish
1514
from torch.utils.data import DataLoader
@@ -25,9 +24,12 @@
2524
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE
2625

2726

28-
assert_samples_equal = functools.partial(
29-
assert_equal, pair_types=(TensorLikePair, ObjectPair), rtol=0, atol=0, equal_nan=True
30-
)
27+
def assert_samples_equal(*args, msg=None, **kwargs):
28+
error_metas = not_close_error_metas(
29+
*args, pair_types=(TensorLikePair, ObjectPair), rtol=0, atol=0, equal_nan=True, **kwargs
30+
)
31+
if error_metas:
32+
raise error_metas[0].to_error(msg)
3133

3234

3335
def extract_datapipes(dp):

0 commit comments

Comments
 (0)