diff --git a/docsrc/py_api/torch_tensorrt.rst b/docsrc/py_api/torch_tensorrt.rst index c2ddc9c701..ab108912db 100644 --- a/docsrc/py_api/torch_tensorrt.rst +++ b/docsrc/py_api/torch_tensorrt.rst @@ -32,6 +32,9 @@ Functions Classes --------- +.. autoclass:: MutableTorchTensorRTModule + :members: + :special-members: __init__ .. autoclass:: Input :members: diff --git a/examples/dynamo/README.rst b/examples/dynamo/README.rst index dac0cf1161..22ae16ec0c 100644 --- a/examples/dynamo/README.rst +++ b/examples/dynamo/README.rst @@ -13,4 +13,5 @@ a number of ways you can leverage this backend to accelerate inference. * :ref:`torch_export_cudagraphs`: Using the Cudagraphs integration with `ir="dynamo"` * :ref:`custom_kernel_plugins`: Creating a plugin to use a custom kernel inside TensorRT engines * :ref:`refit_engine_example`: Refitting a compiled TensorRT Graph Module with updated weights +* :ref:`mutable_torchtrt_module_example`: Compile, use, and modify TensorRT Graph Module with MutableTorchTensorRTModule * :ref:`vgg16_fp8_ptq`: Compiling a VGG16 model with FP8 and PTQ using ``torch.compile`` diff --git a/examples/dynamo/mutable_torchtrt_module_example.py b/examples/dynamo/mutable_torchtrt_module_example.py new file mode 100644 index 0000000000..84122e074b --- /dev/null +++ b/examples/dynamo/mutable_torchtrt_module_example.py @@ -0,0 +1,115 @@ +""" +.. _mutable_torchtrt_module_example: + +Mutable Torch TensorRT Module +=================================================================== + +We are going to demonstrate how we can easily use Mutable Torch TensorRT Module to compile, interact, and modify the TensorRT Graph Module. + +Compiling a Torch-TensorRT module is straightforward, but modifying the compiled module can be challenging, especially when it comes to maintaining the state and connection between the PyTorch module and the corresponding Torch-TensorRT module. +In Ahead-of-Time (AoT) scenarios, integrating Torch TensorRT with complex pipelines, such as the Hugging Face Stable Diffusion pipeline, becomes even more difficult. +The Mutable Torch TensorRT Module is designed to address these challenges, making interaction with the Torch-TensorRT module easier than ever. + +In this tutorial, we are going to walk through +1. Sample workflow of Mutable Torch TensorRT Module with ResNet 18 +2. Save a Mutable Torch TensorRT Module +3. Integration with Huggingface pipeline in LoRA use case +""" + +import numpy as np +import torch +import torch_tensorrt as torch_trt +import torchvision.models as models + +np.random.seed(5) +torch.manual_seed(5) +inputs = [torch.rand((1, 3, 224, 224)).to("cuda")] + +# %% +# Initialize the Mutable Torch TensorRT Module with settings. +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +settings = { + "use_python": False, + "enabled_precisions": {torch.float32}, + "make_refitable": True, +} + +model = models.resnet18(pretrained=False).eval().to("cuda") +mutable_module = torch_trt.MutableTorchTensorRTModule(model, **settings) +# You can use the mutable module just like the original pytorch module. The compilation happens while you first call the mutable module. +mutable_module(*inputs) + +# %% +# Make modifications to the mutable module. +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +# %% +# Making changes to mutable module can trigger refit or re-compilation. For example, loading a different state_dict and setting new weight values will trigger refit, and adding a module to the model will trigger re-compilation. +model2 = models.resnet18(pretrained=True).eval().to("cuda") +mutable_module.load_state_dict(model2.state_dict()) + + +# Check the output +# The refit happens while you call the mutable module again. +expected_outputs, refitted_outputs = model2(*inputs), mutable_module(*inputs) +for expected_output, refitted_output in zip(expected_outputs, refitted_outputs): + assert torch.allclose( + expected_output, refitted_output, 1e-2, 1e-2 + ), "Refit Result is not correct. Refit failed" + +print("Refit successfully!") + +# %% +# Saving Mutable Torch TensorRT Module +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +# Currently, saving is only enabled for C++ runtime, not python runtime. +torch_trt.MutableTorchTensorRTModule.save(mutable_module, "mutable_module.pkl") +reload = torch_trt.MutableTorchTensorRTModule.load("mutable_module.pkl") + +# %% +# Stable Diffusion with Huggingface +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +# The LoRA checkpoint is from https://civitai.com/models/12597/moxin + +from diffusers import DiffusionPipeline + +with torch.no_grad(): + settings = { + "use_python_runtime": True, + "enabled_precisions": {torch.float16}, + "debug": True, + "make_refitable": True, + } + + model_id = "runwayml/stable-diffusion-v1-5" + device = "cuda:0" + + prompt = "house in forest, shuimobysim, wuchangshuo, best quality" + negative = "(worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, out of focus, cloudy, (watermark:2)," + + pipe = DiffusionPipeline.from_pretrained( + model_id, revision="fp16", torch_dtype=torch.float16 + ) + pipe.to(device) + + # The only extra line you need + pipe.unet = torch_trt.MutableTorchTensorRTModule(pipe.unet, **settings) + + image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0] + image.save("./without_LoRA_mutable.jpg") + + # Standard Huggingface LoRA loading procedure + pipe.load_lora_weights( + "stablediffusionapi/load_lora_embeddings", + weight_name="moxin.safetensors", + adapter_name="lora1", + ) + pipe.set_adapters(["lora1"], adapter_weights=[1]) + pipe.fuse_lora() + pipe.unload_lora_weights() + + # Refit triggered + image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0] + image.save("./with_LoRA_mutable.jpg") diff --git a/py/torch_tensorrt/__init__.py b/py/torch_tensorrt/__init__.py index 442afb0bc0..118b45d7b6 100644 --- a/py/torch_tensorrt/__init__.py +++ b/py/torch_tensorrt/__init__.py @@ -134,3 +134,6 @@ def _register_with_torch() -> None: from torch_tensorrt import dynamo # noqa: F401 from torch_tensorrt._compile import * # noqa: F403 +from torch_tensorrt.dynamo.runtime._MutableTorchTensorRTModule import ( + MutableTorchTensorRTModule, +) diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index 910e98fa41..0505d90f89 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -534,4 +534,4 @@ def save( exp_program = torch.export.export( module, tuple(arg_inputs), kwargs=kwarg_inputs, strict=False ) - torch.export.save(exp_program, file_path) + torch.export.save(exp_program, file_path) \ No newline at end of file diff --git a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py new file mode 100644 index 0000000000..88b8c13d10 --- /dev/null +++ b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py @@ -0,0 +1,651 @@ +import logging +from copy import deepcopy +from enum import Enum, auto +from typing import Any, Collection, Dict, Iterator, List, Optional, Set, Tuple, Union + +import numpy as np +import torch +from torch.fx.node import Target +from torch_tensorrt._Device import Device +from torch_tensorrt._enums import EngineCapability, dtype +from torch_tensorrt.dynamo import _defaults +from torch_tensorrt.dynamo._compiler import compile as dynamo_compile +from torch_tensorrt.dynamo._refit import refit_module_weights +from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.utils import to_torch_device, to_torch_tensorrt_device + +logger = logging.getLogger(__name__) + + +class RefitFlag(Enum): + UNKNOWN = auto() + NEEDS_REFIT = auto() + NEEDS_RECOMPILE = auto() + LIVE = auto() + + +class RefitState: + _state: RefitFlag = RefitFlag.NEEDS_RECOMPILE + + def set_state(self, state: RefitFlag) -> None: + if isinstance(state, RefitFlag): + self._state = state + else: + raise ValueError(f"Invalid state: {state}") + + def get_state(self) -> RefitFlag: + return self._state + + +class MutableTorchTensorRTModule(object): + """ + Initialize a MutableTorchTensorRTModule to seamlessly manipulate it like a regular PyTorch module. + All TensorRT compilation and refitting processes are handled automatically as you work with the module. + Any changes to its attributes or loading a different state_dict will trigger refitting or recompilation, + which will be managed during the next forward pass. + + The MutableTorchTensorRTModule takes a PyTorch module and a set of configuration settings for the compiler. + Once compilation is complete, the module maintains the connection between the TensorRT graph module and the original PyTorch module. + Any modifications made to the MutableTorchTensorRTModule will be reflected in both the TensorRT graph module and the original PyTorch module. + """ + + def __init__( + self, + pytorch_model: torch.nn.Module, + *, + device: Optional[Union[Device, torch.device, str]] = _defaults.DEVICE, + disable_tf32: bool = _defaults.DISABLE_TF32, + assume_dynamic_shape_support: bool = _defaults.ASSUME_DYNAMIC_SHAPE_SUPPORT, + sparse_weights: bool = _defaults.SPARSE_WEIGHTS, + enabled_precisions: ( + Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype] + ) = _defaults.ENABLED_PRECISIONS, + engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY, + make_refitable: bool = _defaults.MAKE_REFITABLE, + debug: bool = _defaults.DEBUG, + num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS, + workspace_size: int = _defaults.WORKSPACE_SIZE, + dla_sram_size: int = _defaults.DLA_SRAM_SIZE, + dla_local_dram_size: int = _defaults.DLA_LOCAL_DRAM_SIZE, + dla_global_dram_size: int = _defaults.DLA_GLOBAL_DRAM_SIZE, + truncate_double: bool = _defaults.TRUNCATE_DOUBLE, + require_full_compilation: bool = _defaults.REQUIRE_FULL_COMPILATION, + min_block_size: int = _defaults.MIN_BLOCK_SIZE, + torch_executed_ops: Optional[Collection[Target]] = None, + torch_executed_modules: Optional[List[str]] = None, + pass_through_build_failures: bool = _defaults.PASS_THROUGH_BUILD_FAILURES, + max_aux_streams: Optional[int] = _defaults.MAX_AUX_STREAMS, + version_compatible: bool = _defaults.VERSION_COMPATIBLE, + optimization_level: Optional[int] = _defaults.OPTIMIZATION_LEVEL, + use_python_runtime: bool = _defaults.USE_PYTHON_RUNTIME, + use_fast_partitioner: bool = _defaults.USE_FAST_PARTITIONER, + enable_experimental_decompositions: bool = _defaults.ENABLE_EXPERIMENTAL_DECOMPOSITIONS, + dryrun: bool = _defaults.DRYRUN, + hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE, + timing_cache_path: str = _defaults.TIMING_CACHE_PATH, + **kwargs: Any, + ) -> None: + """ + + Arguments: + pytorch_model (torch.nn.module): Source module that needs to be accelerated + + Keyword Arguments: + device (Union(torch_tensorrt.Device, torch.device, dict)): Target device for TensorRT engines to run on :: + + device=torch_tensorrt.Device("dla:1", allow_gpu_fallback=True) + + disable_tf32 (bool): Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas + assume_dynamic_shape_support (bool): Setting this to true enables the converters work for both dynamic and static shapes. Default: False + sparse_weights (bool): Enable sparsity for convolution and fully connected layers. + enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels + refit (bool): Enable refitting + debug (bool): Enable debuggable engine + capability (torch_tensorrt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels + num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels + workspace_size (int): Maximum size of workspace given to TensorRT + dla_sram_size (int): Fast software managed RAM used by DLA to communicate within a layer. + dla_local_dram_size (int): Host RAM used by DLA to share intermediate tensor data across operations + dla_global_dram_size (int): Host RAM used by DLA to store weights and metadata for execution + truncate_double (bool): Truncate weights provided in double (float64) to float32 + calibrator (Union(torch_tensorrt._C.IInt8Calibrator, tensorrt.IInt8Calibrator)): Calibrator object which will provide data to the PTQ system for INT8 Calibration + require_full_compilation (bool): Require modules to be compiled end to end or return an error as opposed to returning a hybrid graph where operations that cannot be run in TensorRT are run in PyTorch + min_block_size (int): The minimum number of contiguous TensorRT convertible operations in order to run a set of operations in TensorRT + torch_executed_ops (Collection[Target]): Set of aten operators that must be run in PyTorch. An error will be thrown if this set is not empty but ``require_full_compilation`` is True + torch_executed_modules (List[str]): List of modules that must be run in PyTorch. An error will be thrown if this list is not empty but ``require_full_compilation`` is True + pass_through_build_failures (bool): Error out if there are issues during compilation (only applicable to torch.compile workflows) + max_aux_stream (Optional[int]): Maximum streams in the engine + version_compatible (bool): Build the TensorRT engines compatible with future versions of TensorRT (Restrict to lean runtime operators to provide version forward compatibility for the engines) + optimization_level: (Optional[int]): Setting a higher optimization level allows TensorRT to spend longer engine building time searching for more optimization options. The resulting engine may have better performance compared to an engine built with a lower optimization level. The default optimization level is 3. Valid values include integers from 0 to the maximum optimization level, which is currently 5. Setting it to be greater than the maximum level results in identical behavior to the maximum level. + use_python_runtime: (bool): Return a graph using a pure Python runtime, reduces options for serialization + use_fast_partitioner: (bool): Use the adjacency based partitioning scheme instead of the global partitioner. Adjacency partitioning is faster but may not be optimal. Use the global paritioner (``False``) if looking for best performance + enable_experimental_decompositions (bool): Use the full set of operator decompositions. These decompositions may not be tested but serve to make the graph easier to convert to TensorRT, potentially increasing the amount of graphs run in TensorRT. + dryrun (bool): Toggle for "Dryrun" mode, running everything except conversion to TRT and logging outputs + hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer) + timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation + lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime. + **kwargs: Any, + Returns: + MutableTorchTensorRTModule + """ + # The order to initialize this module is + # 1. Set init_finished to False + # 2. Initialize all attributes + # 3. Add the module base class + # 4. Set the init_finished to True + # After initialization, no new attribute should be added to the module __dict__ + # Otherwise, it will cause undefined behavior + + object.__setattr__(self, "init_finished", False) + self.refit_state = RefitState() + self.pytorch_model = _make_refit_change_trigger(pytorch_model, self.refit_state) + self.original_model = pytorch_model + # Process settings + self.gm: Any = None + self.exp_program: Any = None + self.arg_inputs: tuple[Any, ...] = tuple() + self.kwarg_inputs: dict[str, Any] = {} + device = to_torch_tensorrt_device(device) + enabled_precisions = {dtype._from(p) for p in enabled_precisions} + assert ( + make_refitable + ), "'make_refitable' has to be True for a MutableTorchTensorRTModule." + compilation_options = { + "enabled_precisions": ( + enabled_precisions + if enabled_precisions + else _defaults.ENABLED_PRECISIONS + ), + "debug": debug, + "device": device, + "assume_dynamic_shape_support": assume_dynamic_shape_support, + "workspace_size": workspace_size, + "min_block_size": min_block_size, + "torch_executed_ops": ( + torch_executed_ops if torch_executed_ops is not None else set() + ), + "pass_through_build_failures": pass_through_build_failures, + "max_aux_streams": max_aux_streams, + "version_compatible": version_compatible, + "optimization_level": optimization_level, + "use_python_runtime": use_python_runtime, + "truncate_double": truncate_double, + "use_fast_partitioner": use_fast_partitioner, + "num_avg_timing_iters": num_avg_timing_iters, + "enable_experimental_decompositions": enable_experimental_decompositions, + "require_full_compilation": require_full_compilation, + "disable_tf32": disable_tf32, + "sparse_weights": sparse_weights, + "make_refitable": make_refitable, + "engine_capability": engine_capability, + "dla_sram_size": dla_sram_size, + "dla_local_dram_size": dla_local_dram_size, + "dla_global_dram_size": dla_global_dram_size, + "dryrun": dryrun, + "hardware_compatible": hardware_compatible, + "timing_cache_path": timing_cache_path, + } + + self.settings = CompilationSettings(**compilation_options) + self.run_info: Optional[tuple[Any, ...]] = None + self.state_dict_metadata: dict[str, torch.Size] = {} + self.store_state_dict_metadata() + + cls = self.__class__ + self.__class__ = type( + self.original_model.__class__.__name__, + (cls, pytorch_model.__class__), + {}, + ) + self.init_finished = True + + def store_state_dict_metadata(self) -> None: + + for k, v in self.original_model.state_dict().items(): + self.state_dict_metadata[k] = v.shape + + def load_state_dict( + self, state_dict: Dict[str, Any], strict: bool = True, assign: bool = False + ) -> None: + self.refit_state.set_state(RefitFlag.NEEDS_REFIT) + self.original_model.load_state_dict(state_dict, strict=strict, assign=assign) + + @staticmethod + def _transform_state_dict(sd: Dict[str, Any]) -> Dict[str, torch.nn.Parameter]: + return {k: torch.nn.Parameter(v, requires_grad=False) for k, v in sd.items()} + + def update_refit_condition(self) -> None: + # 2-stage check to determine whether the module should be intact, refitted, or recompiled. + + # Default refit + self.refit_state.set_state(RefitFlag.NEEDS_REFIT) + + # Run the same inputs through pytorch model and compare the result to previous run of graph module + # to determine whether refit/recompilation is needed. If the output is the same, no further process needed. + if self.run_info: + args, kwargs, result = self.run_info + self.original_model.to(to_torch_device(self.settings.device)) + new_result = self.original_model(*args, **kwargs) + self.original_model.cpu() + torch.cuda.empty_cache() + if MutableTorchTensorRTModule.check_output_equal(result, new_result): + self.refit_state.set_state(RefitFlag.LIVE) + return + + # Since we do not have access to the previous state_dict, we can only use state_dict_metadata + # to determine whether the keys or weight shape is changed. + sd, sd_meta = self.original_model.state_dict(), self.state_dict_metadata + if sd.keys() != sd_meta.keys(): + # If keys are not identical, recompile. + self.refit_state.set_state(RefitFlag.NEEDS_RECOMPILE) + return + + for k in sd.keys(): + if sd[k].shape != sd_meta[k]: + # If weight shapes are not identical, recompile. + self.refit_state.set_state(RefitFlag.NEEDS_RECOMPILE) + return + + return + + def refit_gm(self) -> None: + """ + Refit the TRT graph module with any updates. + This function should be called whenever the weight values get changed but the weight structure remains + the same. + MutableTorchTensorRTModule automatically catches weight value updates and call this function to refit the module. + If it fails to catch the changes, please call this function manually to update the TRT graph module. + """ + self.original_model.to(to_torch_device(self.settings.device)) + if self.exp_program is None: + self.exp_program = torch.export.export( + self.original_model, self.arg_inputs, kwargs=self.kwarg_inputs + ) + else: + self.exp_program._state_dict = ( + MutableTorchTensorRTModule._transform_state_dict( + self.original_model.state_dict() + ) + ) + self.gm = refit_module_weights( + self.gm, self.exp_program, use_weight_map_cache=True, in_place=True + ) + + self.original_model.cpu() + torch.cuda.empty_cache() + + def compile(self) -> None: + """ + (Re)compile the TRT graph module using the PyTorch module. + This function should be called whenever the weight structure get changed (shape, more layers...) + MutableTorchTensorRTModule automatically catches weight value updates and call this function to recompile. + If it fails to catch the changes, please call this function manually to recompile the TRT graph module. + """ + # Export the module + self.original_model.to(to_torch_device(self.settings.device)) + self.exp_program = torch.export.export( + self.original_model, + self.arg_inputs, + kwargs=self.kwarg_inputs, + ) + self.gm = dynamo_compile( + self.exp_program, + arg_inputs=self.arg_inputs, + kwarg_inputs=self.kwarg_inputs, + **self.settings.__dict__, + ) + self.original_model.cpu() + torch.cuda.empty_cache() + + def _validate_inputs(self, *args: Any, **kwargs: Any) -> None: + if ( + not self.arg_inputs + or not MutableTorchTensorRTModule.check_inputs_equal(self.arg_inputs, args) + or not MutableTorchTensorRTModule.check_inputs_equal( + self.kwarg_inputs, kwargs + ) + ): + logger.info("Input change detected.") + self.refit_state.set_state(RefitFlag.NEEDS_RECOMPILE) + self.store_inputs(args, kwargs) + + def store_inputs(self, arg_inputs: Any, kwarg_inputs: Any) -> None: + self.arg_inputs = arg_inputs + self.kwarg_inputs = kwarg_inputs + + @staticmethod + def process_kwarg_inputs(inputs: Any) -> Any: + # Process kwarg inputs to be acceptable for Torch-TensorRT + if isinstance(inputs, dict): + # None should be excluded. AOT compile also does not allow dynamic control flow, bool is also excluded. + return { + k: MutableTorchTensorRTModule.process_kwarg_inputs(v) + for k, v in inputs.items() + if (v is not None and not isinstance(v, bool)) + } + elif isinstance(inputs, torch.Tensor): + return inputs + elif isinstance(inputs, (int, float, np.ndarray)): + return torch.tensor(inputs) + elif isinstance(inputs, (list, tuple)): + if None not in inputs: + return type(inputs)( + [MutableTorchTensorRTModule.process_kwarg_inputs(v) for v in inputs] + ) + + raise ValueError( + f"Invalid input type {type(inputs)} encountered in the input. " + + "Allowed input types: {torch_tensorrt.Input, torch.Tensor, list, tuple, dict}" + ) + + def forward(self, *args: Any, **kwargs: Any) -> Any: + # Step 1: Check whether the input shape has changed + kwargs = MutableTorchTensorRTModule.process_kwarg_inputs(kwargs) + self._validate_inputs(*args, **kwargs) + + # Step 2: If the flag is unknown, it could be a recompile or refit. + if self.refit_state.get_state() == RefitFlag.UNKNOWN: + # Update the flag + self.update_refit_condition() + + # Step 3: Refit/recompile accordingly + if self.refit_state.get_state() == RefitFlag.NEEDS_RECOMPILE: + logger.info("(Re)Compiling the engine...") + self.compile() + self.store_state_dict_metadata() + self.refit_state.set_state(RefitFlag.LIVE) + + elif self.refit_state.get_state() == RefitFlag.NEEDS_REFIT: + logger.info("Model weight change detected. Refitting the module...") + try: + self.refit_gm() + except Exception as e: + logger.error(e) + logger.error("Model refit failed. Recompiling the graph module.") + self.compile() + self.store_state_dict_metadata() + self.refit_state.set_state(RefitFlag.LIVE) + + result = self.gm(*args, **kwargs) + # Storing inputs and outputs for verification when the state is unknown + self.run_info = (args, kwargs, result) + return result + + def to(self, device: str) -> None: + logger.warning("Original PyTorch model is moved. CPU offload may failed.") + self.orignial_model.to(device) + + def __deepcopy__(self, memo: Any) -> Any: + cls = self.__class__ + result = cls.__new__(cls) + memo[id(self)] = result + for k, v in self.__dict__.items(): + if k != "pytorch_model": + setattr(result, k, deepcopy(v, memo)) + result.pytorch_model = _make_refit_change_trigger( + result.original_model, result.refit_state + ) + return result + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + return self.forward(*args, **kwargs) + + def __getattr__(self, name: str) -> Any: + + if name in self.__dict__: + # this object has it + return getattr(self, name) + + if "pytorch_model" not in self.__dict__: + raise AttributeError( + "Module is not properly initiated. Pytorch model is not found in the module." + ) + + return getattr(self.pytorch_model, name) + + def __delattr__(self, name: str) -> Any: + + if name in self.__dict__: + # this object has it + super().__delattr__(name) + + return self.pytorch_model.__delattr__(name) + + def __setattr__(self, name: str, value: Any) -> None: + # When the module finished initialization, any modification to attributes that does not exist + # in __dict__ will be handled in pytorch module. + if self.init_finished: + if name in self.__dict__: + object.__setattr__(self, name, value) + else: + # Capture attribute change + self.refit_state.set_state(RefitFlag.UNKNOWN) + # We want to make sure the original PyTorch model does not have a trigger wrapper + value = recursively_remove_trigger(value) + setattr(self.original_model, name, value) + else: + object.__setattr__(self, name, value) + + @staticmethod + def check_output_equal( + output1: Any, + output2: Any, + ) -> bool: + # TODO: Move this to utils when all PRs are merged. This can be used by other functions. + if type(output1) != type(output2): + logger.warning( + "This module does not support using output verification to skip refit. Refit will be performed \ + whenever the state is UNKNOWN" + ) + return False + + if isinstance(output1, torch.Tensor): + if output1.shape != output2.shape: + return False + return torch.allclose(output1, output2, 1e-2, 1e-2) # type: ignore + + elif isinstance(output1, (tuple, list)): + if len(output1) != len(output2): + return False + for a, b in zip(output1, output2): + if not MutableTorchTensorRTModule.check_output_equal(a, b): + return False + return True + + elif isinstance(output1, dict): + if output1.keys() != output2.keys(): + return False + for a, b in zip(output1.values(), output2.values()): + if not MutableTorchTensorRTModule.check_output_equal(a, b): + return False + return True + + logger.warning( + "This module does not support using output verification to skip refit. Refit will be performed \ + whenever the state is UNKNOWN" + ) + return False + + @staticmethod + def check_inputs_equal( + input1: Any, + input2: Any, + ) -> bool: + # TODO: Add support for dynamic shape + if isinstance(input1, (tuple, list)): + if len(input1) != len(input2): + return False + for a, b in zip(input1, input2): + if type(a) != type(b): + return False + if isinstance(a, torch.Tensor) and a.shape != b.shape: + return False + elif isinstance(a, bool) and a != b: + return False + + elif isinstance(input1, dict): + if input1.keys() != input2.keys(): + return False + for a, b in zip(input1.values(), input2.values()): + if type(a) != type(b): + return False + if isinstance(a, torch.Tensor) and a.shape != b.shape: + return False + elif isinstance(a, bool) and a != b: + return False + elif isinstance( + a, (list, tuple, dict) + ) and not MutableTorchTensorRTModule.check_inputs_equal(a, b): + return False + return True + + @staticmethod + def save(module: Any, path: str) -> None: + # Cast the object back to MutableTorchTensorRTModule to save + assert ( + not module.settings.use_python_runtime + ), "Python runtime does not support serialization. Save failed." + module.init_finished = False + module.__class__ = MutableTorchTensorRTModule + exp_program = module.exp_program + module.pytorch_model = None + module.exp_program = None + torch.save(module, path) + # Restore deleted attributes + module.exp_program = exp_program + module.pytorch_model = _make_refit_change_trigger( + module.original_model, module.refit_state + ) + cls = module.__class__ + module.__class__ = type( + module.original_model.__class__.__name__, + (cls, module.original_model.__class__), + {}, + ) + + module.init_finished = True + + @staticmethod + def load(path: str) -> Any: + # When the model get saved, init_finished is set to False. + # Class is restored to MutableTorchTensorRTModule, and some attribute is deleted + module = torch.load(path) + module.pytorch_model = _make_refit_change_trigger( + module.original_model, module.refit_state + ) + module.original_model.to(to_torch_device(module.settings.device)) + module.exp_program = torch.export.export( + module.original_model, module.arg_inputs, kwargs=module.kwarg_inputs + ) + module.original_model.to("cpu") + cls = module.__class__ + module.__class__ = type( + module.original_model.__class__.__name__, + (cls, module.original_model.__class__), + {}, + ) + module.init_finished = True + return module + + +def recursively_remove_trigger(obj: Any) -> Any: + # Not safe: If the object has a circular reference (such as a doubly linkded list), this will cause infinite recursion + if obj.__class__.__name__ == "ChangeTriggerWrapper": + obj = obj.instance + + if isinstance(obj, dict): + for k, v in obj.items(): + obj[k] = recursively_remove_trigger(v) + elif isinstance(obj, list): + for i, v in enumerate(obj): + obj[i] = recursively_remove_trigger(v) + else: + if not hasattr(obj, "__dict__") or isinstance(obj, (type,)): + return obj + for k, v in obj.__dict__.items(): + if k[:2] != "__" or k[-2:] != "__": + # We don't want to touch some built in attribute such as __dict__ + setattr(obj, k, recursively_remove_trigger(v)) + + return obj + + +def _make_refit_change_trigger(obj: object, refit_state: RefitState) -> Any: + subclass: type = obj.__class__ + + class ChangeTriggerWrapper(subclass): # type: ignore + # The reason why we want to inherent to the subclass is that we want the ChangeTriggerWrapper shares all functions + # that an ordinary object has. In this way attributes accessed inside a function will be from the __getattr__function + # of ChangeTriggerWrapper, instead of the object itself, thus be recursively wrapped by ChangeTriggerWrapper. + + def __init__(self, obj: Any): + object.__setattr__(self, "instance", obj) + + def __getattr__( + self, name: str + ) -> Any: # Called when the attribute does not exist + obj = getattr(self.instance, name) + if isinstance(obj, torch.nn.Parameter): + # Whenever the user retrieve an attribute that could be related to weights, we set the state to UNKNOWN + self._on_change() + if ( + hasattr(obj, "__dict__") or isinstance(obj, (torch.nn.ModuleList, list)) + ) and not isinstance( + obj, ChangeTriggerWrapper + ): # prevent nesting wrapper + return _make_refit_change_trigger(obj, refit_state) + return obj + + def __setattr__(self, name: str, value: Any) -> None: + # If we need to set __dict__ or instance, we directly set it to the trigger wrapper. + # Enable setting __dict__ is because PyTorch proxy uses __new__ to initialize a shallow copy + # of a module and explicit set the __dict__. If we don't set __dict__ it will get infinite recursion. + if name in ["__dict__", "instance"]: + object.__setattr__(self, name, value) + return + self._on_change() + # We want to make sure the original PyTorch model does not have a trigger wrapper + value = recursively_remove_trigger(value) + setattr(self.instance, name, value) + + def __delattr__(self, name: str) -> None: + self._on_change() + delattr( + self.instance, + name, + ) + + def _on_change(self) -> None: + refit_state.set_state(RefitFlag.UNKNOWN) + logger.info( + "Attribute modification detected. The module will be refitted later." + ) + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + return self.instance(*args, **kwargs) + + def _call_impl(self, *args: Any, **kwargs: Any) -> Any: + return self.instance._call_impl(*args, **kwargs) + + def forward(self, *args: Any, **kwargs: Any) -> Any: + return self.instance.forward(*args, **kwargs) + + def __setitem__(self, item: str, value: Any) -> None: + self._on_change() + # We want to make sure the original PyTorch model does not have a trigger wrapper + value = recursively_remove_trigger(value) + self.instance.__setitem__(item, value) + + def __getitem__(self, items: str) -> Any: + obj = self.instance.__getitem__(items) + if isinstance(obj, ChangeTriggerWrapper): + return obj + return _make_refit_change_trigger(obj, refit_state) + + def __len__(self) -> int: + return len(self.instance) + + def __iter__(self) -> Iterator[Any]: + return iter(self.instance) + + return ChangeTriggerWrapper(obj) diff --git a/tests/py/dynamo/models/test_export_kwargs_serde.py b/tests/py/dynamo/models/test_export_kwargs_serde.py index 3ecad11a55..08b23d55e0 100644 --- a/tests/py/dynamo/models/test_export_kwargs_serde.py +++ b/tests/py/dynamo/models/test_export_kwargs_serde.py @@ -51,7 +51,6 @@ def forward(self, x, b=5, c=None, d=None): model = net().eval().to("cuda") args = [torch.rand((1, 3, 224, 224)).to("cuda")] kwargs = { - "b": torch.tensor(6).to("cuda"), "d": {"value": torch.tensor(8).to("cuda")}, } diff --git a/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py b/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py new file mode 100644 index 0000000000..593a4322b7 --- /dev/null +++ b/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py @@ -0,0 +1,473 @@ +# type: ignore +import os +import tempfile +import unittest + +import pytest +import timm +import torch +import torch.nn.functional as F +import torch_tensorrt as torch_trt +import torchvision.models as models +from torch import nn +from torch_tensorrt.dynamo.runtime._MutableTorchTensorRTModule import RefitFlag + +assertions = unittest.TestCase() + + +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, + "TorchScript Frontend is not available", +) +@pytest.mark.unit +def test_resnet18(): + + torch.manual_seed(0) + inputs = [torch.rand((1, 3, 224, 224)).to("cuda")] + + compile_spec = { + "use_python_runtime": False, + "enabled_precisions": {torch.float32}, + "make_refitable": True, + } + + model = models.resnet18(pretrained=False).eval().to("cuda") + model2 = models.resnet18(pretrained=True).eval().to("cuda") + mutable_module = torch_trt.MutableTorchTensorRTModule(model, **compile_spec) + mutable_module(*inputs) + + mutable_module.load_state_dict(model2.state_dict()) + assertions.assertTrue( + mutable_module.refit_state.get_state() == RefitFlag.NEEDS_REFIT, + msg=f"Changing the attribute did not trigger the flag. Test failed.", + ) + # Check the output + expected_outputs, refitted_outputs = model2(*inputs), mutable_module(*inputs) + assertions.assertTrue( + torch_trt.MutableTorchTensorRTModule.check_output_equal( + expected_outputs, refitted_outputs + ), + msg=f"The output of saved and reloaded Mutable Module is not correct.", + ) + + # Clean up model env + torch._dynamo.reset() + + +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, + "TorchScript Frontend is not available", +) +@pytest.mark.unit +def test_save(): + + torch.manual_seed(0) + inputs = [torch.rand((1, 3, 224, 224)).to("cuda")] + + # %% + # Compile the module for the first time and save it. + # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + compile_spec = { + "use_python_runtime": False, + "enabled_precisions": {torch.float32}, + "make_refitable": True, + } + + model = models.resnet18(pretrained=False).eval().to("cuda") + mutable_module = torch_trt.MutableTorchTensorRTModule(model, **compile_spec) + mutable_module(*inputs) + + save_path = os.path.join(tempfile.gettempdir(), "mutable_module.pkl") + torch_trt.MutableTorchTensorRTModule.save(mutable_module, save_path) + reload = torch_trt.MutableTorchTensorRTModule.load(save_path) + + loaded_outputs, trt_gm_outputs = reload(*inputs), mutable_module(*inputs) + assertions.assertTrue( + torch_trt.MutableTorchTensorRTModule.check_output_equal( + loaded_outputs, trt_gm_outputs + ), + msg=f"The output of saved and reloaded Mutable Module is not correct.", + ) + + # Clean up model env + torch._dynamo.reset() + + +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, + "TorchScript Frontend is not available", +) +@pytest.mark.unit +def test_resnet18_modify_attribute(): + + torch.manual_seed(0) + inputs = [torch.rand((1, 3, 224, 224)).to("cuda")] + + compile_spec = { + "use_python_runtime": False, + "enabled_precisions": {torch.float32}, + "make_refitable": True, + } + + model = models.resnet18(pretrained=False).eval().to("cuda") + mutable_module = torch_trt.MutableTorchTensorRTModule(model, **compile_spec) + mutable_module(*inputs) + + mutable_module.conv1.weight = nn.Parameter( + torch.rand_like(mutable_module.conv1.weight) + ) + assertions.assertEqual( + mutable_module.refit_state.get_state(), + RefitFlag.UNKNOWN, + msg=f"Changing the attribute did not trigger the flag. Test failed.", + ) + # Check the output + model.to("cuda") + expected_outputs, refitted_outputs = model(*inputs), mutable_module(*inputs) + for expected_output, refitted_output in zip(expected_outputs, refitted_outputs): + assertions.assertTrue( + torch.allclose(expected_output, refitted_output, 5e-2, 5e-2), + msg=f"The output of refitted Mutable Module is not correct.", + ) + + # # Clean up model env + torch._dynamo.reset() + + +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, + "TorchScript Frontend is not available", +) +@pytest.mark.unit +def test_resnet18_modify_attribute_no_refit(): + + torch.manual_seed(0) + inputs = [torch.rand((1, 3, 224, 224)).to("cuda")] + + compile_spec = { + "use_python_runtime": False, + "enabled_precisions": {torch.float32}, + "make_refitable": True, + } + + model = models.resnet18(pretrained=False).eval().to("cuda") + mutable_module = torch_trt.MutableTorchTensorRTModule(model, **compile_spec) + mutable_module(*inputs) + + mutable_module.conv1.weight = nn.Parameter( + mutable_module.original_model.conv1.weight + ) + assertions.assertEqual( + mutable_module.refit_state.get_state(), + RefitFlag.UNKNOWN, + msg=f"Changing the attribute did not trigger the flag. Test failed.", + ) + # Check the output + mutable_module.update_refit_condition() + assertions.assertEqual( + mutable_module.refit_state.get_state(), + RefitFlag.LIVE, + msg=f"update_refit_condition() failed to set the flag to LIVE.", + ) + # Check the output + model.to("cuda") + expected_outputs, refitted_outputs = model(*inputs), mutable_module(*inputs) + for expected_output, refitted_output in zip(expected_outputs, refitted_outputs): + assertions.assertTrue( + torch.allclose(expected_output, refitted_output, 1e-2, 1e-2), + msg=f"The output of refitted Mutable Module is not correct.", + ) + + # # Clean up model env + torch._dynamo.reset() + + +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, + "TorchScript Frontend is not available", +) +@pytest.mark.unit +def test_custom_model_with_kwarg(): + class net(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 12, 3, padding=1) + self.bn = nn.BatchNorm2d(12) + self.conv2 = nn.Conv2d(12, 12, 3, padding=1) + self.fc1 = nn.Linear(12 * 56 * 56, 10) + + def forward(self, x, b=5, c=None, d=None): + x = self.conv1(x) + x = F.relu(x) + x = self.bn(x) + x = F.max_pool2d(x, (2, 2)) + x = self.conv2(x) + x = F.relu(x) + x = F.max_pool2d(x, (2, 2)) + x = torch.flatten(x, 1) + x = x + b + if c is not None: + x = x * c + if d is not None: + x = x - d["value"] + return self.fc1(x) + + torch.manual_seed(0) + model = net().eval().to("cuda") + args = [torch.rand((1, 3, 224, 224)).to("cuda")] + kwargs = { + "d": {"value": torch.tensor(8).to("cuda")}, + "b": torch.tensor(6).to("cuda"), + } + + compile_spec = { + "enabled_precisions": {torch.float}, + "pass_through_build_failures": True, + "optimization_level": 1, + "min_block_size": 1, + "ir": "dynamo", + "make_refitable": True, + } + + mutable_module = torch_trt.MutableTorchTensorRTModule(model, **compile_spec) + mutable_module(*args, **kwargs) + + torch.manual_seed(2) + model2 = net().eval().to("cuda") + mutable_module.load_state_dict(model2.state_dict()) + + # Check the output + expected_outputs, refitted_outputs = model2(*args, **kwargs), mutable_module( + *args, **kwargs + ) + assertions.assertTrue( + torch_trt.MutableTorchTensorRTModule.check_output_equal( + expected_outputs, refitted_outputs + ), + msg=f"The output of saved and reloaded Mutable Module is not correct.", + ) + + # Clean up model env + torch._dynamo.reset() + + +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, + "TorchScript Frontend is not available", +) +@pytest.mark.unit +def test_custom_model_with_inplace_init(): + class net(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 12, 3, padding=1) + self.bn = nn.BatchNorm2d(12) + self.conv2 = nn.Conv2d(12, 12, 3, padding=1) + self.fc1 = nn.Linear(12 * 56 * 56, 10) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(x) + x = self.bn(x) + x = F.max_pool2d(x, (2, 2)) + x = self.conv2(x) + x = F.relu(x) + x = F.max_pool2d(x, (2, 2)) + x = torch.flatten(x, 1) + return self.fc1(x) + + def set_weights(self): + nn.init.normal_(self.conv1.weight) + + torch.manual_seed(0) + model = net().eval().to("cuda") + args = [torch.rand((1, 3, 224, 224)).to("cuda")] + + compile_spec = { + "enabled_precisions": {torch.float}, + "pass_through_build_failures": True, + "optimization_level": 1, + "min_block_size": 1, + "ir": "dynamo", + "make_refitable": True, + } + + mutable_module = torch_trt.MutableTorchTensorRTModule(model, **compile_spec) + mutable_module(*args) + + mutable_module.set_weights() + assertions.assertEqual( + mutable_module.refit_state.get_state(), + RefitFlag.UNKNOWN, + msg=f"Changing the attribute did not trigger the flag. Test failed.", + ) + + # Check the output + model.cuda() + expected_outputs, refitted_outputs = model(*args), mutable_module(*args) + assertions.assertTrue( + torch_trt.MutableTorchTensorRTModule.check_output_equal( + expected_outputs, refitted_outputs + ), + msg=f"The output of saved and reloaded Mutable Module is not correct.", + ) + + # Clean up model env + torch._dynamo.reset() + + +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, + "TorchScript Frontend is not available", +) +@pytest.mark.unit +def test_custom_model_with_init_recompile(): + class net(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 12, 3, padding=1) + self.bn = nn.BatchNorm2d(12) + self.conv2 = nn.Conv2d(12, 12, 3, padding=1) + self.fc1 = nn.Linear(12 * 56 * 56, 10) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(x) + x = self.bn(x) + x = F.max_pool2d(x, (2, 2)) + x = self.conv2(x) + x = F.relu(x) + x = F.max_pool2d(x, (2, 2)) + x = torch.flatten(x, 1) + return self.fc1(x) + + def set_layer(self): + self.fc1 = nn.Linear(12 * 56 * 56, 3) + + torch.manual_seed(0) + model = net().eval().to("cuda") + args = [torch.rand((1, 3, 224, 224)).to("cuda")] + + compile_spec = { + "enabled_precisions": {torch.float}, + "pass_through_build_failures": True, + "optimization_level": 1, + "min_block_size": 1, + "ir": "dynamo", + "make_refitable": True, + } + + mutable_module = torch_trt.MutableTorchTensorRTModule(model, **compile_spec) + mutable_module(*args) + + mutable_module.set_layer() + assertions.assertEqual( + mutable_module.refit_state.get_state(), + RefitFlag.UNKNOWN, + msg=f"Changing the attribute did not trigger the flag. Test failed.", + ) + + # Check the output + model.cuda() # move offloaded model from cpu to cuda + expected_outputs, refitted_outputs = model(*args), mutable_module(*args) + assertions.assertTrue( + torch_trt.MutableTorchTensorRTModule.check_output_equal( + expected_outputs, refitted_outputs + ), + msg=f"The output of saved and reloaded Mutable Module is not correct.", + ) + + # Clean up model env + torch._dynamo.reset() + + +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, + "TorchScript Frontend is not available", +) +@pytest.mark.unit +def test_custom_model_with_kwarg_different_input(): + class net(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 12, 3, padding=1) + self.bn = nn.BatchNorm2d(12) + self.conv2 = nn.Conv2d(12, 12, 3, padding=1) + self.fc1 = nn.Linear(12 * 56 * 56, 10) + + def forward(self, x, b=5, c=None, d=None): + x = self.conv1(x) + x = F.relu(x) + x = self.bn(x) + x = F.max_pool2d(x, (2, 2)) + x = self.conv2(x) + x = F.relu(x) + x = F.max_pool2d(x, (2, 2)) + x = torch.flatten(x, 1) + x = x + b + if c is not None: + x = x * c + if d is not None: + x = x - d["value"] + return self.fc1(x) + + torch.manual_seed(0) + model = net().eval().to("cuda") + args = [torch.rand((1, 3, 224, 224)).to("cuda")] + kwargs = { + "d": {"value": torch.tensor(8).to("cuda")}, + "b": torch.tensor(6).to("cuda"), + } + + compile_spec = { + "enabled_precisions": {torch.float}, + "pass_through_build_failures": True, + "optimization_level": 1, + "min_block_size": 1, + "ir": "dynamo", + "make_refitable": True, + } + + mutable_module = torch_trt.MutableTorchTensorRTModule(model, **compile_spec) + mutable_module(*args, **kwargs) + + kwargs = { + "d": {"value": torch.tensor(8).to("cuda")}, + "b": torch.tensor(6).to("cuda"), + "c": torch.tensor(0.5).to("cuda"), + } + + # Check the output + model.cuda() + expected_outputs, refitted_outputs = model(*args, **kwargs), mutable_module( + *args, **kwargs + ) + assertions.assertTrue( + torch_trt.MutableTorchTensorRTModule.check_output_equal( + expected_outputs, refitted_outputs + ), + msg=f"The output of saved and reloaded Mutable Module is not correct.", + ) + + # Clean up model env + torch._dynamo.reset() + + +@pytest.mark.unit +def test_check_output_equal(): + torch.manual_seed(0) + a = { + "a": torch.rand(10, 30), + "b": [torch.rand(10, 30), torch.rand(5, 5)], + "c": {"a": torch.rand(10, 30), "b": [torch.rand(10, 30), torch.rand(5, 5)]}, + } + torch.manual_seed(0) + b = { + "a": torch.rand(10, 30), + "b": [torch.rand(10, 30), torch.rand(5, 5)], + "c": {"a": torch.rand(10, 30), "b": [torch.rand(10, 30), torch.rand(5, 5)]}, + } + assertions.assertTrue( + torch_trt.MutableTorchTensorRTModule.check_output_equal(a, b), + msg=f"test_check_output_equal is not correct.", + )