Skip to content

🐛 [Bug] Advanced Indexing/GatherND compilation causes error: isObject() INTERNAL ASSERT FAILED at "libtorch/_virtual_includes/ATen/ATen/core/ivalue_inl.h":123, please report a bug to PyTorch. Expected Object but got None #1274

@chaoz-dev

Description

@chaoz-dev

Bug Description

PyTorch uses advanced indexing to implement GatherND. For example:

torch[[0,1], :, None, torch.tensor((0,1))]

is considered a valid indexing operation.
However, when applying TensorRT compilation, we see the following error:

isObject() INTERNAL ASSERT FAILED at "bazel-out/k8-opt/bin/external/libtorch/_virtual_includes/ATen/ATen/core/ivalue_inl.h":123, please report a bug to PyTorch. Expected Object but got None

To Reproduce

Run with nvcr.io/nvidia/pytorch:22.07-py3 container (gathernd.py):

  import tensorrt
  import torch
  import torch_tensorrt

  class GatherNDListModel(torch.nn.Module):
      def __init__(self):
          super().__init__()

      def forward(self, x):
          return x[[1,0], :, [1,0]]


  class GatherNDTensorModel(torch.nn.Module):
      def __init__(self):
          super().__init__()
          self.index = torch.tensor((1,0))

      def forward(self, x):
          return x[self.index, :, self.index]

  if __name__ == '__main__':
      torch.manual_seed(0)
      t = torch.randn(3, 4, 5).cuda()

      try:
          m = GatherNDListModel()
          o = m(t)
          print(f"PyTorch: {o.shape}")

          mt = torch_tensorrt.compile(m, inputs=[t], truncate_long_and_double=True)
          o = mt(t)
          print(f"TRT: {o.shape}")
      except Exception as err:
          print(f"GatherND list indexing failed: {err}")

      try:
          m = GatherNDTensorModel()
          o = m(t)
          print(f"PyTorch: {o.shape}")

          mt = torch_tensorrt.compile(m, inputs=[t], truncate_long_and_double=True)
          o = mt(t)
          print(f"TRT: {o.shape}")
      except Exception as err:
          print(f"GatherND tensor indexing failed: {err}")

Outputs the following:

PyTorch: torch.Size([2, 4])
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Int64 to Int32
GatherND list indexing failed: isObject() INTERNAL ASSERT FAILED at "bazel-out/k8-opt/bin/external/libtorch/_virtual_includes/ATen/ATen/core/ivalue_inl.h":123, please report a bug to PyTorch. Expected Object but got None
PyTorch: torch.Size([2, 4])
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Int64 to Int32
GatherND tensor indexing failed: isObject() INTERNAL ASSERT FAILED at "bazel-out/k8-opt/bin/external/libtorch/_virtual_includes/ATen/ATen/core/ivalue_inl.h":123, please report a bug to PyTorch. Expected Object but got None

Note that the pytorch version (ie. no-TRT) model outputs the expected shapes, but the TRT conversion fails.

Expected behavior

The TRT compilation should succeed and output the same output as the non-TRT version.

Environment

Build information about Torch-TensorRT can be found by turning on debug messages
nvcr.io/nvidia/pytorch:22.07-py3

Additional context

https://partners.nvidia.com/Bug/ViewBug/3735309

Metadata

Metadata

Labels

bugSomething isn't workingcomponent: convertersIssues re: Specific op converters

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions