Skip to content

fix: Allow low rank inputs in Python Runtime #2282

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 26 additions & 37 deletions py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from __future__ import annotations

import logging
from typing import Any, Dict, List, Optional, Sequence, Tuple

import tensorrt as trt
import torch
from torch.nn import Module
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter

# @manual=//deeplearning/trt/python:py_tensorrt
import tensorrt as trt
logger = logging.getLogger(__name__)


class PythonTorchTensorRTModule(Module): # type: ignore[misc]
Expand All @@ -22,14 +23,12 @@ def __init__(
engine: trt.ICudaEngine,
input_names: Optional[List[str]] = None,
output_names: Optional[List[str]] = None,
cuda_graph_batch_size: int = -1,
):
super(PythonTorchTensorRTModule, self).__init__()
self._register_state_dict_hook(PythonTorchTensorRTModule._on_state_dict)
self.engine = engine
self.input_names = input_names if input_names is not None else []
self.output_names = output_names if output_names is not None else []
self.cuda_graph_batch_size = cuda_graph_batch_size
self.initialized = False
self._initialize()

Expand Down Expand Up @@ -107,7 +106,6 @@ def _on_state_dict(self, state_dict: Dict[str, Any], prefix: str, _: Any) -> Non
state_dict[prefix + "engine"] = bytearray(self.engine.serialize())
state_dict[prefix + "input_names"] = self.input_names
state_dict[prefix + "output_names"] = self.output_names
state_dict[prefix + "cuda_graph_batch_size"] = self.cuda_graph_batch_size

def _load_from_state_dict(
self,
Expand Down Expand Up @@ -156,8 +154,6 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
self.input_names
), f"Wrong number of inputs, expect {len(self.input_names)} get {len(inputs)}."

# This is only used when the trt engine is using implicit batch dim.
batch_size = inputs[0].shape[0]
contiguous_inputs: List[torch.Tensor] = [i.contiguous() for i in inputs]
bindings: List[Any] = [None] * (
len(self.input_names)
Expand All @@ -166,25 +162,29 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
)

for i, input_name in enumerate(self.input_names):
assert inputs[
i
].is_cuda, f"{i}th input({input_name}) is not on cuda device."
if not contiguous_inputs[i].is_cuda:
logger.warning(
f"Detected input {input_name} of engine {self.engine.name} is not on a cuda device. "
"This tensor is being moved by the runtime but for performance considerations, "
"ensure your inputs are all on GPU and open an issue here "
"(https://github.com/pytorch/TensorRT/issues) if this warning persists."
)
contiguous_inputs = (
contiguous_inputs[:i]
+ [contiguous_inputs[i].cuda()]
+ contiguous_inputs[i + 1 :]
)

assert (
inputs[i].dtype == self.input_dtypes[i]
), f"Dtype mismatch for {i}th input({input_name}). Expect {self.input_dtypes[i]}, got {inputs[i].dtype}."
contiguous_inputs[i].dtype == self.input_dtypes[i]
), f"Dtype mismatch for {i}th input({input_name}). Expect {self.input_dtypes[i]}, got {contiguous_inputs[i].dtype}."

idx = self.input_binding_indices_in_order[i]
bindings[idx] = contiguous_inputs[i].data_ptr()

if not self.engine.has_implicit_batch_dimension:
self.context.set_binding_shape(
idx, tuple(contiguous_inputs[i].shape)
)
else:
assert inputs[i].size()[1:] == self.input_shapes[i], (
f"Shape mismatch for {i}th input({input_name}). "
f"Expect {self.input_shapes[i]}, got {inputs[i].size()[1:]}."
)
self.context.set_binding_shape(
idx, tuple(contiguous_inputs[i].shape)
)

with torch.autograd.profiler.record_function(
"PythonTorchTensorRTModule:ProcessOutputs"
Expand All @@ -193,10 +193,7 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
outputs: List[torch.Tensor] = []

for i, idx in enumerate(self.output_binding_indices_in_order):
if self.engine.has_implicit_batch_dimension:
shape = (batch_size,) + self.output_shapes[i]
else:
shape = tuple(self.context.get_binding_shape(idx))
shape = tuple(self.context.get_binding_shape(idx))

output = torch.empty(
size=shape,
Expand All @@ -207,10 +204,7 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
bindings[idx] = output.data_ptr()

for i, idx in enumerate(self.hidden_output_binding_indices_in_order):
if self.engine.has_implicit_batch_dimension:
shape = (batch_size,) + self.hidden_output_shapes[i]
else:
shape = tuple(self.context.get_binding_shape(idx))
shape = tuple(self.context.get_binding_shape(idx))

output = torch.empty(
size=shape,
Expand All @@ -222,14 +216,9 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
with torch.autograd.profiler.record_function(
"PythonTorchTensorRTModule:TensorRTRuntime"
):
if self.engine.has_implicit_batch_dimension:
self.context.execute_async(
batch_size, bindings, torch.cuda.current_stream().cuda_stream
)
else:
self.context.execute_async_v2(
bindings, torch.cuda.current_stream().cuda_stream
)
self.context.execute_async_v2(
bindings, torch.cuda.current_stream().cuda_stream
)

if len(outputs) == 1:
return outputs[0]
Expand Down
81 changes: 81 additions & 0 deletions tests/py/dynamo/runtime/test_python_runtime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import torch
import torch_tensorrt
from torch.testing._internal.common_utils import TestCase, run_tests


class TestLowRankInputs(TestCase):
def test_0D_input(self):
class Tensor0DInput(torch.nn.Module):
def forward(self, x):
return x * 7

inputs = [
torch.tensor(
3,
)
.cuda()
.int(),
]

fx_graph = torch.fx.symbolic_trace(Tensor0DInput())

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = torch_tensorrt.compile(
fx_graph,
"torch_compile",
inputs,
min_block_size=1,
pass_through_build_failures=True,
use_python_runtime=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = fx_graph(*inputs).detach().cpu()

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
msg=f"0D-Tensor TRT outputs don't match with the original model.",
)
torch._dynamo.reset()

def test_1D_input(self):
class Tensor1DInput(torch.nn.Module):
def forward(self, x, y):
return (x + 7.1) / (y * 2.1)

inputs = [torch.rand((3, 1)).cuda(), torch.rand((3, 1)).cuda()]

fx_graph = torch.fx.symbolic_trace(Tensor1DInput())

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = torch_tensorrt.compile(
fx_graph,
"torch_compile",
inputs,
min_block_size=1,
pass_through_build_failures=True,
use_python_runtime=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = fx_graph(*inputs).detach().cpu()

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
msg=f"1D-Tensor TRT outputs don't match with the original model.",
)

# Validate that the runtime moves cpu inputs to cuda
optimized_model(torch.rand((3, 1)), torch.rand((3, 1)))

torch._dynamo.reset()


if __name__ == "__main__":
run_tests()