Closed
Description
Bug Description
Returning a list of tensors fails when ops are applied to the tensors prior to appending them to the list that is returned.
This is not the case if tensors are directly appended to the list without applying any operations.
RuntimeError: [Error thrown at core/conversion/conversion.cpp:220] List type. Only a single tensor or a TensorList type is supported.
To Reproduce
Run the following:
import torch
import torch_tensorrt as torchtrt
import torch_tensorrt.logging as logging
logging.set_reportable_log_level(logging.Level.Info)
torch.manual_seed(0)
DEVICE = torch.device("cuda:0")
SHAPE = (1, 2)
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
tensors = []
for i in range(3):
y = x + x
tensors.append(y)
return tensors
if __name__ == "__main__":
tensor = torch.randn(SHAPE, dtype=torch.float32, device=DEVICE)
model = Model().eval().to(DEVICE)
out = model(tensor)
print(out)
model_trt = torchtrt.compile(
model,
inputs=[
torchtrt.Input(shape=SHAPE),
],
enabled_precisions={torch.float},
)
out_trt = model(tensor)
print(out_trt)
This throws the following error:
(trtorch-1.0) ~/av-dbg/experimental/chaoz/trtorch (chaoz/trtorch-experiments) $ python index.py
[tensor([[-1.8493, -0.8507]], device='cuda:0'), tensor([[-1.8493, -0.8507]], device='cuda:0'), tensor([[-1.8493, -0.8507]], device='cuda:0')]
INFO: [Torch-TensorRT] - ir was set to default, using TorchScript as ir
INFO: [Torch-TensorRT] - Module was provided as a torch.nn.Module, trying to script the module with torch.jit.script. In the event of a failure please preconvert your module to TorchScript
INFO: [Torch-TensorRT] - Lowered Graph: graph(%x.1 : Tensor):
%2 : int = prim::Constant[value=1]()
%y.1 : Tensor = aten::add(%x.1, %x.1, %2) # /home/chaoz/av-dbg/experimental/chaoz/trtorch/index.py:24:16
%y.2 : Tensor = aten::add(%x.1, %x.1, %2) # /home/chaoz/av-dbg/experimental/chaoz/trtorch/index.py:24:16
%y.4 : Tensor = aten::add(%x.1, %x.1, %2) # /home/chaoz/av-dbg/experimental/chaoz/trtorch/index.py:24:16
%tensors.1 : Tensor[] = prim::ListConstruct(%y.1, %y.2, %y.4)
return (%tensors.1)
WARNING: [Torch-TensorRT] - Cannot infer input type from calcuations in graph for input x.1. Assuming it is Float32. If not, specify input type explicity
INFO: [Torch-TensorRT] - Skipping partitioning since model is fully supported
INFO: [Torch-TensorRT TorchScript Conversion Context] - [MemUsageChange] Init CUDA: CPU +449, GPU +0, now: CPU 3411, GPU 1873 (MiB)
INFO: [Torch-TensorRT TorchScript Conversion Context] - [MemUsageSnapshot] Begin constructing builder kernel library: CPU 3411 MiB, GPU 1873 MiB
INFO: [Torch-TensorRT TorchScript Conversion Context] - [MemUsageSnapshot] End constructing builder kernel library: CPU 3565 MiB, GPU 1915 MiB
INFO: [Torch-TensorRT] - Settings requested for TensorRT engine:
Enabled Precisions: Float32
TF32 Floating Point Computation Enabled: 1
Truncate Long and Double: 0
Make Refittable Engine: 0
Debuggable Engine: 0
Strict Types: 0
GPU ID: 0
Allow GPU Fallback (if running on DLA): 0
Min Timing Iterations: 2
Avg Timing Iterations: 1
Max Workspace Size: 1073741824
Max Batch Size: Not set
Device Type: GPU
GPU ID: 0
Engine Capability: standard
Calibrator Created: 0
INFO: [Torch-TensorRT TorchScript Conversion Context] - Converting Block
INFO: [Torch-TensorRT TorchScript Conversion Context] - Adding Input x.1 (named: input_0): Input(shape: [1, 2], dtype: Float32, format: NCHW\Contiguous\Linear) in engine (conversion.AddInputs)
INFO: [Torch-TensorRT TorchScript Conversion Context] - Adding Layer %y.1 : Tensor = aten::add(%x.1, %x.1, %2) # /home/chaoz/av-dbg/experimental/chaoz/trtorch/index.py:24:16 (ctx.AddLayer)
INFO: [Torch-TensorRT TorchScript Conversion Context] - Adding Layer %y.2 : Tensor = aten::add(%x.1, %x.1, %2) # /home/chaoz/av-dbg/experimental/chaoz/trtorch/index.py:24:16 (ctx.AddLayer)
INFO: [Torch-TensorRT TorchScript Conversion Context] - Adding Layer %y.4 : Tensor = aten::add(%x.1, %x.1, %2) # /home/chaoz/av-dbg/experimental/chaoz/trtorch/index.py:24:16 (ctx.AddLayer)
Traceback (most recent call last):
File "/home/chaoz/av-dbg/experimental/chaoz/trtorch/index.py", line 37, in <module>
model_trt = torchtrt.compile(
File "/home/chaoz/.anaconda3/envs/trtorch-1.0/lib/python3.9/site-packages/torch_tensorrt/_compile.py", line 97, in compile
return torch_tensorrt.ts.compile(ts_mod, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs)
File "/home/chaoz/.anaconda3/envs/trtorch-1.0/lib/python3.9/site-packages/torch_tensorrt/ts/_compiler.py", line 119, in compile
compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
RuntimeError: [Error thrown at core/conversion/conversion.cpp:220] List type. Only a single tensor or a TensorList type is supported.
Expected behavior
Graph should return a list of tensors without errors.
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
- Torch-TensorRT Version (e.g. 1.0.0): 1.0
- PyTorch Version (e.g. 1.0): 1.10.2
- CPU Architecture: x86-64
- OS (e.g., Linux): Ubuntu 18.04
- How you installed PyTorch (
conda
,pip
,libtorch
, source): Conda - Build command you used (if compiling from source):
- Are you using local sources or building from archives: local
- Python version: 3.9
- CUDA version: 11.6
- GPU models and configuration: Nvidia A10
- Any other relevant information:
Additional context
Note that changing the forward function to the following definition:
def forward(self, x):
tensors = []
for i in range(3):
# y = x + x
tensors.append(x)
return tensors
will succeed with the following output:
(trtorch-1.0) ~/av-dbg/experimental/chaoz/trtorch (chaoz/trtorch-experiments) $ python index.py
[tensor([[-0.9247, -0.4253]], device='cuda:0'), tensor([[-0.9247, -0.4253]], device='cuda:0'), tensor([[-0.9247, -0.4253]], device='cuda:0')]
INFO: [Torch-TensorRT] - ir was set to default, using TorchScript as ir
INFO: [Torch-TensorRT] - Module was provided as a torch.nn.Module, trying to script the module with torch.jit.script. In the event of a failure please preconvert your module to TorchScript
INFO: [Torch-TensorRT] - Lowered Graph: graph(%x.1 : Tensor):
%tensors.1 : Tensor[] = prim::ListConstruct(%x.1, %x.1, %x.1)
return (%tensors.1)
WARNING: [Torch-TensorRT] - Cannot infer input type from calcuations in graph for input x.1. Assuming it is Float32. If not, specify input type explicity
ERROR: [Torch-TensorRT] - Method requested cannot be compiled by Torch-TensorRT.TorchScript.
There is no work to be done since the resulting compiled program will contain an engine that is empty.
This may be because there are no operators that can be added to the TensorRT graph or all operators have a resolved compile time value.
WARNING: [Torch-TensorRT] - Input type for doing shape analysis could not be determined, defaulting to F32
INFO: [Torch-TensorRT] - Partitioned Graph: []
INFO: [Torch-TensorRT] - Segmented Graph: graph(%x.1 : Tensor):
return ()
WARNING: [Torch-TensorRT] - Didn't generate any TensorRT engines, the compiler did nothing
[tensor([[-0.9247, -0.4253]], device='cuda:0'), tensor([[-0.9247, -0.4253]], device='cuda:0'), tensor([[-0.9247, -0.4253]], device='cuda:0')]