Skip to content

Commit 7ffa87e

Browse files
TRT engine docs (#1396)
* updated docs * use optimum/gpt2
1 parent 85997c3 commit 7ffa87e

1 file changed

Lines changed: 13 additions & 20 deletions

File tree

  • docs/source/onnxruntime/usage_guides

docs/source/onnxruntime/usage_guides/gpu.mdx

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -291,29 +291,33 @@ We recommend setting these two provider options when using the TensorRT executio
291291
... )
292292
```
293293

294-
TensorRT builds its engine depending on specified input shapes. Unfortunately, in the [current ONNX Runtime implementation](https://github.com/microsoft/onnxruntime/blob/613920d6c5f53a8e5e647c5f1dcdecb0a8beef31/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L1677-L1688) (references: [1](https://github.com/microsoft/onnxruntime/issues/13559), [2](https://github.com/microsoft/onnxruntime/issues/13851)), the engine is rebuilt every time an input has a shape smaller than the previously smallest encountered shape, and conversely if the input has a shape larger than the previously largest encountered shape. For example, if a model takes `(batch_size, input_ids)` as inputs, and the model takes successively the inputs:
294+
TensorRT builds its engine depending on specified input shapes. One big issue is that building the engine can be time consuming, especially for large models. Therefore, as a workaround, one recommendation is to build the TensorRT engine with dynamic shapes. This allows to avoid rebuilding the engine for new small and large shapes, which is unwanted once the model is deployed for inference.
295295

296-
1. `input.shape: (4, 5) --> the engine is built (first input)`
297-
2. `input.shape: (4, 10) --> engine rebuilt (10 larger than 5)`
298-
3. `input.shape: (4, 7) --> no rebuild (5 <= 7 <= 10)`
299-
4. `input.shape: (4, 12) --> engine rebuilt (10 <= 12)`
300-
5. `input.shape: (4, 3) --> engine rebuilt (3 <= 5)`
296+
To do so we use the provider's options `trt_profile_min_shapes`, `trt_profile_max_shapes` and `trt_profile_opt_shapes` to specify the minimum, maximum and optimal shapes for the engine. For example, for GPT2, we can use the following shapes:
301297

302-
One big issue is that building the engine can be time consuming, especially for large models. Therefore, as a workaround, one recommendation is to **first build the TensorRT engine with an input of small shape, and then with an input of large shape to have an engine valid for all shapes inbetween**. This allows to avoid rebuilding the engine for new small and large shapes, which is unwanted once the model is deployed for inference.
298+
```python
299+
provider_options = {
300+
"trt_profile_min_shapes": "input_ids:1x1,attention_mask:1x1,position_ids:1x1",
301+
"trt_profile_opt_shapes": "input_ids:1x1,attention_mask:1x1,position_ids:1x1",
302+
"trt_profile_max_shapes": "input_ids:1x64,attention_mask:1x64,position_ids:1x64",
303+
}
304+
```
303305

304306
Passing the engine cache path in the provider options, the engine can therefore be built once for all and used fully for inference thereafter.
305307

306308
For example, for text generation, the engine can be built with:
307309

308310
```python
309311
>>> import os
310-
>>> from transformers import AutoTokenizer
311312
>>> from optimum.onnxruntime import ORTModelForCausalLM
312313

313314
>>> os.makedirs("tmp/trt_cache_gpt2_example", exist_ok=True)
314315
>>> provider_options = {
315316
... "trt_engine_cache_enable": True,
316-
... "trt_engine_cache_path": "tmp/trt_cache_gpt2_example"
317+
... "trt_engine_cache_path": "tmp/trt_cache_gpt2_example",
318+
... "trt_profile_min_shapes": "input_ids:1x1,attention_mask:1x1,position_ids:1x1",
319+
... "trt_profile_opt_shapes": "input_ids:1x1,attention_mask:1x1,position_ids:1x1",
320+
... "trt_profile_max_shapes": "input_ids:1x64,attention_mask:1x64,position_ids:1x64",
317321
... }
318322

319323
>>> ort_model = ORTModelForCausalLM.from_pretrained(
@@ -322,17 +326,6 @@ For example, for text generation, the engine can be built with:
322326
... provider="TensorrtExecutionProvider",
323327
... provider_options=provider_options,
324328
... )
325-
>>> tokenizer = AutoTokenizer.from_pretrained("optimum/gpt2")
326-
327-
>>> print("Building engine for a short sequence...") # doctest: +IGNORE_RESULT
328-
>>> text = ["short"]
329-
>>> encoded_input = tokenizer(text, return_tensors="pt").to("cuda")
330-
>>> output = ort_model(**encoded_input)
331-
332-
>>> print("Building engine for a long sequence...") # doctest: +IGNORE_RESULT
333-
>>> text = [" a very long input just for demo purpose, this is very long" * 10]
334-
>>> encoded_input = tokenizer(text, return_tensors="pt").to("cuda")
335-
>>> output = ort_model(**encoded_input)
336329
```
337330

338331
The engine is stored as:

0 commit comments

Comments
 (0)