Skip to content

✨[Feature] Automatic conversion for int32<->int64 in fallback #1382

Closed
@inocsin

Description

@inocsin

As we know, if there are some operators that torch-tensorrt doesn't support, the model will be partitioned into tensorrt and torch subgraphs. TensorRT doesn't support int64 value and will truncate int64 to int32.

In some cases, the operators in the torch subgraph consume int64 value(like aten::index), and this value is produced from tensorrt subgraph(truncated into int32), this will cause an error. We need to track the data type conversion and automatic convert the data type back to the origianl type between torch and tensorrt.

Here is a typical case

import torch
import torch.nn as nn
import torch_tensorrt


class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()

  def forward(self, data, index):
    src = 1
    index = index.to(torch.int64)
    data = data * data
    data = data.scatter_(1,index,src)
    data = data + 1
    return data

data = torch.randn([5,5])
index = torch.randint(0,4,[2,2], dtype = torch.int32)

compile_spec = {
    "inputs": None,
    "device": {
        "device_type": torch_tensorrt.DeviceType.GPU,
        "gpu_id": 0,
        "allow_gpu_fallback": False,
        "disable_tf32": False
    },
    "truncate_long_and_double": True,
    "require_full_compilation": False,
    "torch_executed_ops": ["aten::scatter_", "aten::scatter"],
    "min_block_size": 1
}

net = Net()
model = torch.jit.trace(net, (data, index))

torch_type = torch.float32
min_shape = [5,5]
data2 = torch_tensorrt.Input(shape=min_shape, dtype=torch_type)
torch_type = torch.int32
index2 = torch_tensorrt.Input(shape=min_shape, dtype=torch_type)

inputs = [data2, index2]

compile_spec["inputs"] = inputs

with torch_tensorrt.logging.debug():
    trt_mod = torch_tensorrt.ts.compile(model, **compile_spec)
inputs = [data.cuda(), index.cuda()]
output = trt_mod(*inputs)
print(output)

subgraph log

INFO: [Torch-TensorRT - Debug Build] - Partitioned Graph: [Segment Block @0:
    Target: TensorRT

    Graph: graph(%index.1 : Tensor,
      %data.1 : Tensor):
  %2 : int = prim::Constant[value=4]() # test_int64.py:28:0
  %3 : bool = prim::Constant[value=0]() # test_int64.py:28:0
  %4 : NoneType = prim::Constant()
  %index : Tensor = aten::to(%index.1, %2, %3, %3, %4) # test_int64.py:28:0
  %data.3 : Tensor = aten::mul(%data.1, %data.1) # test_int64.py:29:0
  return (%index, %data.3)

Segment Block @1:
    Target: Torch

    Graph: graph(%data.3 : Tensor,
      %index : Tensor):
  %2 : int = prim::Constant[value=1]() # test_int64.py:30:0
  %0 : Tensor = aten::scatter(%data.3, %2, %index, %2) # test_int64.py:30:0
  return (%0)

Segment Block @2:
    Target: TensorRT

    Graph: graph(%1 : Tensor):
  %2 : Tensor = prim::Constant[value={1}]() # test_int64.py:31:0
  %3 : int = prim::Constant[value=1]() # test_int64.py:30:0
  %0 : Tensor = aten::add(%1, %2, %3) # test_int64.py:31:0
  return (%0)

]

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions