Skip to content

🐛 [Bug] aten.expand fails when rank disagrees with tensor shape #2183

Closed
@gs-olive

Description

@gs-olive

Bug Description

When the rank and tensor shape disagree, the torch.aten.ops.expand operator fails due to this portion of the code:

@tensorrt_converter(acc_ops.expand)
def acc_ops_expand_tensor(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_t = kwargs["input"]
shape = list(kwargs["sizes"])
input_val = get_trt_tensor(network, input_t, f"{name}_input")
if network.has_implicit_batch_dimension:
shape = shape[1:]
ranks = len(input_val.shape)
# TRT does not support different dimension size
assert len(shape) == ranks

This is not in agreement with Torch behavior, where calling .expand on a Tensor does not require that the expanded size have the same rank as the original Tensor. See documentation here.

import torch
x = torch.ones(2, 2)
y = x.expand([5, 5, 5, 5, -1, -1])
print(y.shape)
>>> torch.Size([5, 5, 5, 5, 2, 2])

This is the error message in the converter:

File "~/TensorRT/py/torch_tensorrt/fx/converters/acc_ops_converters.py", line 2475, in acc_ops_expand_tensor
     assert len(shape) == ranks
 torch._dynamo.exc.BackendCompilerFailed: backend='torch_tensorrt' raised:
 AssertionError: While executing %expand_1 : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%permute_3, [1, 512, 512]), kwargs =...

To Reproduce

See above code snippet for desired behavior from converter.

Expected behavior

Converter should succeed in this case.

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0): 8c62fca
  • PyTorch Version (e.g. 1.0): 2.1.0.dev20230803+cu121

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingcomponent: convertersIssues re: Specific op converters

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions