@@ -227,6 +227,7 @@ def update_refit_condition(self) -> None:
227
227
self .original_model .to (to_torch_device (self .settings .device ))
228
228
new_result = self .original_model (* args , ** kwargs )
229
229
self .original_model .cpu ()
230
+ torch .cuda .empty_cache ()
230
231
if MutableTorchTensorRTModule .check_output_equal (result , new_result ):
231
232
self .refit_state .set_state (RefitFlag .LIVE )
232
233
return
@@ -269,6 +270,7 @@ def refit_gm(self) -> None:
269
270
self .gm = refit_module_weights (self .gm , self .exp_program )
270
271
271
272
self .original_model .cpu ()
273
+ torch .cuda .empty_cache ()
272
274
273
275
def _compile (self ) -> None :
274
276
"""
@@ -291,6 +293,7 @@ def _compile(self) -> None:
291
293
** self .settings .__dict__ ,
292
294
)
293
295
self .original_model .cpu ()
296
+ torch .cuda .empty_cache ()
294
297
295
298
def _validate_inputs (self , * args : Any , ** kwargs : Any ) -> None :
296
299
if (
@@ -354,7 +357,8 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
354
357
logger .info ("Model weight change detected. Refitting the module..." )
355
358
try :
356
359
self .refit_gm ()
357
- except Exception :
360
+ except Exception as e :
361
+ logger .error (e )
358
362
logger .error ("Model refit failed. Recompiling the graph module." )
359
363
self ._compile ()
360
364
self .store_state_dict_metadata ()
@@ -550,10 +554,12 @@ def recursively_remove_trigger(obj: Any) -> Any:
550
554
for i , v in enumerate (obj ):
551
555
obj [i ] = recursively_remove_trigger (v )
552
556
else :
553
- if not hasattr (obj , "__dict__" ):
557
+ if not hasattr (obj , "__dict__" ) or isinstance ( obj , ( type ,)) :
554
558
return obj
555
559
for k , v in obj .__dict__ .items ():
556
- setattr (obj , k , recursively_remove_trigger (v ))
560
+ if k [:2 ] != "__" or k [- 2 :] != "__" :
561
+ # We don't want to touch some built in attribute such as __dict__
562
+ setattr (obj , k , recursively_remove_trigger (v ))
557
563
558
564
return obj
559
565
0 commit comments