Closed
Description
As we know, if there are some operators that torch-tensorrt doesn't support, the model will be partitioned into tensorrt and torch subgraphs. TensorRT doesn't support int64 value and will truncate int64 to int32.
In some cases, the operators in the torch subgraph consume int64 value(like aten::index), and this value is produced from tensorrt subgraph(truncated into int32), this will cause an error. We need to track the data type conversion and automatic convert the data type back to the origianl type between torch and tensorrt.
Here is a typical case
import torch
import torch.nn as nn
import torch_tensorrt
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
def forward(self, data, index):
src = 1
index = index.to(torch.int64)
data = data * data
data = data.scatter_(1,index,src)
data = data + 1
return data
data = torch.randn([5,5])
index = torch.randint(0,4,[2,2], dtype = torch.int32)
compile_spec = {
"inputs": None,
"device": {
"device_type": torch_tensorrt.DeviceType.GPU,
"gpu_id": 0,
"allow_gpu_fallback": False,
"disable_tf32": False
},
"truncate_long_and_double": True,
"require_full_compilation": False,
"torch_executed_ops": ["aten::scatter_", "aten::scatter"],
"min_block_size": 1
}
net = Net()
model = torch.jit.trace(net, (data, index))
torch_type = torch.float32
min_shape = [5,5]
data2 = torch_tensorrt.Input(shape=min_shape, dtype=torch_type)
torch_type = torch.int32
index2 = torch_tensorrt.Input(shape=min_shape, dtype=torch_type)
inputs = [data2, index2]
compile_spec["inputs"] = inputs
with torch_tensorrt.logging.debug():
trt_mod = torch_tensorrt.ts.compile(model, **compile_spec)
inputs = [data.cuda(), index.cuda()]
output = trt_mod(*inputs)
print(output)
subgraph log
INFO: [Torch-TensorRT - Debug Build] - Partitioned Graph: [Segment Block @0:
Target: TensorRT
Graph: graph(%index.1 : Tensor,
%data.1 : Tensor):
%2 : int = prim::Constant[value=4]() # test_int64.py:28:0
%3 : bool = prim::Constant[value=0]() # test_int64.py:28:0
%4 : NoneType = prim::Constant()
%index : Tensor = aten::to(%index.1, %2, %3, %3, %4) # test_int64.py:28:0
%data.3 : Tensor = aten::mul(%data.1, %data.1) # test_int64.py:29:0
return (%index, %data.3)
Segment Block @1:
Target: Torch
Graph: graph(%data.3 : Tensor,
%index : Tensor):
%2 : int = prim::Constant[value=1]() # test_int64.py:30:0
%0 : Tensor = aten::scatter(%data.3, %2, %index, %2) # test_int64.py:30:0
return (%0)
Segment Block @2:
Target: TensorRT
Graph: graph(%1 : Tensor):
%2 : Tensor = prim::Constant[value={1}]() # test_int64.py:31:0
%3 : int = prim::Constant[value=1]() # test_int64.py:30:0
%0 : Tensor = aten::add(%1, %2, %3) # test_int64.py:31:0
return (%0)
]