From 35b51cc52de97f912f98c69a54089750957d0dea Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Wed, 1 Nov 2023 11:46:04 -0700 Subject: [PATCH] feat/fix: Add new models, fix perf scripts - Add new key models to benchmarking scripts - Add fixes and improvements to existing benchmarking code --- .github/workflows/docker_builder.yml | 1 + tools/perf/benchmark.sh | 60 +++++++++++++++--- tools/perf/custom_models.py | 35 ++++++----- tools/perf/hub.py | 92 ++++++++++------------------ tools/perf/perf_run.py | 62 +++++++++++++------ tools/perf/requirements.txt | 2 + tools/perf/utils.py | 33 ++++++++-- 7 files changed, 178 insertions(+), 107 deletions(-) diff --git a/.github/workflows/docker_builder.yml b/.github/workflows/docker_builder.yml index 817bc87c82..99b6efe53e 100644 --- a/.github/workflows/docker_builder.yml +++ b/.github/workflows/docker_builder.yml @@ -6,6 +6,7 @@ on: branches: - main - nightly + - release/2.1 # If pushes to main are made in rapid succession, # cancel existing docker builds and use newer commits diff --git a/tools/perf/benchmark.sh b/tools/perf/benchmark.sh index 870f5352cc..78737e685f 100644 --- a/tools/perf/benchmark.sh +++ b/tools/perf/benchmark.sh @@ -6,8 +6,10 @@ MODELS_DIR="models" python hub.py batch_sizes=(1 2 4 8 16 32 64 128 256) +large_model_batch_sizes=(1 2 4 8 16 32 64) -#Benchmark VGG16 model + +# Benchmark VGG16 model echo "Benchmarking VGG16 model" for bs in ${batch_sizes[@]} do @@ -15,11 +17,25 @@ do --model_torch ${MODELS_DIR}/vgg16_pytorch.pt \ --precision fp32,fp16 --inputs="(${bs}, 3, 224, 224)" \ --batch_size ${bs} \ + --truncate \ + --backends torch,ts_trt,dynamo,torch_compile,inductor \ + --report "vgg16_perf_bs${bs}.txt" +done + +# Benchmark AlexNet model +echo "Benchmarking AlexNet model" +for bs in ${batch_sizes[@]} +do + python perf_run.py --model ${MODELS_DIR}/alexnet_scripted.jit.pt \ + --model_torch ${MODELS_DIR}/alexnet_pytorch.pt \ + --precision fp32,fp16 --inputs="(${bs}, 3, 227, 227)" \ + --batch_size ${bs} \ + --truncate \ --backends torch,ts_trt,dynamo,torch_compile,inductor \ - --report "vgg_perf_bs${bs}.txt" + --report "alexnet_perf_bs${bs}.txt" done -# # Benchmark Resnet50 model +# Benchmark Resnet50 model echo "Benchmarking Resnet50 model" for bs in ${batch_sizes[@]} do @@ -27,22 +43,37 @@ do --model_torch ${MODELS_DIR}/resnet50_pytorch.pt \ --precision fp32,fp16 --inputs="(${bs}, 3, 224, 224)" \ --batch_size ${bs} \ + --truncate \ --backends torch,ts_trt,dynamo,torch_compile,inductor \ - --report "rn50_perf_bs${bs}.txt" + --report "resnet50_perf_bs${bs}.txt" done -# # Benchmark VIT model +# Benchmark VIT model echo "Benchmarking VIT model" for bs in ${batch_sizes[@]} do python perf_run.py --model ${MODELS_DIR}/vit_scripted.jit.pt \ + --model_torch ${MODELS_DIR}/vit_pytorch.pt \ --precision fp32,fp16 --inputs="(${bs}, 3, 224, 224)" \ --batch_size ${bs} \ + --truncate \ --backends torch,ts_trt,dynamo,torch_compile,inductor \ --report "vit_perf_bs${bs}.txt" done -# # Benchmark EfficientNet-B0 model +# Benchmark VIT Large model +echo "Benchmarking VIT Large model" +for bs in ${large_model_batch_sizes[@]} +do + python perf_run.py --model ${MODELS_DIR}/vit_large_scripted.jit.pt \ + --model_torch ${MODELS_DIR}/vit_large_pytorch.pt \ + --precision fp32,fp16 --inputs="(${bs}, 3, 224, 224)" \ + --truncate \ + --batch_size ${bs} \ + --backends torch,ts_trt,dynamo,torch_compile,inductor \ + --report "vit_large_perf_bs${bs}.txt" + +# Benchmark EfficientNet-B0 model echo "Benchmarking EfficientNet-B0 model" for bs in ${batch_sizes[@]} do @@ -50,8 +81,21 @@ do --model_torch ${MODELS_DIR}/efficientnet_b0_pytorch.pt \ --precision fp32,fp16 --inputs="(${bs}, 3, 224, 224)" \ --batch_size ${bs} \ + --truncate \ --backends torch,ts_trt,dynamo,torch_compile,inductor \ - --report "eff_b0_perf_bs${bs}.txt" + --report "efficientnet_b0_perf_bs${bs}.txt" +done + +# Benchmark Stable Diffusion UNet model +echo "Benchmarking SD UNet model" +for bs in ${large_model_batch_sizes[@]} +do + python perf_run.py --model_torch ${MODELS_DIR}/sd_unet_pytorch.pt \ + --precision fp32,fp16 --inputs="(${bs}, 4, 128, 128)@fp16;(${bs})@fp16;(${bs}, 1, 768)@fp16" \ + --batch_size ${bs} \ + --backends torch,dynamo,torch_compile,inductor \ + --truncate \ + --report "sd_unet_perf_bs${bs}.txt" done # Benchmark BERT model @@ -60,7 +104,7 @@ for bs in ${batch_sizes[@]} do python perf_run.py --model ${MODELS_DIR}/bert_base_uncased_traced.jit.pt \ --model_torch "bert_base_uncased" \ - --precision fp32 --inputs="(${bs}, 128)@int32;(${bs}, 128)@int32" \ + --precision fp32,fp16 --inputs="(${bs}, 128)@int32;(${bs}, 128)@int32" \ --batch_size ${bs} \ --backends torch,ts_trt,dynamo,torch_compile,inductor \ --truncate \ diff --git a/tools/perf/custom_models.py b/tools/perf/custom_models.py index a8b8a5dae0..0f85957e1e 100644 --- a/tools/perf/custom_models.py +++ b/tools/perf/custom_models.py @@ -1,10 +1,18 @@ import torch -import torch.nn as nn -from transformers import BertModel, BertTokenizer, BertConfig -import torch.nn.functional as F def BertModule(): + from transformers import BertModel + + model_name = "bert-base-uncased" + model = BertModel.from_pretrained(model_name, torchscript=True) + model.eval() + return model + + +def BertInputs(): + from transformers import BertTokenizer + model_name = "bert-base-uncased" enc = BertTokenizer.from_pretrained(model_name) text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]" @@ -15,16 +23,13 @@ def BertModule(): segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1] tokens_tensor = torch.tensor([indexed_tokens]) segments_tensors = torch.tensor([segments_ids]) - config = BertConfig( - vocab_size_or_config_json_file=32000, - hidden_size=768, - num_hidden_layers=12, - num_attention_heads=12, - intermediate_size=3072, - torchscript=True, + return [tokens_tensor, segments_tensors] + + +def StableDiffusionUnet(): + from diffusers import DiffusionPipeline + + pipe = DiffusionPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16 ) - model = BertModel(config) - model.eval() - model = BertModel.from_pretrained(model_name, torchscript=True) - traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors]) - return traced_model + return pipe.unet diff --git a/tools/perf/hub.py b/tools/perf/hub.py index d6dbd95376..4a14b59a4b 100644 --- a/tools/perf/hub.py +++ b/tools/perf/hub.py @@ -1,13 +1,7 @@ import json import os -import custom_models as cm -import timm import torch -import torch.nn as nn -import torch.nn.functional as F -import torchvision.models as models -from transformers import BertConfig, BertModel, BertTokenizer torch.hub._validate_not_a_forked_repo = lambda a, b, c: True @@ -26,25 +20,7 @@ VALID_PATHS = ("script", "trace", "torchscript", "pytorch", "all") # Key models selected for benchmarking with their respective paths -BENCHMARK_MODELS = { - "vgg16": { - "model": models.vgg16(weights=models.VGG16_Weights.DEFAULT), - "path": ["script", "pytorch"], - }, - "resnet50": { - "model": models.resnet50(weights=None), - "path": ["script", "pytorch"], - }, - "efficientnet_b0": { - "model": timm.create_model("efficientnet_b0", pretrained=True), - "path": ["script", "pytorch"], - }, - "vit": { - "model": timm.create_model("vit_base_patch16_224", pretrained=True), - "path": ["script", "pytorch"], - }, - "bert_base_uncased": {"model": cm.BertModule(), "path": "trace"}, -} +from utils import BENCHMARK_MODELS def get(n, m, manifest): @@ -52,42 +28,38 @@ def get(n, m, manifest): traced_filename = "models/" + n + "_traced.jit.pt" script_filename = "models/" + n + "_scripted.jit.pt" pytorch_filename = "models/" + n + "_pytorch.pt" - x = torch.ones((1, 3, 300, 300)).cuda() - if n == "bert_base_uncased": - traced_model = m["model"] - torch.jit.save(traced_model, traced_filename) + + m["model"] = m["model"].eval().cuda() + + # Get all desired model save specifications as list + paths = [m["path"]] if isinstance(m["path"], str) else m["path"] + + # Depending on specified model save specifications, save desired model formats + if any(path in ("all", "torchscript", "trace") for path in paths): + # (TorchScript) Traced model + trace_model = torch.jit.trace(m["model"], [inp.cuda() for inp in m["inputs"]]) + torch.jit.save(trace_model, traced_filename) manifest.update({n: [traced_filename]}) - else: - m["model"] = m["model"].eval().cuda() - - # Get all desired model save specifications as list - paths = [m["path"]] if isinstance(m["path"], str) else m["path"] - - # Depending on specified model save specifications, save desired model formats - if any(path in ("all", "torchscript", "trace") for path in paths): - # (TorchScript) Traced model - trace_model = torch.jit.trace(m["model"], [x]) - torch.jit.save(trace_model, traced_filename) - manifest.update({n: [traced_filename]}) - if any(path in ("all", "torchscript", "script") for path in paths): - # (TorchScript) Scripted model - script_model = torch.jit.script(m["model"]) - torch.jit.save(script_model, script_filename) - if n in manifest.keys(): - files = list(manifest[n]) if type(manifest[n]) != list else manifest[n] - files.append(script_filename) - manifest.update({n: files}) - else: - manifest.update({n: [script_filename]}) - if any(path in ("all", "pytorch") for path in paths): - # (PyTorch Module) model - torch.save(m["model"], pytorch_filename) - if n in manifest.keys(): - files = list(manifest[n]) if type(manifest[n]) != list else manifest[n] - files.append(script_filename) - manifest.update({n: files}) - else: - manifest.update({n: [script_filename]}) + if any(path in ("all", "torchscript", "script") for path in paths): + # (TorchScript) Scripted model + script_model = torch.jit.script(m["model"]) + torch.jit.save(script_model, script_filename) + if n in manifest.keys(): + files = list(manifest[n]) if type(manifest[n]) != list else manifest[n] + files.append(script_filename) + manifest.update({n: files}) + else: + manifest.update({n: [script_filename]}) + if any(path in ("all", "pytorch") for path in paths): + # (PyTorch Module) model + torch.save(m["model"], pytorch_filename) + if n in manifest.keys(): + files = list(manifest[n]) if type(manifest[n]) != list else manifest[n] + files.append(script_filename) + manifest.update({n: files}) + else: + manifest.update({n: [script_filename]}) + return manifest diff --git a/tools/perf/perf_run.py b/tools/perf/perf_run.py index c96c9d3f8d..d7ee46279a 100644 --- a/tools/perf/perf_run.py +++ b/tools/perf/perf_run.py @@ -2,6 +2,7 @@ # Config parsers and report generations import argparse +import logging import os import time import timeit @@ -14,7 +15,6 @@ # Importing supported Backends import torch import torch.backends.cudnn as cudnn -import torch_tensorrt as torchtrt from utils import ( BENCHMARK_MODELS, parse_backends, @@ -23,11 +23,26 @@ precision_to_dtype, ) +import torch_tensorrt as torchtrt + WARMUP_ITER = 10 results = [] +def run_with_try_except(func): + def wrapper_func(*args, **kwargs): + try: + return func(*args, **kwargs) + except KeyboardInterrupt: + raise + except: + logging.warning(f"Running {func} failed", exc_info=True) + + return wrapper_func + + # Runs inference using Torch backend +@run_with_try_except def run_torch(model, input_tensors, params, precision, batch_size): print("Running Torch for precision: ", precision, " batch_size : ", batch_size) iters = params.get("iterations", 20) @@ -53,6 +68,7 @@ def run_torch(model, input_tensors, params, precision, batch_size): # Runs inference using Torch-TensorRT backend +@run_with_try_except def run_ts_trt(model, input_tensors, params, precision, batch_size): print( "Running Torch-TensorRT for precision: ", @@ -71,9 +87,9 @@ def run_ts_trt(model, input_tensors, params, precision, batch_size): compile_settings.update({"calib": params.get("calibration_cache")}) start_compile = time.time_ns() - model = torchtrt.compile(model, **compile_settings) + model = torchtrt.compile(model, ir="ts", **compile_settings) end_compile = time.time_ns() - compile_time_ms = (end_compile - start_compile) / 1e6 + compile_time_s = (end_compile - start_compile) / 1e9 iters = params.get("iterations", 20) # Warm up @@ -94,10 +110,11 @@ def run_ts_trt(model, input_tensors, params, precision, batch_size): timings.append(meas_time) recordStats( - "Torch-TensorRT [Torchscript]", timings, precision, batch_size, compile_time_ms + "Torch-TensorRT [Torchscript]", timings, precision, batch_size, compile_time_s ) +@run_with_try_except def run_dynamo(model, input_tensors, params, precision, batch_size): """ Compile the given model using Torch-TensorRT dynamo frontend and record performance stats @@ -119,7 +136,7 @@ def run_dynamo(model, input_tensors, params, precision, batch_size): truncate_long_and_double=params.get("truncate", False), ) end_compile = time.time_ns() - compile_time_ms = (end_compile - start_compile) / 1e6 + compile_time_s = (end_compile - start_compile) / 1e9 iters = params.get("iterations", 20) # Warm up with torch.no_grad(): @@ -139,14 +156,17 @@ def run_dynamo(model, input_tensors, params, precision, batch_size): timings.append(meas_time) recordStats( - "Torch-TensorRT [Dynamo]", timings, precision, batch_size, compile_time_ms + "Torch-TensorRT [Dynamo]", timings, precision, batch_size, compile_time_s ) +@run_with_try_except def run_torch_compile(model, input_tensors, params, precision, batch_size): """ Compile the given model using Torch-TensorRT torch.compile frontend and record performance stats """ + torch._dynamo.reset() + print( "Running Torch-TensorRT [torch_compile] for precision: ", precision, @@ -165,7 +185,7 @@ def run_torch_compile(model, input_tensors, params, precision, batch_size): ) model(*input_tensors) end_compile = time.time_ns() - compile_time_ms = (end_compile - start_compile) / 1e6 + compile_time_s = (end_compile - start_compile) / 1e9 iters = params.get("iterations", 20) # Warm up with torch.no_grad(): @@ -191,16 +211,19 @@ def run_torch_compile(model, input_tensors, params, precision, batch_size): timings, precision, batch_size, - compile_time_ms, + compile_time_s, ) +@run_with_try_except def run_inductor(model, input_tensors, params, precision, batch_size): """ Compile the given model using torch inductor and record performance stats """ + torch._dynamo.reset() + print( - "Running Torch-TensorRT [inductor] for precision: ", + "Running Torch [inductor] for precision: ", precision, " batch_size : ", batch_size, @@ -210,7 +233,7 @@ def run_inductor(model, input_tensors, params, precision, batch_size): model = torch.compile(model, backend="inductor", dynamic=False, mode="max-autotune") model(*input_tensors) end_compile = time.time_ns() - compile_time_ms = (end_compile - start_compile) / 1e6 + compile_time_s = (end_compile - start_compile) / 1e9 iters = params.get("iterations", 20) # Warm up with torch.no_grad(): @@ -232,11 +255,11 @@ def run_inductor(model, input_tensors, params, precision, batch_size): torch._dynamo.reset() recordStats( - "Torch-TensorRT [inductor]", + "Torch [inductor]", timings, precision, batch_size, - compile_time_ms, + compile_time_s, ) @@ -264,6 +287,7 @@ def torch_device_from_trt(device): return TypeError("%s is not supported by torch" % device) +@run_with_try_except def run_tensorrt( model, input_tensors, @@ -356,7 +380,7 @@ def run( print("int8 precision expects calibration cache file for inference") return False - if (model is None) and (backend != "fx2trt"): + if (model is None) and (backend in ("tensorrt", "ts_trt", "all")): warnings.warn( f"Requested backend {backend} without specifying a TorchScript Model, " + "skipping this backend" @@ -423,7 +447,7 @@ def run( # Generate report -def recordStats(backend, timings, precision, batch_size=1, compile_time_ms=None): +def recordStats(backend, timings, precision, batch_size=1, compile_time_s=None): times = np.array(timings) steps = len(times) speeds = batch_size / times @@ -442,7 +466,7 @@ def recordStats(backend, timings, precision, batch_size=1, compile_time_ms=None) "Mean(FPS)": speed_mean, "Median-Latency(ms)": time_med * 1000, "Mean-Latency(ms)": time_mean * 1000, - "Compile Time(ms)": compile_time_ms, + "Compile Time(s)": compile_time_s, } results.append(stats) @@ -488,7 +512,7 @@ def recordStats(backend, timings, precision, batch_size=1, compile_time_ms=None) arg_parser.add_argument( "--truncate", action="store_true", - help="Truncate long and double weights in the network in Torch-TensorRT", + help="Truncate long and double weights in the network in Torch-TensorRT", ) arg_parser.add_argument( "--is_trt_engine", @@ -563,8 +587,10 @@ def recordStats(backend, timings, precision, batch_size=1, compile_time_ms=None) if not is_trt_engine and (precision == "fp16" or precision == "half"): # If model is TensorRT serialized engine then model.half will report failure - model = model.half() - model_torch = model_torch.half() + if model is not None: + model = model.half() + if model_torch is not None: + model_torch = model_torch.half() status = run( model, diff --git a/tools/perf/requirements.txt b/tools/perf/requirements.txt index f9f8813feb..4cf81b69f2 100644 --- a/tools/perf/requirements.txt +++ b/tools/perf/requirements.txt @@ -2,4 +2,6 @@ timeit numpy argparse yaml +transformers==4.33.2 +diffusers==0.21.4 pandas==2.0.1 diff --git a/tools/perf/utils.py b/tools/perf/utils.py index 96a13ffbc2..6e071da0f2 100644 --- a/tools/perf/utils.py +++ b/tools/perf/utils.py @@ -1,14 +1,19 @@ -import torch -import torch_tensorrt import custom_models as cm -import torchvision.models as models import timm +import torch +import torchvision.models as models + +import torch_tensorrt BENCHMARK_MODELS = { "vgg16": { "model": models.vgg16(weights=models.VGG16_Weights.DEFAULT), "path": ["script", "pytorch"], }, + "alexnet": { + "model": models.alexnet(weights=models.AlexNet_Weights.DEFAULT), + "path": ["script", "pytorch"], + }, "resnet50": { "model": models.resnet50(weights=None), "path": ["script", "pytorch"], @@ -19,9 +24,21 @@ }, "vit": { "model": timm.create_model("vit_base_patch16_224", pretrained=True), - "path": "script", + "path": ["script", "pytorch"], + }, + # "vit_large": { + # "model": timm.create_model("vit_giant_patch14_224", pretrained=False), + # "path": ["script", "pytorch"], + # }, + "bert_base_uncased": { + "model": cm.BertModule(), + "inputs": cm.BertInputs(), + "path": ["trace", "pytorch"], + }, + "sd_unet": { + "model": cm.StableDiffusionUnet(), + "path": "pytorch", }, - "bert_base_uncased": {"model": cm.BertModule(), "path": "trace"}, } @@ -51,7 +68,11 @@ def parse_inputs(user_inputs, dtype): ) for input_dim in input_shape_and_dtype[0][1:-1].split(","): input_shape.append(int(input_dim)) - torchtrt_inputs.append(torch.randint(0, 5, input_shape, dtype=dtype).cuda()) + + if input_shape != [1]: + torchtrt_inputs.append(torch.randn(input_shape, dtype=dtype).cuda()) + else: + torchtrt_inputs.append(torch.Tensor([1.0]).cuda()) return torchtrt_inputs