Skip to content

🐛 [Bug] aten.index converter doesn't work when index is a TRTTensor #2480

Closed
@peri044

Description

@peri044

Bug Description

aten index converter fails for data dependent shape operator. If we try to get the output shape of aten.index converter, it gives an error message as follows invalidating the INetwork.

To reproduce, use the https://github.com/pytorch/TensorRT/tree/dyn_2.2 branch as it has the implementation for sym_int converter which is required to run the reproducer.

Error:

gather_layer.get_output(0).shape # Print the output shape of this index converter
[11/21/2023-00:33:07] [TRT] [E] 3: [GATHER]-[aten_ops.index.Tensor]-[__/index_index_gather]: only kINT32 allowed for input 1 to this layer.
(1, 3, 4, 4, 3, 4, 4)

To Reproduce

import torch
import torch_tensorrt

class DDS(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.relu = torch.nn.ReLU()

    def forward(self, x, mask):
        out = x[mask]
        out = self.relu(out)
        return out

model = DDS().eval().cuda()
x = torch.randn(1, 3, 4, 4).cuda()
y = torch.rand((1, 3, 4, 4), device="cuda") < 0.9

trt_model = torch_tensorrt.compile(model, inputs=[x, y], 
                                    ir="dynamo",
                                    min_block_size=1, 
                                    debug=True)

Steps to reproduce the behavior:

Expected behavior

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0):
  • PyTorch Version (e.g. 1.0):
  • CPU Architecture:
  • OS (e.g., Linux):
  • How you installed PyTorch (conda, pip, libtorch, source):
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version:
  • CUDA version:
  • GPU models and configuration:
  • Any other relevant information:

Additional context

Metadata

Metadata

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions