Skip to content

Commit bd685ae

Browse files
committed
Fixed a bug in recursive remove trigger
1 parent 157b5ba commit bd685ae

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ def update_refit_condition(self) -> None:
227227
self.original_model.to(to_torch_device(self.settings.device))
228228
new_result = self.original_model(*args, **kwargs)
229229
self.original_model.cpu()
230+
torch.cuda.empty_cache()
230231
if MutableTorchTensorRTModule.check_output_equal(result, new_result):
231232
self.refit_state.set_state(RefitFlag.LIVE)
232233
return
@@ -269,6 +270,7 @@ def refit_gm(self) -> None:
269270
self.gm = refit_module_weights(self.gm, self.exp_program)
270271

271272
self.original_model.cpu()
273+
torch.cuda.empty_cache()
272274

273275
def _compile(self) -> None:
274276
"""
@@ -291,6 +293,7 @@ def _compile(self) -> None:
291293
**self.settings.__dict__,
292294
)
293295
self.original_model.cpu()
296+
torch.cuda.empty_cache()
294297

295298
def _validate_inputs(self, *args: Any, **kwargs: Any) -> None:
296299
if (
@@ -354,7 +357,8 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
354357
logger.info("Model weight change detected. Refitting the module...")
355358
try:
356359
self.refit_gm()
357-
except Exception:
360+
except Exception as e:
361+
logger.error(e)
358362
logger.error("Model refit failed. Recompiling the graph module.")
359363
self._compile()
360364
self.store_state_dict_metadata()
@@ -550,10 +554,12 @@ def recursively_remove_trigger(obj: Any) -> Any:
550554
for i, v in enumerate(obj):
551555
obj[i] = recursively_remove_trigger(v)
552556
else:
553-
if not hasattr(obj, "__dict__"):
557+
if not hasattr(obj, "__dict__") or isinstance(obj, (type,)):
554558
return obj
555559
for k, v in obj.__dict__.items():
556-
setattr(obj, k, recursively_remove_trigger(v))
560+
if k[:2] != "__" or k[-2:] != "__":
561+
# We don't want to touch some built in attribute such as __dict__
562+
setattr(obj, k, recursively_remove_trigger(v))
557563

558564
return obj
559565

0 commit comments

Comments
 (0)