Skip to content

Commit 369fb71

Browse files
author
Wei
authored
[FX] move example folder (#1149)
* move example folder * change folder name to fx * add hf example
1 parent 93f1a5f commit 369fb71

File tree

7 files changed

+345
-0
lines changed

7 files changed

+345
-0
lines changed
Lines changed: 341 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,341 @@
1+
import argparse
2+
import copy
3+
import gc
4+
import time
5+
from functools import partial
6+
7+
import numpy as np
8+
import pandas as pd
9+
import torch
10+
from transformers import AutoConfig
11+
from transformers import AutoModelForCausalLM, AutoModelForMaskedLM, AutoModelForSeq2SeqLM
12+
from transformers import BertConfig, ReformerConfig, XLNetModel, XLNetConfig
13+
14+
import torchdynamo
15+
from torchdynamo.optimizations import backends
16+
from torchdynamo.optimizations.training import aot_autograd_debug_strategy1
17+
from torchdynamo.optimizations.training import aot_autograd_speedup_strategy
18+
from torchdynamo.testing import collect_results
19+
from torchdynamo.testing import same
20+
21+
torch.backends.cuda.matmul.allow_tf32 = True
22+
23+
24+
# This example is for testing the hugging face models. Since the model can not be directly traced by acc tracer(based on torch.fx)
25+
# We combined our efforts together with TorchDynamo. To illustrate the performance, we tested the performance with different batch size.
26+
27+
benchmarks = [
28+
# Longformer is not suitable for torch_tensorrt-fx
29+
# (
30+
# AutoConfig.from_pretrained("allenai/longformer-base-4096"),
31+
# AutoModelForMaskedLM,
32+
# (2, 1024),
33+
# [torch.bfloat16], # trilu not implemented for bfloat16
34+
# ),
35+
#(ReformerConfig(), AutoModelForMaskedLM, (8, 4096), []), # Reformer is not suitable for torch_tensorrt-fx
36+
#(BigBirdConfig(attention_type="block_sparse"), AutoModelForMaskedLM, (2, 1024), []), # Birdbird is not suitable for torch_tensorrt-fx
37+
#(AutoConfig.from_pretrained("google/fnet-base"), AutoModelForMaskedLM, (4, 512), []), # not supported by torch_tensorrt-fx
38+
39+
# batch size = 1
40+
(BertConfig(), AutoModelForMaskedLM, (1, 512), []),
41+
(AutoConfig.from_pretrained("albert-base-v2"), AutoModelForMaskedLM, (1, 512), []),
42+
(AutoConfig.from_pretrained("gpt2"), AutoModelForCausalLM, (1, 512), []),
43+
(AutoConfig.from_pretrained("t5-small"), AutoModelForSeq2SeqLM, (1, 512), []),
44+
(AutoConfig.from_pretrained("distilbert-base-uncased"), AutoModelForMaskedLM, (1, 512), []),
45+
(AutoConfig.from_pretrained("roberta-base"), AutoModelForMaskedLM, (1, 512), []),
46+
(AutoConfig.from_pretrained("distilgpt2"), AutoModelForCausalLM, (1, 512), []),
47+
(AutoConfig.from_pretrained("google/electra-base-discriminator"), AutoModelForMaskedLM, (1, 512), []),
48+
(AutoConfig.from_pretrained("YituTech/conv-bert-base"), AutoModelForMaskedLM, (1, 512), []),
49+
(AutoConfig.from_pretrained("google/mobilebert-uncased"), AutoModelForMaskedLM, (1, 512), []),
50+
(AutoConfig.from_pretrained("camembert-base"), AutoModelForMaskedLM, (1, 512), []),
51+
(AutoConfig.from_pretrained("microsoft/layoutlm-base-uncased"), AutoModelForMaskedLM, (1, 512), []),
52+
# batch size = 4
53+
(BertConfig(), AutoModelForMaskedLM, (4, 512), []),
54+
(AutoConfig.from_pretrained("albert-base-v2"), AutoModelForMaskedLM, (4, 512), []),
55+
(AutoConfig.from_pretrained("gpt2"), AutoModelForCausalLM, (4, 512), []),
56+
(AutoConfig.from_pretrained("t5-small"), AutoModelForSeq2SeqLM, (4, 512), []),
57+
(AutoConfig.from_pretrained("distilbert-base-uncased"), AutoModelForMaskedLM, (4, 512), []),
58+
(AutoConfig.from_pretrained("roberta-base"), AutoModelForMaskedLM, (4, 512), []),
59+
(AutoConfig.from_pretrained("distilgpt2"), AutoModelForCausalLM, (4, 512), []),
60+
(AutoConfig.from_pretrained("google/electra-base-discriminator"), AutoModelForMaskedLM, (4, 512), []),
61+
(AutoConfig.from_pretrained("YituTech/conv-bert-base"), AutoModelForMaskedLM, (4, 512), []),
62+
(AutoConfig.from_pretrained("google/mobilebert-uncased"), AutoModelForMaskedLM, (4, 512), []),
63+
(AutoConfig.from_pretrained("camembert-base"), AutoModelForMaskedLM, (4, 512), []),
64+
(AutoConfig.from_pretrained("microsoft/layoutlm-base-uncased"), AutoModelForMaskedLM, (4, 512), []),
65+
# batch size = 8
66+
(BertConfig(), AutoModelForMaskedLM, (8, 512), []),
67+
(AutoConfig.from_pretrained("albert-base-v2"), AutoModelForMaskedLM, (8, 512), []),
68+
(AutoConfig.from_pretrained("gpt2"), AutoModelForCausalLM, (8, 512), []),
69+
(AutoConfig.from_pretrained("t5-small"), AutoModelForSeq2SeqLM, (8, 512), []),
70+
(AutoConfig.from_pretrained("distilbert-base-uncased"), AutoModelForMaskedLM, (8, 512), []),
71+
(AutoConfig.from_pretrained("roberta-base"), AutoModelForMaskedLM, (8, 512), []),
72+
(AutoConfig.from_pretrained("distilgpt2"), AutoModelForCausalLM, (8, 512), []),
73+
(AutoConfig.from_pretrained("google/electra-base-discriminator"), AutoModelForMaskedLM, (8, 512), []),
74+
(AutoConfig.from_pretrained("YituTech/conv-bert-base"), AutoModelForMaskedLM, (8, 512), []),
75+
(AutoConfig.from_pretrained("google/mobilebert-uncased"), AutoModelForMaskedLM, (8, 512), []),
76+
(AutoConfig.from_pretrained("camembert-base"), AutoModelForMaskedLM, (8, 512), []),
77+
(AutoConfig.from_pretrained("microsoft/layoutlm-base-uncased"), AutoModelForMaskedLM, (8, 512), []),
78+
]
79+
80+
device = "cuda"
81+
82+
83+
class NullContext:
84+
def __enter__(self):
85+
pass
86+
87+
def __exit__(self, exc_type, exc_val, exc_tb):
88+
pass
89+
90+
91+
@torchdynamo.skip
92+
def get_cur_memory():
93+
torch.cuda.synchronize()
94+
95+
gc.collect()
96+
torch.cuda.empty_cache()
97+
stats = torch.cuda.memory_stats()
98+
peak_bytes_requirement = stats["allocated_bytes.all.current"]
99+
# print(f"Current memory requirement: {peak_bytes_requirement / 1024 ** 3:.2f} GB")
100+
return peak_bytes_requirement
101+
102+
103+
@torchdynamo.skip
104+
def forward_pass(mod, inputs, collect_outputs=True):
105+
return mod(*inputs)
106+
107+
# correctness function to compare with eager mode
108+
@torchdynamo.skip
109+
def check_correctness(args, mod, inputs, optimize_ctx, optimize_name):
110+
torch.manual_seed(1337)
111+
correct_result = forward_pass(copy.deepcopy(mod), inputs)
112+
113+
torch.manual_seed(1337)
114+
correct_rerun_result = forward_pass(copy.deepcopy(mod), inputs)
115+
116+
if not same(correct_result, correct_rerun_result):
117+
print("INCORRECT - Variation in Eager runs itself")
118+
return False
119+
120+
torch.manual_seed(1337)
121+
torchdynamo.reset()
122+
try:
123+
with optimize_ctx:
124+
new_result = forward_pass(mod, inputs)
125+
except Exception:
126+
print("ERROR")
127+
return False
128+
129+
if optimize_name == "dynamo_fx2trt_fp16":
130+
cos_similarity = True
131+
else:
132+
cos_similarity = False
133+
134+
if not same(correct_result, new_result, cos_similarity=cos_similarity, tol=1e-2):
135+
print("INCORRECT")
136+
return False
137+
return True
138+
139+
140+
synchronize = torch.cuda.synchronize
141+
142+
# timing function to record the repeated run time
143+
def timed(model, model_iter_fn, train_inputs, timings=1, return_result=False):
144+
synchronize()
145+
torch.manual_seed(1337)
146+
t0 = time.perf_counter()
147+
# Dont collect outputs to correctly measure timing
148+
for _ in range(timings):
149+
result = model_iter_fn(model, train_inputs, collect_outputs=False)
150+
synchronize()
151+
t1 = time.perf_counter()
152+
# print("===timed=", t1-t0)
153+
return (t1 - t0, result) if return_result else t1 - t0
154+
155+
# benchmark functions for repeated run of hugging face models after tracing by torchdynamo and lowered through torch_tensorrt-fx
156+
@torchdynamo.skip
157+
def bench_model_eval(args, name, mod, eval_inputs, optimize_ctx):
158+
if type(optimize_ctx) == NullContext:
159+
# Profile memory
160+
m = None
161+
for i in range(5):
162+
out = mod(*eval_inputs)
163+
if i == 4:
164+
m = get_cur_memory()
165+
166+
# Warmup
167+
iters = 5
168+
for _ in range(iters):
169+
timed(mod, forward_pass, eval_inputs)
170+
synchronize()
171+
172+
# Profile time
173+
iters = 50
174+
synchronize()
175+
timings = []
176+
for _ in range(iters):
177+
timings.append(timed(mod, forward_pass, eval_inputs))
178+
t = np.median(timings, axis=0)
179+
else:
180+
# does not need recompile for torchdynamo, demo for fx2trt only
181+
with torchdynamo.run():
182+
# Profile memory
183+
m = None
184+
for i in range(5):
185+
out = mod(*eval_inputs)
186+
if i == 4:
187+
m = get_cur_memory()
188+
189+
# Warmup
190+
iters = 5
191+
for _ in range(iters):
192+
timed(mod, forward_pass, eval_inputs)
193+
synchronize()
194+
195+
# Profile time
196+
iters = 50
197+
synchronize()
198+
timings = []
199+
for _ in range(iters):
200+
timings.append(timed(mod, forward_pass, eval_inputs))
201+
t = np.median(timings, axis=0)
202+
203+
print(name, t, m)
204+
return t, m
205+
206+
207+
model_header, dtype_header, nh, th, mh, sp, mp, acc = (
208+
"model",
209+
"dtype",
210+
"name",
211+
"time (s)",
212+
"mem (GB)",
213+
"speedup",
214+
"mem_compression",
215+
"is_accurate",
216+
)
217+
218+
219+
def create_record(model_name, dtype, is_accurate, name, t, m):
220+
return {
221+
model_header: model_name,
222+
dtype_header: str(dtype),
223+
acc: is_accurate,
224+
nh: name,
225+
th: t,
226+
mh: m / 2 ** 30,
227+
}
228+
229+
230+
numerical_diffs = []
231+
results = []
232+
233+
234+
def load_model(config, model_type, dtype, args):
235+
for attr in dir(config):
236+
if "drop" in attr and isinstance(getattr(config, attr), float):
237+
setattr(
238+
config, attr, 1e-30
239+
) # So we can check for correct gradients without eliminating the dropout computation
240+
model = model_type.from_config(config).to(device, dtype=dtype)
241+
model.eval()
242+
return model
243+
244+
245+
class ArgsToKwargsWrapper(torch.nn.Module):
246+
def __init__(self, model):
247+
super(ArgsToKwargsWrapper, self).__init__()
248+
self.model = model
249+
250+
def forward(self, input_ids, decoder_input_ids):
251+
return self.model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
252+
253+
254+
def run_all_eval(args, optimize_ctx, optimize_name, dtype):
255+
for config, model_type, input_size, not_supported_dtypes in benchmarks:
256+
if dtype in not_supported_dtypes:
257+
continue
258+
259+
model = load_model(config, model_type, dtype, args)
260+
261+
model_name = type(model).__name__
262+
263+
# Prepare inputs
264+
input_ids = torch.randint(0, config.vocab_size, input_size).to(device)
265+
266+
if model_type.__name__ == "AutoModelForSeq2SeqLM":
267+
model = ArgsToKwargsWrapper(model)
268+
eval_inputs = (input_ids, input_ids, )
269+
else:
270+
eval_inputs = (input_ids,)
271+
272+
# Correctness check
273+
is_accurate = check_correctness(args, model, eval_inputs, optimize_ctx, optimize_name)
274+
# Profile eager
275+
t, m = bench_model_eval(args, "eager", model, eval_inputs, NullContext())
276+
results.append(create_record(model_name, dtype, is_accurate, "eager", t, m))
277+
278+
# Profile Dynamo nvfuser
279+
t, m = bench_model_eval(args, optimize_name, model, eval_inputs, optimize_ctx)
280+
results.append(create_record(model_name, dtype, is_accurate, optimize_name, t, m))
281+
282+
# calculate relative improvements
283+
base_r = results[-2]
284+
for r in results[-2:]:
285+
r[sp] = round(base_r[th] / r[th], 3)
286+
r[mp] = round(base_r[mh] / r[mh], 3)
287+
print(pd.DataFrame(results[-2:]).to_markdown(index=False, floatfmt=".3f"))
288+
289+
print("=== Final results ===")
290+
print(pd.DataFrame(results).to_markdown(index=False, floatfmt=".3f"))
291+
292+
293+
def main():
294+
parser = argparse.ArgumentParser()
295+
group = parser.add_mutually_exclusive_group()
296+
group.add_argument(
297+
"--run-dynamo-eager",
298+
action="store_true",
299+
help="Use Dynamo eager",
300+
)
301+
group.add_argument(
302+
"--run-dynamo-fx2trt-fp16",
303+
action="store_true",
304+
help="Use Dynamo with fx2trt fp16",
305+
)
306+
group.add_argument(
307+
"--run-dynamo-fx2trt-fp32",
308+
action="store_true",
309+
help="Use Dynamo with fx2trt fp32",
310+
)
311+
args = parser.parse_args()
312+
optimize_ctx = NullContext()
313+
optimize_name = "eager"
314+
315+
if args.run_dynamo_eager:
316+
optimize_ctx = torchdynamo.optimize("eager")
317+
optimize_name = "dynamo_eager"
318+
elif args.run_dynamo_fx2trt_fp16:
319+
optimize_ctx = torchdynamo.optimize(
320+
backends.fx2trt_compiler_fp16
321+
)
322+
optimize_name = "dynamo_fx2trt_fp16"
323+
elif args.run_dynamo_fx2trt_fp32:
324+
optimize_ctx = torchdynamo.optimize(
325+
backends.fx2trt_compiler
326+
)
327+
optimize_name = "dynamo_fx2trt_fp32"
328+
329+
experiment = run_all_eval
330+
# fp16
331+
if optimize_name == "dynamo_fx2trt_fp16":
332+
experiment = partial(experiment, dtype=torch.float16)
333+
if optimize_name == "dynamo_fx2trt_fp32":
334+
experiment = partial(experiment, dtype=torch.float32)
335+
336+
experiment = partial(experiment, optimize_ctx=optimize_ctx, optimize_name=optimize_name)
337+
experiment(args)
338+
339+
340+
if __name__ == "__main__":
341+
main()

py/torch_tensorrt/fx/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
FX2TRT is merged as FX module in Torch-TensorRT
2+
3+
- The user guide is in [link](https://github.com/pytorch/TensorRT/blob/master/docsrc/tutorials/getting_started_with_fx_path.rst#installation)
4+
- The examples are moved to [link](https://github.com/pytorch/TensorRT/tree/master/examples/fx_example)

0 commit comments

Comments
 (0)