Skip to content

Commit 83d7276

Browse files
committed
Rebased to main after refit acceleration is merged.
1 parent bd685ae commit 83d7276

File tree

2 files changed

+16
-10
lines changed

2 files changed

+16
-10
lines changed

examples/dynamo/mutable_torchtrt_module_example.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
import torch_tensorrt as torch_trt
2222
import torchvision.models as models
2323

24-
np.random.seed(0)
25-
torch.manual_seed(0)
24+
np.random.seed(5)
25+
torch.manual_seed(5)
2626
inputs = [torch.rand((1, 3, 224, 224)).to("cuda")]
2727

2828
# %%
@@ -76,7 +76,7 @@
7676
from diffusers import DiffusionPipeline
7777

7878
with torch.no_grad():
79-
kwargs = {
79+
settings = {
8080
"use_python_runtime": True,
8181
"enabled_precisions": {torch.float16},
8282
"debug": True,
@@ -86,7 +86,7 @@
8686
model_id = "runwayml/stable-diffusion-v1-5"
8787
device = "cuda:0"
8888

89-
prompt = "portrait of a woman standing, shuimobysim, wuchangshuo, best quality"
89+
prompt = "house in forest, shuimobysim, wuchangshuo, best quality"
9090
negative = "(worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, skin spots, acnes, skin blemishes, age spot, glans, (watermark:2),"
9191

9292
pipe = DiffusionPipeline.from_pretrained(
@@ -95,7 +95,7 @@
9595
pipe.to(device)
9696

9797
# The only extra line you need
98-
pipe.unet = torch_trt.MutableTorchTensorRTModule(pipe.unet, **kwargs)
98+
pipe.unet = torch_trt.MutableTorchTensorRTModule(pipe.unet, **settings)
9999

100100
image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0]
101101
image.save("./without_LoRA_mutable.jpg")

py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class RefitFlag(Enum):
2525

2626

2727
class RefitState:
28-
_state: RefitFlag = RefitFlag.UNKNOWN
28+
_state: RefitFlag = RefitFlag.NEEDS_RECOMPILE
2929

3030
def set_state(self, state: RefitFlag) -> None:
3131
if isinstance(state, RefitFlag):
@@ -267,12 +267,14 @@ def refit_gm(self) -> None:
267267
self.original_model.state_dict()
268268
)
269269
)
270-
self.gm = refit_module_weights(self.gm, self.exp_program)
270+
self.gm = refit_module_weights(
271+
self.gm, self.exp_program, use_weight_map_cache=True, in_place=True
272+
)
271273

272274
self.original_model.cpu()
273275
torch.cuda.empty_cache()
274276

275-
def _compile(self) -> None:
277+
def compile(self) -> None:
276278
"""
277279
(Re)compile the TRT graph module using the PyTorch module.
278280
This function should be called whenever the weight structure get changed (shape, more layers...)
@@ -349,7 +351,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
349351
# Step 3: Refit/recompile accordingly
350352
if self.refit_state.get_state() == RefitFlag.NEEDS_RECOMPILE:
351353
logger.info("(Re)Compiling the engine...")
352-
self._compile()
354+
self.compile()
353355
self.store_state_dict_metadata()
354356
self.refit_state.set_state(RefitFlag.LIVE)
355357

@@ -360,7 +362,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
360362
except Exception as e:
361363
logger.error(e)
362364
logger.error("Model refit failed. Recompiling the graph module.")
363-
self._compile()
365+
self.compile()
364366
self.store_state_dict_metadata()
365367
self.refit_state.set_state(RefitFlag.LIVE)
366368

@@ -369,6 +371,10 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
369371
self.run_info = (args, kwargs, result)
370372
return result
371373

374+
def to(self, device: str):
375+
logger.warning("Original PyTorch model is moved. CPU offload may failed.")
376+
self.orignial_model.to(device)
377+
372378
def __deepcopy__(self, memo: Any) -> Any:
373379
cls = self.__class__
374380
result = cls.__new__(cls)

0 commit comments

Comments
 (0)