diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index e00b66d8f3..41baecc7ab 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -1,13 +1,14 @@ from __future__ import annotations +import logging from typing import Any, Dict, List, Optional, Sequence, Tuple +import tensorrt as trt import torch from torch.nn import Module from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter -# @manual=//deeplearning/trt/python:py_tensorrt -import tensorrt as trt +logger = logging.getLogger(__name__) class PythonTorchTensorRTModule(Module): # type: ignore[misc] @@ -22,14 +23,12 @@ def __init__( engine: trt.ICudaEngine, input_names: Optional[List[str]] = None, output_names: Optional[List[str]] = None, - cuda_graph_batch_size: int = -1, ): super(PythonTorchTensorRTModule, self).__init__() self._register_state_dict_hook(PythonTorchTensorRTModule._on_state_dict) self.engine = engine self.input_names = input_names if input_names is not None else [] self.output_names = output_names if output_names is not None else [] - self.cuda_graph_batch_size = cuda_graph_batch_size self.initialized = False self._initialize() @@ -107,7 +106,6 @@ def _on_state_dict(self, state_dict: Dict[str, Any], prefix: str, _: Any) -> Non state_dict[prefix + "engine"] = bytearray(self.engine.serialize()) state_dict[prefix + "input_names"] = self.input_names state_dict[prefix + "output_names"] = self.output_names - state_dict[prefix + "cuda_graph_batch_size"] = self.cuda_graph_batch_size def _load_from_state_dict( self, @@ -156,8 +154,6 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]: self.input_names ), f"Wrong number of inputs, expect {len(self.input_names)} get {len(inputs)}." - # This is only used when the trt engine is using implicit batch dim. - batch_size = inputs[0].shape[0] contiguous_inputs: List[torch.Tensor] = [i.contiguous() for i in inputs] bindings: List[Any] = [None] * ( len(self.input_names) @@ -166,25 +162,29 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]: ) for i, input_name in enumerate(self.input_names): - assert inputs[ - i - ].is_cuda, f"{i}th input({input_name}) is not on cuda device." + if not contiguous_inputs[i].is_cuda: + logger.warning( + f"Detected input {input_name} of engine {self.engine.name} is not on a cuda device. " + "This tensor is being moved by the runtime but for performance considerations, " + "ensure your inputs are all on GPU and open an issue here " + "(https://github.com/pytorch/TensorRT/issues) if this warning persists." + ) + contiguous_inputs = ( + contiguous_inputs[:i] + + [contiguous_inputs[i].cuda()] + + contiguous_inputs[i + 1 :] + ) + assert ( - inputs[i].dtype == self.input_dtypes[i] - ), f"Dtype mismatch for {i}th input({input_name}). Expect {self.input_dtypes[i]}, got {inputs[i].dtype}." + contiguous_inputs[i].dtype == self.input_dtypes[i] + ), f"Dtype mismatch for {i}th input({input_name}). Expect {self.input_dtypes[i]}, got {contiguous_inputs[i].dtype}." idx = self.input_binding_indices_in_order[i] bindings[idx] = contiguous_inputs[i].data_ptr() - if not self.engine.has_implicit_batch_dimension: - self.context.set_binding_shape( - idx, tuple(contiguous_inputs[i].shape) - ) - else: - assert inputs[i].size()[1:] == self.input_shapes[i], ( - f"Shape mismatch for {i}th input({input_name}). " - f"Expect {self.input_shapes[i]}, got {inputs[i].size()[1:]}." - ) + self.context.set_binding_shape( + idx, tuple(contiguous_inputs[i].shape) + ) with torch.autograd.profiler.record_function( "PythonTorchTensorRTModule:ProcessOutputs" @@ -193,10 +193,7 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]: outputs: List[torch.Tensor] = [] for i, idx in enumerate(self.output_binding_indices_in_order): - if self.engine.has_implicit_batch_dimension: - shape = (batch_size,) + self.output_shapes[i] - else: - shape = tuple(self.context.get_binding_shape(idx)) + shape = tuple(self.context.get_binding_shape(idx)) output = torch.empty( size=shape, @@ -207,10 +204,7 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]: bindings[idx] = output.data_ptr() for i, idx in enumerate(self.hidden_output_binding_indices_in_order): - if self.engine.has_implicit_batch_dimension: - shape = (batch_size,) + self.hidden_output_shapes[i] - else: - shape = tuple(self.context.get_binding_shape(idx)) + shape = tuple(self.context.get_binding_shape(idx)) output = torch.empty( size=shape, @@ -222,14 +216,9 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]: with torch.autograd.profiler.record_function( "PythonTorchTensorRTModule:TensorRTRuntime" ): - if self.engine.has_implicit_batch_dimension: - self.context.execute_async( - batch_size, bindings, torch.cuda.current_stream().cuda_stream - ) - else: - self.context.execute_async_v2( - bindings, torch.cuda.current_stream().cuda_stream - ) + self.context.execute_async_v2( + bindings, torch.cuda.current_stream().cuda_stream + ) if len(outputs) == 1: return outputs[0] diff --git a/tests/py/dynamo/runtime/test_python_runtime.py b/tests/py/dynamo/runtime/test_python_runtime.py new file mode 100644 index 0000000000..01ad150479 --- /dev/null +++ b/tests/py/dynamo/runtime/test_python_runtime.py @@ -0,0 +1,81 @@ +import torch +import torch_tensorrt +from torch.testing._internal.common_utils import TestCase, run_tests + + +class TestLowRankInputs(TestCase): + def test_0D_input(self): + class Tensor0DInput(torch.nn.Module): + def forward(self, x): + return x * 7 + + inputs = [ + torch.tensor( + 3, + ) + .cuda() + .int(), + ] + + fx_graph = torch.fx.symbolic_trace(Tensor0DInput()) + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + pass_through_build_failures=True, + use_python_runtime=True, + ) + optimized_model_results = optimized_model(*inputs).detach().cpu() + torch_model_results = fx_graph(*inputs).detach().cpu() + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + msg=f"0D-Tensor TRT outputs don't match with the original model.", + ) + torch._dynamo.reset() + + def test_1D_input(self): + class Tensor1DInput(torch.nn.Module): + def forward(self, x, y): + return (x + 7.1) / (y * 2.1) + + inputs = [torch.rand((3, 1)).cuda(), torch.rand((3, 1)).cuda()] + + fx_graph = torch.fx.symbolic_trace(Tensor1DInput()) + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + pass_through_build_failures=True, + use_python_runtime=True, + ) + optimized_model_results = optimized_model(*inputs).detach().cpu() + torch_model_results = fx_graph(*inputs).detach().cpu() + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + msg=f"1D-Tensor TRT outputs don't match with the original model.", + ) + + # Validate that the runtime moves cpu inputs to cuda + optimized_model(torch.rand((3, 1)), torch.rand((3, 1))) + + torch._dynamo.reset() + + +if __name__ == "__main__": + run_tests()