Skip to content

Implemented basic Mutable torch tensorrt module pipeline #2981

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 31 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
712ed6c
Implemented basic Mutable torch tensorrt module pipeline
cehongwang Jul 3, 2024
a58c736
Make MutableTorchTensorRTModule a standalone module. seperated the se…
cehongwang Jul 5, 2024
1ff5aed
Added input interception. No manually compile is needed
cehongwang Jul 5, 2024
359c648
Added input check and re-compilation
cehongwang Jul 5, 2024
ae086b3
Revised the refit flag to a enum
cehongwang Jul 5, 2024
cd0207b
Added saving, copy functionality
cehongwang Jul 8, 2024
46100e4
Fixed bug of refit state being a global object
cehongwang Jul 8, 2024
5657153
Delete changes to torchtrt.compile
cehongwang Jul 8, 2024
25f91ac
Fixed code lint
cehongwang Jul 10, 2024
964bc71
Added stable diffusion LoRA example
cehongwang Jul 11, 2024
5a28c10
Added SD lora example. Waiting kwarg compile to be merged to main and…
cehongwang Jul 11, 2024
9b2f354
continued
cehongwang Jul 11, 2024
a1634bc
continued
cehongwang Jul 11, 2024
fa0579b
Added original model offload to cpu
cehongwang Jul 11, 2024
5f0d767
Updated
cehongwang Jul 12, 2024
dd382e7
Updated attribute changes to UNKNOWN. Conditions will be checked on t…
cehongwang Jul 15, 2024
2a970b6
Updated attribute changes to UNKNOWN. Conditions will be checked on t…
cehongwang Jul 15, 2024
912306c
Updated attribute changes to UNKNOWN. Conditions will be checked on t…
cehongwang Jul 15, 2024
9c52267
Added test cases and fixed some bugs
cehongwang Aug 7, 2024
b4d6da6
Added kwarg changes back
cehongwang Aug 7, 2024
4e5b8cc
Added test for mutable torchtrt module
cehongwang Aug 7, 2024
f688817
Enabled Mutable TensorRT Module of SD1.5
cehongwang Aug 8, 2024
2740240
Added subclass compatibility
cehongwang Aug 9, 2024
0e958b6
Rewrote module initialization logic
cehongwang Aug 9, 2024
71b1af4
Added skip to tests if the runtime is not supported
cehongwang Aug 9, 2024
bb31956
Fixed the bug of change trigger wrapper incursion in original pytorch…
cehongwang Aug 13, 2024
c368746
Fixed the comments
cehongwang Aug 13, 2024
157b5ba
Completed docs
cehongwang Aug 13, 2024
bd685ae
Fixed a bug in recursive remove trigger
cehongwang Aug 13, 2024
83d7276
Rebased to main after refit acceleration is merged.
cehongwang Aug 13, 2024
86891e8
refactor: Use HF to download the LoRA (#3089)
narendasan Aug 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docsrc/py_api/torch_tensorrt.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ Functions

Classes
---------
.. autoclass:: MutableTorchTensorRTModule
:members:
:special-members: __init__

.. autoclass:: Input
:members:
Expand Down
1 change: 1 addition & 0 deletions examples/dynamo/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ a number of ways you can leverage this backend to accelerate inference.
* :ref:`torch_export_cudagraphs`: Using the Cudagraphs integration with `ir="dynamo"`
* :ref:`custom_kernel_plugins`: Creating a plugin to use a custom kernel inside TensorRT engines
* :ref:`refit_engine_example`: Refitting a compiled TensorRT Graph Module with updated weights
* :ref:`mutable_torchtrt_module_example`: Compile, use, and modify TensorRT Graph Module with MutableTorchTensorRTModule
* :ref:`vgg16_fp8_ptq`: Compiling a VGG16 model with FP8 and PTQ using ``torch.compile``
115 changes: 115 additions & 0 deletions examples/dynamo/mutable_torchtrt_module_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
"""
.. _mutable_torchtrt_module_example:

Mutable Torch TensorRT Module
===================================================================

We are going to demonstrate how we can easily use Mutable Torch TensorRT Module to compile, interact, and modify the TensorRT Graph Module.

Compiling a Torch-TensorRT module is straightforward, but modifying the compiled module can be challenging, especially when it comes to maintaining the state and connection between the PyTorch module and the corresponding Torch-TensorRT module.
In Ahead-of-Time (AoT) scenarios, integrating Torch TensorRT with complex pipelines, such as the Hugging Face Stable Diffusion pipeline, becomes even more difficult.
The Mutable Torch TensorRT Module is designed to address these challenges, making interaction with the Torch-TensorRT module easier than ever.

In this tutorial, we are going to walk through
1. Sample workflow of Mutable Torch TensorRT Module with ResNet 18
2. Save a Mutable Torch TensorRT Module
3. Integration with Huggingface pipeline in LoRA use case
"""

import numpy as np
import torch
import torch_tensorrt as torch_trt
import torchvision.models as models

np.random.seed(5)
torch.manual_seed(5)
inputs = [torch.rand((1, 3, 224, 224)).to("cuda")]

# %%
# Initialize the Mutable Torch TensorRT Module with settings.
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
settings = {
"use_python": False,
"enabled_precisions": {torch.float32},
"make_refitable": True,
}

model = models.resnet18(pretrained=False).eval().to("cuda")
mutable_module = torch_trt.MutableTorchTensorRTModule(model, **settings)
# You can use the mutable module just like the original pytorch module. The compilation happens while you first call the mutable module.
mutable_module(*inputs)

# %%
# Make modifications to the mutable module.
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

# %%
# Making changes to mutable module can trigger refit or re-compilation. For example, loading a different state_dict and setting new weight values will trigger refit, and adding a module to the model will trigger re-compilation.
model2 = models.resnet18(pretrained=True).eval().to("cuda")
mutable_module.load_state_dict(model2.state_dict())


# Check the output
# The refit happens while you call the mutable module again.
expected_outputs, refitted_outputs = model2(*inputs), mutable_module(*inputs)
for expected_output, refitted_output in zip(expected_outputs, refitted_outputs):
assert torch.allclose(
expected_output, refitted_output, 1e-2, 1e-2
), "Refit Result is not correct. Refit failed"

print("Refit successfully!")

# %%
# Saving Mutable Torch TensorRT Module
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

# Currently, saving is only enabled for C++ runtime, not python runtime.
torch_trt.MutableTorchTensorRTModule.save(mutable_module, "mutable_module.pkl")
reload = torch_trt.MutableTorchTensorRTModule.load("mutable_module.pkl")

# %%
# Stable Diffusion with Huggingface
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

# The LoRA checkpoint is from https://civitai.com/models/12597/moxin

from diffusers import DiffusionPipeline

with torch.no_grad():
settings = {
"use_python_runtime": True,
"enabled_precisions": {torch.float16},
"debug": True,
"make_refitable": True,
}

model_id = "runwayml/stable-diffusion-v1-5"
device = "cuda:0"

prompt = "house in forest, shuimobysim, wuchangshuo, best quality"
negative = "(worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, out of focus, cloudy, (watermark:2),"

pipe = DiffusionPipeline.from_pretrained(
model_id, revision="fp16", torch_dtype=torch.float16
)
pipe.to(device)

# The only extra line you need
pipe.unet = torch_trt.MutableTorchTensorRTModule(pipe.unet, **settings)

image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0]
image.save("./without_LoRA_mutable.jpg")

# Standard Huggingface LoRA loading procedure
pipe.load_lora_weights(
"stablediffusionapi/load_lora_embeddings",
weight_name="moxin.safetensors",
adapter_name="lora1",
)
pipe.set_adapters(["lora1"], adapter_weights=[1])
pipe.fuse_lora()
pipe.unload_lora_weights()

# Refit triggered
image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0]
image.save("./with_LoRA_mutable.jpg")
3 changes: 3 additions & 0 deletions py/torch_tensorrt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,6 @@ def _register_with_torch() -> None:
from torch_tensorrt import dynamo # noqa: F401

from torch_tensorrt._compile import * # noqa: F403
from torch_tensorrt.dynamo.runtime._MutableTorchTensorRTModule import (
MutableTorchTensorRTModule,
)
2 changes: 1 addition & 1 deletion py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,4 +534,4 @@ def save(
exp_program = torch.export.export(
module, tuple(arg_inputs), kwargs=kwarg_inputs, strict=False
)
torch.export.save(exp_program, file_path)
torch.export.save(exp_program, file_path)
Loading
Loading