Skip to content

Commit 157b5ba

Browse files
committed
Completed docs
1 parent c368746 commit 157b5ba

File tree

5 files changed

+87
-98
lines changed

5 files changed

+87
-98
lines changed

docsrc/py_api/torch_tensorrt.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ Functions
3232

3333
Classes
3434
---------
35+
.. autoclass:: MutableTorchTensorRTModule
36+
:members:
37+
:special-members: __init__
3538

3639
.. autoclass:: Input
3740
:members:

examples/dynamo/README.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,5 @@ a number of ways you can leverage this backend to accelerate inference.
1313
* :ref:`torch_export_cudagraphs`: Using the Cudagraphs integration with `ir="dynamo"`
1414
* :ref:`custom_kernel_plugins`: Creating a plugin to use a custom kernel inside TensorRT engines
1515
* :ref:`refit_engine_example`: Refitting a compiled TensorRT Graph Module with updated weights
16+
* :ref:`mutable_torchtrt_module_example`: Compile, use, and modify TensorRT Graph Module with MutableTorchTensorRTModule
1617
* :ref:`vgg16_fp8_ptq`: Compiling a VGG16 model with FP8 and PTQ using ``torch.compile``
Lines changed: 70 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
"""
2-
.. _refit_engine_example:
2+
.. _mutable_torchtrt_module_example:
33
4-
Refit TenorRT Graph Module with Torch-TensorRT
4+
Mutable Torch TensorRT Module
55
===================================================================
66
7-
We are going to demonstrate how a compiled TensorRT Graph Module can be refitted with updated weights.
7+
We are going to demonstrate how we can easily use Mutable Torch TensorRT Module to compile, interact, and modify the TensorRT Graph Module.
88
9-
In many cases, we frequently update the weights of models, such as applying various LoRA to Stable Diffusion or constant A/B testing of AI products.
10-
That poses challenges for TensorRT inference optimizations, as compiling the TensorRT engines takes significant time, making repetitive compilation highly inefficient.
11-
Torch-TensorRT supports refitting TensorRT graph modules without re-compiling the engine, considerably accelerating the workflow.
9+
Compiling a Torch-TensorRT module is straightforward, but modifying the compiled module can be challenging, especially when it comes to maintaining the state and connection between the PyTorch module and the corresponding Torch-TensorRT module.
10+
In Ahead-of-Time (AoT) scenarios, integrating Torch TensorRT with complex pipelines, such as the Hugging Face Stable Diffusion pipeline, becomes even more difficult.
11+
The Mutable Torch TensorRT Module is designed to address these challenges, making interaction with the Torch-TensorRT module easier than ever.
1212
1313
In this tutorial, we are going to walk through
14-
1. Compiling a PyTorch model to a TensorRT Graph Module
15-
2. Save and load a graph module
16-
3. Refit the graph module
14+
1. Sample workflow of Mutable Torch TensorRT Module with ResNet 18
15+
2. Save a Mutable Torch TensorRT Module
16+
3. Integration with Huggingface pipeline in LoRA use case
1717
"""
1818

1919
import numpy as np
@@ -26,29 +26,86 @@
2626
inputs = [torch.rand((1, 3, 224, 224)).to("cuda")]
2727

2828
# %%
29-
# Compile the module for the first time and save it.
29+
# Initialize the Mutable Torch TensorRT Module with settings.
3030
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
31-
kwargs = {
31+
settings = {
3232
"use_python": False,
3333
"enabled_precisions": {torch.float32},
3434
"make_refitable": True,
3535
}
3636

3737
model = models.resnet18(pretrained=False).eval().to("cuda")
38-
model2 = models.resnet18(pretrained=True).eval().to("cuda")
39-
mutable_module = torch_trt.MutableTorchTensorRTModule(model, **kwargs)
38+
mutable_module = torch_trt.MutableTorchTensorRTModule(model, **settings)
39+
# You can use the mutable module just like the original pytorch module. The compilation happens while you first call the mutable module.
4040
mutable_module(*inputs)
4141

42+
# %%
43+
# Make modifications to the mutable module.
44+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
45+
46+
# %%
47+
# Making changes to mutable module can trigger refit or re-compilation. For example, loading a different state_dict and setting new weight values will trigger refit, and adding a module to the model will trigger re-compilation.
48+
model2 = models.resnet18(pretrained=True).eval().to("cuda")
4249
mutable_module.load_state_dict(model2.state_dict())
4350

4451

4552
# Check the output
53+
# The refit happens while you call the mutable module again.
4654
expected_outputs, refitted_outputs = model2(*inputs), mutable_module(*inputs)
4755
for expected_output, refitted_output in zip(expected_outputs, refitted_outputs):
4856
assert torch.allclose(
4957
expected_output, refitted_output, 1e-2, 1e-2
5058
), "Refit Result is not correct. Refit failed"
5159

5260
print("Refit successfully!")
61+
62+
# %%
63+
# Saving Mutable Torch TensorRT Module
64+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
65+
66+
# Currently, saving is only enabled for C++ runtime, not python runtime.
5367
torch_trt.MutableTorchTensorRTModule.save(mutable_module, "mutable_module.pkl")
5468
reload = torch_trt.MutableTorchTensorRTModule.load("mutable_module.pkl")
69+
70+
# %%
71+
# Stable Diffusion with Huggingface
72+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
73+
74+
# The LoRA checkpoint is from https://civitai.com/models/12597/moxin
75+
76+
from diffusers import DiffusionPipeline
77+
78+
with torch.no_grad():
79+
kwargs = {
80+
"use_python_runtime": True,
81+
"enabled_precisions": {torch.float16},
82+
"debug": True,
83+
"make_refitable": True,
84+
}
85+
86+
model_id = "runwayml/stable-diffusion-v1-5"
87+
device = "cuda:0"
88+
89+
prompt = "portrait of a woman standing, shuimobysim, wuchangshuo, best quality"
90+
negative = "(worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, skin spots, acnes, skin blemishes, age spot, glans, (watermark:2),"
91+
92+
pipe = DiffusionPipeline.from_pretrained(
93+
model_id, revision="fp16", torch_dtype=torch.float16
94+
)
95+
pipe.to(device)
96+
97+
# The only extra line you need
98+
pipe.unet = torch_trt.MutableTorchTensorRTModule(pipe.unet, **kwargs)
99+
100+
image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0]
101+
image.save("./without_LoRA_mutable.jpg")
102+
103+
# Standard Huggingface LoRA loading procedure
104+
pipe.load_lora_weights("./moxin.safetensors", adapter_name="lora1")
105+
pipe.set_adapters(["lora1"], adapter_weights=[1])
106+
pipe.fuse_lora()
107+
pipe.unload_lora_weights()
108+
109+
# Refit triggered
110+
image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0]
111+
image.save("./with_LoRA_mutable.jpg")

