3
3
import logging
4
4
from typing import Any , List , Optional , Sequence
5
5
6
+ import tensorrt as trt
6
7
import torch
7
8
from torch .fx .experimental .proxy_tensor import unset_fake_temporarily
8
9
from torch_tensorrt ._Device import Device
17
18
from torch_tensorrt .dynamo .runtime import PythonTorchTensorRTModule , TorchTensorRTModule
18
19
from torch_tensorrt .dynamo .utils import get_torch_inputs
19
20
20
- import tensorrt as trt
21
-
22
21
logger = logging .getLogger (__name__ )
23
22
24
23
@@ -131,13 +130,13 @@ def convert_module(
131
130
from torch_tensorrt .dynamo ._refit import _refit_single_trt_engine_with_gm
132
131
from torch_tensorrt .logging import TRT_LOGGER
133
132
134
- runtime = trt .Runtime (TRT_LOGGER )
135
- refit_test_engine = runtime .deserialize_cuda_engine (
136
- interpreter_result .serialized_engine
137
- )
138
133
weight_name_map : Any = None
139
134
# Do the test refit with cached map if make_refitable is enabled
140
135
if settings .make_refitable :
136
+ runtime = trt .Runtime (TRT_LOGGER )
137
+ refit_test_engine = runtime .deserialize_cuda_engine (
138
+ interpreter_result .serialized_engine
139
+ )
141
140
weight_name_map = interpreter_result .weight_name_map
142
141
try :
143
142
_refit_single_trt_engine_with_gm (
@@ -150,6 +149,9 @@ def convert_module(
150
149
except AssertionError :
151
150
logger .warning ("Fast refit test failed. Removing the weight map caching." )
152
151
152
+ del refit_test_engine
153
+ torch .cuda .empty_cache ()
154
+
153
155
rt_cls = PythonTorchTensorRTModule
154
156
155
157
if ENABLED_FEATURES .torch_tensorrt_runtime and not settings .use_python_runtime :
0 commit comments