Skip to content

🐛 [Bug] Error when compiling Punctuation BERT model #1587

Closed
@gs-olive

Description

@gs-olive

Bug Description

When compiling the BERT punctuation/capitalization model, the following error is encountered:

RuntimeError: outputs_[i]->uses().empty() INTERNAL ASSERT FAILED at "../torch/csrc/jit/ir/ir.cpp":1312, please report a bug to PyTorch. 

To Reproduce

Steps to reproduce the behavior:

  1. Run torch_tensorrt.compile with the BERT punctuation model as input, using fp32 precision.
  2. Choose three fixed-size inputs of shape [1, 256], [1, 256], and [1, 256], and enable truncate_long_and_double with 12 GB workspace.

Expected behavior

Model should successfully compile with Torch-TRT. Specifically, internal assertion errors of this sort should not occur.

Environment

  • Torch-TensorRT Version: 1.4.0.dev0+f43be5b6
  • PyTorch Version: 1.14.0.dev20221114+cu116
  • CPU Architecture: Intel Xeon CPU
  • OS: Ubuntu 20.04
  • How you installed PyTorch: pip
  • Build command you used: python setup.py develop
  • Are you using local sources or building from archives: local
  • Python version: 3.8.13
  • CUDA version: 11.6

Additional context + Temporary Solution

Upon inspecting the error in greater depth, it seems to occur in the Lowering phase, during the RemoveDropout lowering pass:

void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph) {

The model was in evaluation mode, but it seems that one of the rewrites/replacements used in this lowering pass causes the above error. Disabling this lowering pass and recompiling the model causes the compilation to proceed normally.

Metadata

Metadata

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions