Skip to content

Commit 1efe4b1

Browse files
authored
Merge pull request #1254 from pytorch/perf_changes
feat(//tools/perf): Refactor perf_run.py, add fx2trt backend support, usage via CLI arguments
2 parents 7142c82 + 77543a0 commit 1efe4b1

File tree

7 files changed

+633
-131
lines changed

7 files changed

+633
-131
lines changed

tools/perf/README.md

Lines changed: 70 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@ This is a comprehensive Python benchmark suite to run perf runs using different
44

55
1. Torch
66
2. Torch-TensorRT
7-
3. TensorRT
7+
3. FX-TRT
8+
4. TensorRT
9+
810

911
Note: Please note that for ONNX models, user can convert the ONNX model to TensorRT serialized engine and then use this package.
1012

@@ -25,21 +27,35 @@ Benchmark scripts depends on following Python packages in addition to requiremen
2527
│ └── vgg16.yml
2628
├── models
2729
├── perf_run.py
30+
├── hub.py
31+
├── custom_models.py
32+
├── requirements.txt
33+
├── benchmark.sh
2834
└── README.md
2935
```
3036

31-
Please save your configuration files at config directory. Similarly, place your model files at models path.
37+
38+
39+
* `config` - Directory which contains sample yaml configuration files for VGG network.
40+
* `models` - Model directory
41+
* `perf_run.py` - Performance benchmarking script which supports torch, torch_tensorrt, fx2trt, tensorrt backends
42+
* `hub.py` - Script to download torchscript models for VGG16, Resnet50, EfficientNet-B0, VIT, HF-BERT
43+
* `custom_models.py` - Script which includes custom models other than torchvision and timm (eg: HF BERT)
44+
* `utils.py` - utility functions script
45+
* `benchmark.sh` - This is used for internal performance testing of VGG16, Resnet50, EfficientNet-B0, VIT, HF-BERT.
3246

3347
## Usage
3448

49+
There are two ways you can run a performance benchmark.
50+
51+
### Using YAML config files
52+
3553
To run the benchmark for a given configuration file:
3654

37-
```
55+
```python
3856
python perf_run.py --config=config/vgg16.yml
3957
```
4058

41-
## Configuration
42-
4359
There are two sample configuration files added.
4460

4561
* vgg16.yml demonstrates a configuration with all the supported backends (Torch, Torch-TensorRT, TensorRT)
@@ -48,23 +64,17 @@ There are two sample configuration files added.
4864

4965
### Supported fields
5066

51-
| Name | Supported Values | Description |
52-
| --- | --- | --- |
53-
| backend | all, torch, torch_tensorrt, tensorrt | Supported backends for inference. |
54-
| input | - | Input binding names. Expected to list shapes of each input bindings |
55-
| model | - | Configure the model filename and name |
56-
| filename | - | Model file name to load from disk. |
57-
| name | - | Model name |
58-
| runtime | - | Runtime configurations |
59-
| device | 0 | Target device ID to run inference. Range depends on available GPUs |
60-
| precision | fp32, fp16 or half, int8 | Target precision to run inference. int8 cannot be used with 'all' backend |
61-
| calibration_cache | - | Calibration cache file expected for torch_tensorrt runtime in int8 precision |
62-
63-
Note:
64-
1. Please note that torch runtime perf is not supported for int8 yet.
65-
2. Torchscript module filename should end with .jit.pt otherwise it will be treated as a TensorRT engine.
66-
67-
67+
| Name | Supported Values | Description |
68+
| ----------------- | ------------------------------------ | ------------------------------------------------------------ |
69+
| backend | all, torch, torch_tensorrt, tensorrt | Supported backends for inference. |
70+
| input | - | Input binding names. Expected to list shapes of each input bindings |
71+
| model | - | Configure the model filename and name |
72+
| filename | - | Model file name to load from disk. |
73+
| name | - | Model name |
74+
| runtime | - | Runtime configurations |
75+
| device | 0 | Target device ID to run inference. Range depends on available GPUs |
76+
| precision | fp32, fp16 or half, int8 | Target precision to run inference. int8 cannot be used with 'all' backend |
77+
| calibration_cache | - | Calibration cache file expected for torch_tensorrt runtime in int8 precision |
6878

6979
Additional sample use case:
7080

@@ -88,3 +98,41 @@ runtime:
8898
- fp32
8999
- fp16
90100
```
101+
102+
Note:
103+
104+
1. Please note that measuring INT8 performance is only supported via a `calibration cache` file or QAT mode for `torch_tensorrt` backend.
105+
2. TensorRT engine filename should end with `.plan` otherwise it will be treated as Torchscript module.
106+
107+
### Using CompileSpec options via CLI
108+
109+
Here are the list of `CompileSpec` options that can be provided directly to compile the pytorch module
110+
111+
* `--backends` : Comma separated string of backends. Eg: torch,torch_tensorrt, tensorrt or fx2trt
112+
* `--model` : Name of the model file (Can be a torchscript module or a tensorrt engine (ending in `.plan` extension)). If the backend is `fx2trt`, the input should be a Pytorch module (instead of a torchscript module) and the options for model are (`vgg16` | `resnet50` | `efficientnet_b0`)
113+
* `--inputs` : List of input shapes & dtypes. Eg: (1, 3, 224, 224)@fp32 for Resnet or (1, 128)@int32;(1, 128)@int32 for BERT
114+
* `--batch_size` : Batch size
115+
* `--precision` : Comma separated list of precisions to build TensorRT engine Eg: fp32,fp16
116+
* `--device` : Device ID
117+
* `--truncate` : Truncate long and double weights in the network in Torch-TensorRT
118+
* `--is_trt_engine` : Boolean flag to be enabled if the model file provided is a TensorRT engine.
119+
* `--report` : Path of the output file where performance summary is written.
120+
121+
Eg:
122+
123+
```
124+
python perf_run.py --model ${MODELS_DIR}/vgg16_scripted.jit.pt \
125+
--precision fp32,fp16 --inputs="(1, 3, 224, 224)@fp32" \
126+
--batch_size 1 \
127+
--backends torch,torch_tensorrt,tensorrt \
128+
--report "vgg_perf_bs1.txt"
129+
```
130+
131+
### Example models
132+
133+
This tool benchmarks any pytorch model or torchscript module. As an example, we provide VGG16, Resnet50, EfficientNet-B0, VIT, HF-BERT models in `hub.py` that we internally test for performance.
134+
The torchscript modules for these models can be generated by running
135+
```
136+
python hub.py
137+
```
138+
You can refer to `benchmark.sh` on how we run/benchmark these models.

tools/perf/benchmark.sh

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
#!/bin/bash
2+
3+
MODELS_DIR="models"
4+
5+
# Download the Torchscript models
6+
python hub.py
7+
8+
batch_sizes=(1 2 4 8 16 32 64 128 256)
9+
10+
#Benchmark VGG16 model
11+
echo "Benchmarking VGG16 model"
12+
for bs in ${batch_sizes[@]}
13+
do
14+
python perf_run.py --model ${MODELS_DIR}/vgg16_scripted.jit.pt \
15+
--precision fp32,fp16 --inputs="(${bs}, 3, 224, 224)" \
16+
--batch_size ${bs} \
17+
--backends torch,torch_tensorrt,tensorrt \
18+
--report "vgg_perf_bs${bs}.txt"
19+
done
20+
21+
# Benchmark Resnet50 model
22+
echo "Benchmarking Resnet50 model"
23+
for bs in ${batch_sizes[@]}
24+
do
25+
python perf_run.py --model ${MODELS_DIR}/resnet50_scripted.jit.pt \
26+
--precision fp32,fp16 --inputs="(${bs}, 3, 224, 224)" \
27+
--batch_size ${bs} \
28+
--backends torch,torch_tensorrt,tensorrt \
29+
--report "rn50_perf_bs${bs}.txt"
30+
done
31+
32+
# Benchmark VIT model
33+
echo "Benchmarking VIT model"
34+
for bs in ${batch_sizes[@]}
35+
do
36+
python perf_run.py --model ${MODELS_DIR}/vit_scripted.jit.pt \
37+
--precision fp32,fp16 --inputs="(${bs}, 3, 224, 224)" \
38+
--batch_size ${bs} \
39+
--backends torch,torch_tensorrt,tensorrt \
40+
--report "vit_perf_bs${bs}.txt"
41+
done
42+
43+
# Benchmark EfficientNet-B0 model
44+
echo "Benchmarking EfficientNet-B0 model"
45+
for bs in ${batch_sizes[@]}
46+
do
47+
python perf_run.py --model ${MODELS_DIR}/efficientnet_b0_scripted.jit.pt \
48+
--precision fp32,fp16 --inputs="(${bs}, 3, 224, 224)" \
49+
--batch_size ${bs} \
50+
--backends torch,torch_tensorrt,tensorrt \
51+
--report "eff_b0_perf_bs${bs}.txt"
52+
done
53+
54+
# Benchmark BERT model
55+
echo "Benchmarking Huggingface BERT base model"
56+
for bs in ${batch_sizes[@]}
57+
do
58+
python perf_run.py --model ${MODELS_DIR}/bert_base_uncased_traced.jit.pt \
59+
--precision fp32 --inputs="(${bs}, 128)@int32;(${bs}, 128)@int32" \
60+
--batch_size ${bs} \
61+
--backends torch,torch_tensorrt \
62+
--truncate \
63+
--report "bert_base_perf_bs${bs}.txt"
64+
done

tools/perf/config/vgg16.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@ input:
88
- 224
99
- 224
1010
num_inputs: 1
11+
batch_size: 1
1112
model:
12-
filename: models/vgg16_traced.jit.pt
13+
filename: models/vgg16_scripted.jit.pt
1314
name: vgg16
1415
runtime:
1516
device: 0

tools/perf/custom_models.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import torch
2+
import torch.nn as nn
3+
from transformers import BertModel, BertTokenizer, BertConfig
4+
import torch.nn.functional as F
5+
6+
7+
def BertModule():
8+
model_name = "bert-base-uncased"
9+
enc = BertTokenizer.from_pretrained(model_name)
10+
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
11+
tokenized_text = enc.tokenize(text)
12+
masked_index = 8
13+
tokenized_text[masked_index] = "[MASK]"
14+
indexed_tokens = enc.convert_tokens_to_ids(tokenized_text)
15+
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]
16+
tokens_tensor = torch.tensor([indexed_tokens])
17+
segments_tensors = torch.tensor([segments_ids])
18+
config = BertConfig(
19+
vocab_size_or_config_json_file=32000,
20+
hidden_size=768,
21+
num_hidden_layers=12,
22+
num_attention_heads=12,
23+
intermediate_size=3072,
24+
torchscript=True,
25+
)
26+
model = BertModel(config)
27+
model.eval()
28+
model = BertModel.from_pretrained(model_name, torchscript=True)
29+
traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors])
30+
return traced_model

tools/perf/hub.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
import torchvision.models as models
5+
import timm
6+
from transformers import BertModel, BertTokenizer, BertConfig
7+
import os
8+
import json
9+
import custom_models as cm
10+
11+
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
12+
13+
torch_version = torch.__version__
14+
15+
# Detect case of no GPU before deserialization of models on GPU
16+
if not torch.cuda.is_available():
17+
raise Exception(
18+
"No GPU found. Please check if installed torch version is compatible with CUDA version"
19+
)
20+
21+
# Downloads all model files again if manifest file is not present
22+
MANIFEST_FILE = "model_manifest.json"
23+
24+
BENCHMARK_MODELS = {
25+
"vgg16": {"model": models.vgg16(weights=None), "path": "script"},
26+
"resnet50": {"model": models.resnet50(weights=None), "path": "script"},
27+
"efficientnet_b0": {
28+
"model": timm.create_model("efficientnet_b0", pretrained=True),
29+
"path": "script",
30+
},
31+
"vit": {
32+
"model": timm.create_model("vit_base_patch16_224", pretrained=True),
33+
"path": "script",
34+
},
35+
"bert_base_uncased": {"model": cm.BertModule(), "path": "trace"},
36+
}
37+
38+
39+
def get(n, m, manifest):
40+
print("Downloading {}".format(n))
41+
traced_filename = "models/" + n + "_traced.jit.pt"
42+
script_filename = "models/" + n + "_scripted.jit.pt"
43+
x = torch.ones((1, 3, 300, 300)).cuda()
44+
if n == "bert-base-uncased":
45+
traced_model = m["model"]
46+
torch.jit.save(traced_model, traced_filename)
47+
manifest.update({n: [traced_filename]})
48+
else:
49+
m["model"] = m["model"].eval().cuda()
50+
if m["path"] == "both" or m["path"] == "trace":
51+
trace_model = torch.jit.trace(m["model"], [x])
52+
torch.jit.save(trace_model, traced_filename)
53+
manifest.update({n: [traced_filename]})
54+
if m["path"] == "both" or m["path"] == "script":
55+
script_model = torch.jit.script(m["model"])
56+
torch.jit.save(script_model, script_filename)
57+
if n in manifest.keys():
58+
files = list(manifest[n]) if type(manifest[n]) != list else manifest[n]
59+
files.append(script_filename)
60+
manifest.update({n: files})
61+
else:
62+
manifest.update({n: [script_filename]})
63+
return manifest
64+
65+
66+
def download_models(version_matches, manifest):
67+
# Download all models if torch version is different than model version
68+
if not version_matches:
69+
for n, m in BENCHMARK_MODELS.items():
70+
manifest = get(n, m, manifest)
71+
else:
72+
for n, m in BENCHMARK_MODELS.items():
73+
scripted_filename = "models/" + n + "_scripted.jit.pt"
74+
traced_filename = "models/" + n + "_traced.jit.pt"
75+
# Check if model file exists on disk
76+
if (
77+
(
78+
m["path"] == "both"
79+
and os.path.exists(scripted_filename)
80+
and os.path.exists(traced_filename)
81+
)
82+
or (m["path"] == "script" and os.path.exists(scripted_filename))
83+
or (m["path"] == "trace" and os.path.exists(traced_filename))
84+
):
85+
print("Skipping {} ".format(n))
86+
continue
87+
manifest = get(n, m, manifest)
88+
89+
90+
def main():
91+
manifest = None
92+
version_matches = False
93+
manifest_exists = False
94+
95+
# Check if Manifest file exists or is empty
96+
if not os.path.exists(MANIFEST_FILE) or os.stat(MANIFEST_FILE).st_size == 0:
97+
manifest = {"version": torch_version}
98+
99+
# Creating an empty manifest file for overwriting post setup
100+
os.system("touch {}".format(MANIFEST_FILE))
101+
else:
102+
manifest_exists = True
103+
104+
# Load manifest if already exists
105+
with open(MANIFEST_FILE, "r") as f:
106+
manifest = json.load(f)
107+
if manifest["version"] == torch_version:
108+
version_matches = True
109+
else:
110+
print(
111+
"Torch version: {} mismatches \
112+
with manifest's version: {}. Re-downloading \
113+
all models".format(
114+
torch_version, manifest["version"]
115+
)
116+
)
117+
118+
# Overwrite the manifest version as current torch version
119+
manifest["version"] = torch_version
120+
121+
download_models(version_matches, manifest)
122+
123+
# Write updated manifest file to disk
124+
with open(MANIFEST_FILE, "r+") as f:
125+
data = f.read()
126+
f.seek(0)
127+
record = json.dumps(manifest)
128+
f.write(record)
129+
f.truncate()
130+
131+
132+
main()

0 commit comments

Comments
 (0)