Skip to content

🐛 [Bug] Compilation failure for SSD300 model with dynamic batch #1555

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
gs-olive opened this issue Dec 16, 2022 · 3 comments
Closed

🐛 [Bug] Compilation failure for SSD300 model with dynamic batch #1555

gs-olive opened this issue Dec 16, 2022 · 3 comments
Assignees
Labels
bug Something isn't working

Comments

@gs-olive
Copy link
Collaborator

gs-olive commented Dec 16, 2022

Bug Description

When converting the SSD300 object detection network from TorchScript to Torch-TRT, the following error is encountered:

WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
ERROR: [Torch-TensorRT TorchScript Conversion Context] - 4: [graphShapeAnalyzer.cpp::analyzeShapes::1872] Error Code 4: Miscellaneous (IShuffleLayer %454 : Tensor = aten::reshape(%446, %451): reshape dimension with more than one -1 wildcard. Reshaping [(# 0 (SHAPE input_0)),16,38,38] to [-1,4,-1].)
ERROR: [Torch-TensorRT TorchScript Conversion Context] - 4: [graphShapeAnalyzer.cpp::analyzeShapes::1872] Error Code 4: Miscellaneous (IShuffleLayer %454 : Tensor = aten::reshape(%446, %451): reshape dimension with more than one -1 wildcard. Reshaping [(# 0 (SHAPE input_0)),16,38,38] to [-1,4,-1].)

The error arises from an input to aten::reshape which utilizes both the dynamic batch dimension, but also the aten::reshape wildcard -1. The fact that both dynamic batch and reshape wildcard use -1 is the cause of the bug.

This bug has also been demonstrated to impact some ResNet50 implementations with dynamic batch sizes.

Bug Source

The source of the error is this line:

shuffle->setReshapeDimensions(util::toDims(new_shape));

The input new_shape is [-1, 4, -1], implying the desired shape has 2 "implicit" dimensions, however this is not the case, as the first -1 indicates the batch dimension, while the second is an implicit dimension. Thus, the desired behavior is a reshape from:
$$[-1, 16, 38, 38] \Longrightarrow [-1, 4, 5776].$$

The necessary code modifications would be needed here:

for (size_t i = 0; i < new_shape.size(); i++) {
if (in_shape[i] == -1)
nbDynamicDims++;
}
if (nbDynamicDims > 1) {
TORCHTRT_THROW_ERROR(
"Resize is currently not supported when target shape contains more than one dynamic dimension");
}

A potential challenge here is determining which dimension in the reshape input dimensions corresponds to the batch dimension, and which corresponds to the implicit dimension.

Potential Resolution

Consider using a different value than -1 to represent dynamic dimensions, for example INT32_MIN, or some other value which cannot represent any reasonable shape in the original tensor.

To Reproduce

Steps to reproduce the behavior:

  1. Run torch_tensorrt.compile with SSD300 model as input, using fp32 precision.
  2. Choose dynamic input sizes: {"min": [1, 3, 300, 300], "opt": [16, 3, 300, 300], "max": [16, 3, 300, 300]} and enable truncate_long_and_double with 8 GB workspace.

Expected behavior

Model should successfully compile to Torch-TRT. Specifically, internal reshape dimensions with dynamic batch should resolve correctly.

Environment

  • Torch-TensorRT Version: 1.4.0.dev0+2ef6c3a5
  • PyTorch Version: 1.14.0.dev20221114+cu116
  • CPU Architecture: Intel Xeon CPU
  • OS: Ubuntu 20.04
  • How you installed PyTorch: pip
  • Build command you used: python setup.py develop
  • Are you using local sources or building from archives: local
  • Python version: 3.8.13
  • CUDA version: 11.6
@gs-olive gs-olive added the bug Something isn't working label Dec 16, 2022
@gs-olive gs-olive self-assigned this Dec 16, 2022
@gs-olive
Copy link
Collaborator Author

Updates

The same error on dynamic batch is not showing up when using the FX path for SSD300 or ResNet50. Notably, for both models, full compilation in TensorRT is supported in both TS and FX

@github-actions
Copy link

This issue has not seen activity for 90 days, Remove stale label or comment or this will be closed in 10 days

@gs-olive
Copy link
Collaborator Author

Fixed by #1851, when using allow_shape_tensors=True as a compilation argument.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant