14
14
import torch .testing
15
15
from datasets_utils import combinations_grid
16
16
from torch .nn .functional import one_hot
17
- from torch .testing ._comparison import assert_equal as _assert_equal , BooleanPair , NonePair , NumberPair , TensorLikePair
17
+ from torch .testing ._comparison import BooleanPair , NonePair , not_close_error_metas , NumberPair , TensorLikePair
18
18
from torchvision .prototype import datapoints
19
19
from torchvision .prototype .transforms .functional import convert_dtype_image_tensor , to_image_tensor
20
20
from torchvision .transforms .functional_tensor import _max_value as get_max_value
@@ -73,7 +73,7 @@ def compare(self) -> None:
73
73
actual , expected = self ._promote_for_comparison (actual , expected )
74
74
mae = float (torch .abs (actual - expected ).float ().mean ())
75
75
if mae > self .atol :
76
- raise self ._make_error_meta (
76
+ self ._fail (
77
77
AssertionError ,
78
78
f"The MAE of the images is { mae } , but only { self .atol } is allowed." ,
79
79
)
@@ -99,7 +99,7 @@ def assert_close(
99
99
"""Superset of :func:`torch.testing.assert_close` with support for PIL vs. tensor image comparison"""
100
100
__tracebackhide__ = True
101
101
102
- _assert_equal (
102
+ error_metas = not_close_error_metas (
103
103
actual ,
104
104
expected ,
105
105
pair_types = (
@@ -117,10 +117,12 @@ def assert_close(
117
117
check_dtype = check_dtype ,
118
118
check_layout = check_layout ,
119
119
check_stride = check_stride ,
120
- msg = msg ,
121
120
** kwargs ,
122
121
)
123
122
123
+ if error_metas :
124
+ raise error_metas [0 ].to_error (msg )
125
+
124
126
125
127
assert_equal = functools .partial (assert_close , rtol = 0 , atol = 0 )
126
128
0 commit comments