Description
Bug Description
PyTorch uses advanced indexing to implement GatherND. For example:
torch[[0,1], :, None, torch.tensor((0,1))]
is considered a valid indexing operation.
However, when applying TensorRT compilation, we see the following error:
isObject() INTERNAL ASSERT FAILED at "bazel-out/k8-opt/bin/external/libtorch/_virtual_includes/ATen/ATen/core/ivalue_inl.h":123, please report a bug to PyTorch. Expected Object but got None
To Reproduce
Run with nvcr.io/nvidia/pytorch:22.07-py3
container (gathernd.py
):
import tensorrt
import torch
import torch_tensorrt
class GatherNDListModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x[[1,0], :, [1,0]]
class GatherNDTensorModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.index = torch.tensor((1,0))
def forward(self, x):
return x[self.index, :, self.index]
if __name__ == '__main__':
torch.manual_seed(0)
t = torch.randn(3, 4, 5).cuda()
try:
m = GatherNDListModel()
o = m(t)
print(f"PyTorch: {o.shape}")
mt = torch_tensorrt.compile(m, inputs=[t], truncate_long_and_double=True)
o = mt(t)
print(f"TRT: {o.shape}")
except Exception as err:
print(f"GatherND list indexing failed: {err}")
try:
m = GatherNDTensorModel()
o = m(t)
print(f"PyTorch: {o.shape}")
mt = torch_tensorrt.compile(m, inputs=[t], truncate_long_and_double=True)
o = mt(t)
print(f"TRT: {o.shape}")
except Exception as err:
print(f"GatherND tensor indexing failed: {err}")
Outputs the following:
PyTorch: torch.Size([2, 4])
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Int64 to Int32
GatherND list indexing failed: isObject() INTERNAL ASSERT FAILED at "bazel-out/k8-opt/bin/external/libtorch/_virtual_includes/ATen/ATen/core/ivalue_inl.h":123, please report a bug to PyTorch. Expected Object but got None
PyTorch: torch.Size([2, 4])
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Int64 to Int32
GatherND tensor indexing failed: isObject() INTERNAL ASSERT FAILED at "bazel-out/k8-opt/bin/external/libtorch/_virtual_includes/ATen/ATen/core/ivalue_inl.h":123, please report a bug to PyTorch. Expected Object but got None
Note that the pytorch version (ie. no-TRT) model outputs the expected shapes, but the TRT conversion fails.
Expected behavior
The TRT compilation should succeed and output the same output as the non-TRT version.
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
nvcr.io/nvidia/pytorch:22.07-py3