examples/dynamo/mutable_torchtrt_module_stable_diffusion_example.py

Lines changed: 0 additions & 74 deletions
This file was deleted.

py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,17 @@ def get_state(self) -> RefitFlag:
3838

3939

4040
class MutableTorchTensorRTModule(object):
41+
"""
42+
Initialize a MutableTorchTensorRTModule to seamlessly manipulate it like a regular PyTorch module.
43+
All TensorRT compilation and refitting processes are handled automatically as you work with the module.
44+
Any changes to its attributes or loading a different state_dict will trigger refitting or recompilation,
45+
which will be managed during the next forward pass.
46+
47+
The MutableTorchTensorRTModule takes a PyTorch module and a set of configuration settings for the compiler.
48+
Once compilation is complete, the module maintains the connection between the TensorRT graph module and the original PyTorch module.
49+
Any modifications made to the MutableTorchTensorRTModule will be reflected in both the TensorRT graph module and the original PyTorch module.
50+
"""
51+
4152
def __init__(
4253
self,
4354
pytorch_model: torch.nn.Module,
@@ -75,15 +86,6 @@ def __init__(
7586
**kwargs: Any,
7687
) -> None:
7788
"""
78-
Initialize a MutableTorchTensorRTModule. This module can be manipulated just as a normal PyTorch module
79-
and all TRT compilation and refit happens underthe hood as the user is using it. Modifying its attribute or
80-
loading a different state_dict can trigger refit/recompilation that will be handled in the next forward run.
81-
82-
MutableTorchTensorRTModule takes a PyTorch module and a set of settings to configure the compiler.
83-
After compilation is finished, MutableTorchTensorRTModule maintains the connection between the TRT graph module
84-
and the original PyTorch module. And modification to MutableTorchTensorRTModule will reflect in both TRT graph module
85-
and original PyTorch module.
86-
8789
8890
Arguments:
8991
pytorch_model (torch.nn.module): Source module that needs to be accelerated
@@ -148,7 +150,6 @@ def __init__(
148150
assert (
149151
make_refitable
150152
), "'make_refitable' has to be True for a MutableTorchTensorRTModule."
151-
make_refitable = True
152153
compilation_options = {
153154
"enabled_precisions": (
154155
enabled_precisions
@@ -309,6 +310,7 @@ def store_inputs(self, arg_inputs: Any, kwarg_inputs: Any) -> None:
309310

310311
@staticmethod
311312
def process_kwarg_inputs(inputs: Any) -> Any:
313+
# Process kwarg inputs to be acceptable for Torch-TensorRT
312314
if isinstance(inputs, dict):
313315
# None should be excluded. AOT compile also does not allow dynamic control flow, bool is also excluded.
314316
return {
@@ -537,7 +539,7 @@ def load(path: str) -> Any:
537539

538540

539541
def recursively_remove_trigger(obj: Any) -> Any:
540-
# Not save: If the object has a loop (such as a doubly linkded list), this will cause infinite recursion
542+
# Not safe: If the object has a circular reference (such as a doubly linkded list), this will cause infinite recursion
541543
if obj.__class__.__name__ == "ChangeTriggerWrapper":
542544
obj = obj.instance
543545

0 commit comments

Comments
 (0)