Skip to content

Commit 86891e8

Browse files
narendasancehongwang
authored andcommitted
refactor: Use HF to download the LoRA (#3089)
Signed-off-by: Naren Dasan <[email protected]>
1 parent 83d7276 commit 86891e8

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

examples/dynamo/mutable_torchtrt_module_example.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@
8787
device = "cuda:0"
8888

8989
prompt = "house in forest, shuimobysim, wuchangshuo, best quality"
90-
negative = "(worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, skin spots, acnes, skin blemishes, age spot, glans, (watermark:2),"
90+
negative = "(worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, out of focus, cloudy, (watermark:2),"
9191

9292
pipe = DiffusionPipeline.from_pretrained(
9393
model_id, revision="fp16", torch_dtype=torch.float16
@@ -101,7 +101,11 @@
101101
image.save("./without_LoRA_mutable.jpg")
102102

103103
# Standard Huggingface LoRA loading procedure
104-
pipe.load_lora_weights("./moxin.safetensors", adapter_name="lora1")
104+
pipe.load_lora_weights(
105+
"stablediffusionapi/load_lora_embeddings",
106+
weight_name="moxin.safetensors",
107+
adapter_name="lora1",
108+
)
105109
pipe.set_adapters(["lora1"], adapter_weights=[1])
106110
pipe.fuse_lora()
107111
pipe.unload_lora_weights()

py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -371,10 +371,10 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
371371
self.run_info = (args, kwargs, result)
372372
return result
373373

374-
def to(self, device: str):
374+
def to(self, device: str) -> None:
375375
logger.warning("Original PyTorch model is moved. CPU offload may failed.")
376376
self.orignial_model.to(device)
377-
377+
378378
def __deepcopy__(self, memo: Any) -> Any:
379379
cls = self.__class__
380380
result = cls.__new__(cls)

0 commit comments

Comments
 (0)