@@ -25,7 +25,7 @@ class RefitFlag(Enum):
25
25
26
26
27
27
class RefitState :
28
- _state : RefitFlag = RefitFlag .UNKNOWN
28
+ _state : RefitFlag = RefitFlag .NEEDS_RECOMPILE
29
29
30
30
def set_state (self , state : RefitFlag ) -> None :
31
31
if isinstance (state , RefitFlag ):
@@ -267,12 +267,14 @@ def refit_gm(self) -> None:
267
267
self .original_model .state_dict ()
268
268
)
269
269
)
270
- self .gm = refit_module_weights (self .gm , self .exp_program )
270
+ self .gm = refit_module_weights (
271
+ self .gm , self .exp_program , use_weight_map_cache = True , in_place = True
272
+ )
271
273
272
274
self .original_model .cpu ()
273
275
torch .cuda .empty_cache ()
274
276
275
- def _compile (self ) -> None :
277
+ def compile (self ) -> None :
276
278
"""
277
279
(Re)compile the TRT graph module using the PyTorch module.
278
280
This function should be called whenever the weight structure get changed (shape, more layers...)
@@ -349,7 +351,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
349
351
# Step 3: Refit/recompile accordingly
350
352
if self .refit_state .get_state () == RefitFlag .NEEDS_RECOMPILE :
351
353
logger .info ("(Re)Compiling the engine..." )
352
- self ._compile ()
354
+ self .compile ()
353
355
self .store_state_dict_metadata ()
354
356
self .refit_state .set_state (RefitFlag .LIVE )
355
357
@@ -360,7 +362,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
360
362
except Exception as e :
361
363
logger .error (e )
362
364
logger .error ("Model refit failed. Recompiling the graph module." )
363
- self ._compile ()
365
+ self .compile ()
364
366
self .store_state_dict_metadata ()
365
367
self .refit_state .set_state (RefitFlag .LIVE )
366
368
@@ -369,6 +371,10 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
369
371
self .run_info = (args , kwargs , result )
370
372
return result
371
373
374
+ def to (self , device : str ):
375
+ logger .warning ("Original PyTorch model is moved. CPU offload may failed." )
376
+ self .orignial_model .to (device )
377
+
372
378
def __deepcopy__ (self , memo : Any ) -> Any :
373
379
cls = self .__class__
374
380
result = cls .__new__ (cls )
0 commit comments