Skip to content

Commit 762d965

Browse files
committed
chore: move threshold values to variables and rebase with main
1 parent 1ffa050 commit 762d965

File tree

5 files changed

+21
-14
lines changed

5 files changed

+21
-14
lines changed

py/torch_tensorrt/dynamo/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222

2323
COSINE_THRESHOLD = 0.99
2424
DYNAMIC_DIM = -1
25+
RTOL = 5e-3
26+
ATOL = 5e-3
2527

2628

2729
class Frameworks(Enum):
@@ -412,6 +414,8 @@ def check_module_output(
412414
def check_output_equal(
413415
output1: Any,
414416
output2: Any,
417+
rtol: float = RTOL,
418+
atol: float = ATOL,
415419
) -> bool:
416420

417421
if type(output1) != type(output2):
@@ -423,7 +427,7 @@ def check_output_equal(
423427
if isinstance(output1, torch.Tensor):
424428
if output1.shape != output2.shape:
425429
return False
426-
return torch.allclose(output1, output2, 5e-3, 5e-3) # type: ignore
430+
return torch.allclose(output1, output2, rtol, atol) # type: ignore
427431

428432
elif isinstance(output1, (tuple, list)):
429433
if len(output1) != len(output2):

tests/py/dynamo/conversion/harness.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
pre_export_lowering,
2424
)
2525
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule
26-
from torch_tensorrt.dynamo.utils import get_torch_inputs
26+
from torch_tensorrt.dynamo.utils import ATOL, RTOL, get_torch_inputs
2727

2828
_LOGGER: logging.Logger = logging.getLogger(__name__)
2929

@@ -60,8 +60,8 @@ def run_test(
6060
mod,
6161
inputs,
6262
interpreter,
63-
rtol,
64-
atol,
63+
rtol=RTOL,
64+
atol=ATOL,
6565
check_dtype=True,
6666
pyt_inputs=None,
6767
rt_cls=PythonTorchTensorRTModule,
@@ -254,8 +254,8 @@ def run_test(
254254
self,
255255
mod,
256256
inputs,
257-
rtol=5e-3,
258-
atol=5e-3,
257+
rtol=RTOL,
258+
atol=ATOL,
259259
precision=dtype.f32,
260260
check_dtype=True,
261261
use_dynamo_tracer=False,
@@ -374,8 +374,8 @@ def run_test_with_dynamic_shape(
374374
self,
375375
mod,
376376
input_specs,
377-
rtol=5e-3,
378-
atol=5e-3,
377+
rtol=RTOL,
378+
atol=ATOL,
379379
output_dtypes=None,
380380
use_dynamo_tracer=False,
381381
enable_passes=False,

tests/py/dynamo/conversion/test_bitwise_and_aten.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from torch.export import Dim
66
from torch.testing._internal.common_utils import run_tests
77
from torch_tensorrt import Input
8+
from torch_tensorrt.dynamo.utils import ATOL, RTOL
89

910
from .harness import DispatchTestCase
1011

@@ -152,8 +153,8 @@ def forward(self, lhs_val, rhs_val):
152153
torch.testing.assert_close(
153154
out,
154155
ref,
155-
rtol=5e-3,
156-
atol=5e-3,
156+
rtol=RTOL,
157+
atol=ATOL,
157158
equal_nan=True,
158159
check_dtype=True,
159160
)

tests/py/dynamo/conversion/test_embedding_bag_aten.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from parameterized import param, parameterized
44
from torch.testing._internal.common_utils import run_tests
55
from torch_tensorrt import Input
6+
from torch_tensorrt.dynamo.utils import ATOL, RTOL
67

78
from .harness import DispatchTestCase
89

@@ -501,8 +502,8 @@ def forward(self, weights, indices, offsets, per_sample_weights=None):
501502
torch.testing.assert_close(
502503
out,
503504
ref,
504-
rtol=5e-3,
505-
atol=5e-3,
505+
rtol=RTOL,
506+
atol=ATOL,
506507
equal_nan=True,
507508
check_dtype=True,
508509
)

tests/py/dynamo/conversion/test_index_select_aten.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from parameterized import param, parameterized
55
from torch.testing._internal.common_utils import run_tests
66
from torch_tensorrt import Input
7+
from torch_tensorrt.dynamo.utils import ATOL, RTOL
78

89
from .harness import DispatchTestCase
910

@@ -122,8 +123,8 @@ def forward(self, source_tensor, indice_tensor):
122123
torch.testing.assert_close(
123124
out,
124125
ref,
125-
rtol=5e-3,
126-
atol=5e-3,
126+
rtol=RTOL,
127+
atol=ATOL,
127128
equal_nan=True,
128129
check_dtype=True,
129130
)

0 commit comments

Comments
 (0)