Closed
Description
Is your feature request related to a problem? Please describe.
In this graph:
INFO: [Torch-TensorRT - Debug Build] - Partitioned Graph: [Segment Block @0:
Target: TensorRT
Graph: graph(%index.1 : Tensor,
%data.1 : Tensor):
%2 : int = prim::Constant[value=4]() # test_int64.py:28:0
%3 : bool = prim::Constant[value=0]() # test_int64.py:28:0
%4 : NoneType = prim::Constant()
%index : Tensor = aten::to(%index.1, %2, %3, %3, %4) # test_int64.py:28:0
%data.3 : Tensor = aten::mul(%data.1, %data.1) # test_int64.py:29:0
return (%index, %data.3)
Segment Block @1:
Target: Torch
Graph: graph(%data.3 : Tensor,
%index : Tensor):
%2 : int = prim::Constant[value=1]() # test_int64.py:30:0
%0 : Tensor = aten::scatter(%data.3, %2, %index, %2) # test_int64.py:30:0
return (%0)
%index
is converted to int32, but in block 1, scatter
function needs data type int64 but got int32.
This is because TensorRT doesn't support int64, so Torch-TensorRT will cast all int64=>int32 to run them in TensorRT. However, when partitioning is enabled, some functions in Torch they still need type int64 to run.
Describe the solution you'd like
This could be supported if every aten::to operation is recorded and then cast the types between torch and tensorrt.