forked from NVIDIA/Model-Optimizer
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathevaluate.py
More file actions
127 lines (112 loc) · 4.18 KB
/
evaluate.py
File metadata and controls
127 lines (112 loc) · 4.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import csv
import timm
from evaluation import evaluate
from modelopt.torch._deploy._runtime import RuntimeRegistry
from modelopt.torch._deploy.device_model import DeviceModel
from modelopt.torch._deploy.utils import OnnxBytes
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--onnx_path",
type=str,
required=True,
help="""Path to the image classification ONNX model with input shape of
[batch_size,3,224,224] and output shape of [1,1000]""",
)
parser.add_argument(
"--model_name",
type=str,
required=True,
help="use for timm.create_model to load data config",
)
parser.add_argument(
"--engine_path",
type=str,
default=None,
help="Path to the TensorRT engine",
)
parser.add_argument(
"--timing_cache_path",
type=str,
default=None,
help="Path to the TensorRT timing cache",
)
parser.add_argument(
"--imagenet_path",
type=str,
default="ILSVRC/imagenet-1k",
help="HF dataset card or local path to the ImageNet dataset",
)
parser.add_argument("--batch_size", type=int, default=1, help="Batch size for evaluation")
parser.add_argument(
"--eval_data_size", type=int, default=None, help="Number of examples to evaluate"
)
parser.add_argument(
"--engine_precision",
type=str,
default="stronglyTyped",
choices=["best", "fp16", "stronglyTyped"],
help="Precision mode for the TensorRT engine. \
stronglyTyped is recommended, all other modes have been deprecated in TensorRT",
)
parser.add_argument(
"--results_path", type=str, default=None, help="Save the results to the specified path"
)
args = parser.parse_args()
deployment = {
"runtime": "TRT",
"precision": args.engine_precision,
}
# Create an ONNX bytes object with the specified path
onnx_bytes = OnnxBytes(args.onnx_path).to_bytes()
# Get the runtime client
client = RuntimeRegistry.get(deployment)
# Compile the ONNX model to TRT engine and create the device model
compilation_args = {
"engine_path": args.engine_path,
"timing_cache_path": args.timing_cache_path,
}
compiled_model = client.ir_to_compiled(onnx_bytes, compilation_args)
device_model = DeviceModel(client, compiled_model, metadata={})
top1_accuracy, top5_accuracy = 0.0, 0.0
model = timm.create_model(args.model_name, pretrained=False, num_classes=1000)
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)
top1_accuracy, top5_accuracy = evaluate(
device_model,
transforms,
batch_size=args.batch_size,
num_examples=args.eval_data_size,
dataset_path=args.imagenet_path,
)
print(f"The top1 accuracy of the model is {top1_accuracy}%")
print(f"The top5 accuracy of the model is {top5_accuracy}%")
latency = device_model.get_latency()
print(f"Inference latency of the model is {latency} ms")
if args.results_path:
results: list[list[str | float]] = [
["Metric", "Value"],
["Top 1", top1_accuracy],
["Top 5", top5_accuracy],
["Latency", latency],
]
with open(args.results_path, "w", newline="") as csvfile:
writer = csv.writer(csvfile)
writer.writerows(results)
if __name__ == "__main__":
main()