Skip to content

Commit f8b6ac6

Browse files
committed
Handled the memory issue
1 parent 30f3b15 commit f8b6ac6

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

py/torch_tensorrt/dynamo/conversion/_conversion.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import logging
44
from typing import Any, List, Optional, Sequence
55

6+
import tensorrt as trt
67
import torch
78
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
89
from torch_tensorrt._Device import Device
@@ -17,8 +18,6 @@
1718
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
1819
from torch_tensorrt.dynamo.utils import get_torch_inputs
1920

20-
import tensorrt as trt
21-
2221
logger = logging.getLogger(__name__)
2322

2423

@@ -131,13 +130,13 @@ def convert_module(
131130
from torch_tensorrt.dynamo._refit import _refit_single_trt_engine_with_gm
132131
from torch_tensorrt.logging import TRT_LOGGER
133132

134-
runtime = trt.Runtime(TRT_LOGGER)
135-
refit_test_engine = runtime.deserialize_cuda_engine(
136-
interpreter_result.serialized_engine
137-
)
138133
weight_name_map: Any = None
139134
# Do the test refit with cached map if make_refitable is enabled
140135
if settings.make_refitable:
136+
runtime = trt.Runtime(TRT_LOGGER)
137+
refit_test_engine = runtime.deserialize_cuda_engine(
138+
interpreter_result.serialized_engine
139+
)
141140
weight_name_map = interpreter_result.weight_name_map
142141
try:
143142
_refit_single_trt_engine_with_gm(
@@ -150,6 +149,9 @@ def convert_module(
150149
except AssertionError:
151150
logger.warning("Fast refit test failed. Removing the weight map caching.")
152151

152+
del refit_test_engine
153+
torch.cuda.empty_cache()
154+
153155
rt_cls = PythonTorchTensorRTModule
154156

155157
if ENABLED_FEATURES.torch_tensorrt_runtime and not settings.use_python_runtime:

0 commit comments

Comments
 (0)