You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The error arises from an input to aten::reshape which utilizes both the dynamic batch dimension, but also the aten::reshape wildcard -1. The fact that both dynamic batch and reshape wildcard use -1 is the cause of the bug.
This bug has also been demonstrated to impact some ResNet50 implementations with dynamic batch sizes.
The input new_shape is [-1, 4, -1], implying the desired shape has 2 "implicit" dimensions, however this is not the case, as the first -1 indicates the batch dimension, while the second is an implicit dimension. Thus, the desired behavior is a reshape from: $$[-1, 16, 38, 38] \Longrightarrow [-1, 4, 5776].$$
The necessary code modifications would be needed here:
"Resize is currently not supported when target shape contains more than one dynamic dimension");
}
A potential challenge here is determining which dimension in the reshape input dimensions corresponds to the batch dimension, and which corresponds to the implicit dimension.
Potential Resolution
Consider using a different value than -1 to represent dynamic dimensions, for example INT32_MIN, or some other value which cannot represent any reasonable shape in the original tensor.
To Reproduce
Steps to reproduce the behavior:
Run torch_tensorrt.compile with SSD300 model as input, using fp32 precision.
The same error on dynamic batch is not showing up when using the FX path for SSD300 or ResNet50. Notably, for both models, full compilation in TensorRT is supported in both TS and FX
Uh oh!
There was an error while loading. Please reload this page.
Bug Description
When converting the SSD300 object detection network from TorchScript to Torch-TRT, the following error is encountered:
The error arises from an input to
aten::reshape
which utilizes both the dynamic batch dimension, but also theaten::reshape
wildcard -1. The fact that both dynamic batch and reshape wildcard use -1 is the cause of the bug.This bug has also been demonstrated to impact some ResNet50 implementations with dynamic batch sizes.
Bug Source
The source of the error is this line:
TensorRT/core/conversion/converters/impl/shuffle.cpp
Line 90 in c63a5a5
The input
$$[-1, 16, 38, 38] \Longrightarrow [-1, 4, 5776].$$
new_shape
is[-1, 4, -1]
, implying the desired shape has 2 "implicit" dimensions, however this is not the case, as the first-1
indicates the batch dimension, while the second is an implicit dimension. Thus, the desired behavior is a reshape from:The necessary code modifications would be needed here:
TensorRT/core/conversion/converters/impl/shuffle.cpp
Lines 76 to 83 in c63a5a5
A potential challenge here is determining which dimension in the reshape input dimensions corresponds to the batch dimension, and which corresponds to the implicit dimension.
Potential Resolution
Consider using a different value than
-1
to represent dynamic dimensions, for exampleINT32_MIN
, or some other value which cannot represent any reasonable shape in the original tensor.To Reproduce
Steps to reproduce the behavior:
{"min": [1, 3, 300, 300], "opt": [16, 3, 300, 300], "max": [16, 3, 300, 300]}
and enable truncate_long_and_double with 8 GB workspace.Expected behavior
Model should successfully compile to Torch-TRT. Specifically, internal reshape dimensions with dynamic batch should resolve correctly.
Environment
python setup.py develop
The text was updated successfully, but these errors were encountered: