Skip to content

Commit 3c87214

Browse files
author
Wei
authored
[FX] refactor the fx path in compile function (#1141)
* compile interface * add compile method * update * update * Update lower_setting.py * update fx2trt_example * add docstring * update dynamic_batch default to False * add docstring * add save/load module
1 parent 5b03083 commit 3c87214

File tree

8 files changed

+143
-161
lines changed

8 files changed

+143
-161
lines changed

py/torch_tensorrt/_compile.py

Lines changed: 11 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
import torch_tensorrt.ts
44
from torch_tensorrt import logging
55
import torch
6-
from torch import fx
6+
import torch.fx
77
from enum import Enum
8-
from torch_tensorrt import fx
8+
import torch_tensorrt.fx
9+
from torch_tensorrt.fx.lower import lower_to_trt
10+
from torch_tensorrt.fx.utils import LowerPrecision
911

1012
class _IRType(Enum):
1113
"""Enum to set the minimum required logging level to print a message to stdout
@@ -108,78 +110,14 @@ def compile(module: Any, ir="default", inputs=[], enabled_precisions=set([_enums
108110
ts_mod = torch.jit.script(module)
109111
return torch_tensorrt.ts.compile(ts_mod, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs)
110112
elif target_ir == _IRType.fx:
111-
from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer
112-
from torch_tensorrt.fx import InputTensorSpec
113-
from torch_tensorrt.fx import TRTInterpreter
114-
from torch_tensorrt.fx.passes.lower_basic_pass import transform_setitem
115-
from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter
116-
from torch_tensorrt.fx.tools.trt_splitter import TRTSplitterSetting
117-
from torch_tensorrt.fx.trt_module import TRTModule
118-
from torch_tensorrt.fx.utils import LowerPrecision
119-
acc_model = acc_tracer.trace(module, inputs)
120-
121-
splitter_setting = TRTSplitterSetting()
122-
splitter_setting.use_implicit_batch_dim = False
123-
splitter = TRTSplitter(acc_model, inputs, settings=splitter_setting)
124-
splitter.node_support_preview()
125-
split_mod = splitter()
126-
num_piece = 0
127-
for name, _ in split_mod.named_children():
128-
print(f"graph is split into {name}")
129-
num_piece += 1
130-
131-
# if the graph module is split into pieces larger than 8, we consider its perf
132-
# is not good and fall back to non-TRT
133-
if num_piece > 8:
134-
print(
135-
f"The graph module is split into {num_piece} which is large than the \
136-
threshold=8. Fall back to non-TRT module."
137-
)
138-
return None
139-
140-
if torch.float16 in enabled_precisions or torch.half in enabled_precisions:
141-
precision = LowerPrecision.FP16
113+
if torch.float16 in enabled_precisions or torch_tensorrt.dtype.half in enabled_precisions:
114+
lower_precision = LowerPrecision.FP16
115+
elif torch.float32 in enabled_precisions or torch_tensorrt.dtype.float in enabled_precisions:
116+
lower_precision = LowerPrecision.FP32
142117
else:
143-
precision = LowerPrecision.FP32
144-
145-
def get_submod_inputs(mod, submod, inputs):
146-
acc_inputs = None
147-
148-
def get_input(self, inputs):
149-
nonlocal acc_inputs
150-
acc_inputs = inputs
151-
152-
handle = submod.register_forward_pre_hook(get_input)
153-
mod(*inputs)
154-
handle.remove()
155-
return acc_inputs
156-
157-
for name, _ in split_mod.named_children():
158-
if "_run_on_acc" in name:
159-
submod = getattr(split_mod, name)
160-
# Get submodule inputs for fx2trt
161-
acc_inputs = get_submod_inputs(split_mod, submod, inputs)
162-
163-
# fx2trt replacement
164-
interp = TRTInterpreter(
165-
submod,
166-
InputTensorSpec.from_tensors(acc_inputs),
167-
explicit_batch_dimension=True,
168-
)
169-
r = interp.run(
170-
max_workspace_size=20 << 30,
171-
lower_precision=precision,
172-
# profiling_verbosity=trt.ProfilingVerbosity.DETAILED, #For profile
173-
)
174-
# For profile
175-
# from torch_tensorrt.fx.tools.trt_profiler_sorted import profile_trt_module
176-
# profile_trt_module("", trt_mod, acc_inputs)
177-
trt_mod = TRTModule(*r)
178-
179-
setattr(split_mod, name, trt_mod)
180-
else:
181-
submod = getattr(split_mod, name)
182-
return split_mod
118+
raise ValueError(f"Precision {enabled_precisions} not supported on FX")
119+
120+
return lower_to_trt(module, inputs, lower_precision=lower_precision, max_batch_size=inputs[0].size(0), explicit_batch_dimension=True)
183121
else:
184122
raise RuntimeError("Module is an unknown format or the ir requested is unknown")
185123

py/torch_tensorrt/fx/example/fx2trt_example.py

Lines changed: 52 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
import torch
44
import torch.fx
55
import torch.nn as nn
6+
from torch_tensorrt.fx.utils import LowerPrecision
67
import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer
78
from torch_tensorrt.fx import InputTensorSpec, TRTInterpreter, TRTModule
89
from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter
910

10-
1111
# The purpose of this example is to demonstrate the overall flow of lowering a PyTorch
1212
# model to TensorRT via FX with existing FX based tooling. The general lowering flow
1313
# would be like:
@@ -30,11 +30,12 @@ def forward(self, x):
3030
x = self.linear(x)
3131
x = self.relu(x)
3232
x = torch.linalg.norm(x, ord=2, dim=1)
33+
x = self.relu(x)
3334
return x
3435

3536

36-
inputs = [torch.randn(1, 10)]
37-
model = Model().eval()
37+
inputs = [torch.randn((1, 10), device=torch.device('cuda'))]
38+
model = Model().cuda().eval()
3839

3940
# acc_tracer is a custom fx tracer that maps nodes whose targets are PyTorch operators
4041
# to acc ops.
@@ -64,20 +65,23 @@ def forward(self, x):
6465
# Split.
6566
split_mod = splitter()
6667

67-
# After split we have two submodules, _run_on_acc_0 and _run_on_gpu_1.
68+
# After split we have three submodules, _run_on_acc_0 and _run_on_gpu_1.
6869
print(split_mod.graph)
6970
"""
7071
graph():
7172
%x : [#users=1] = placeholder[target=x]
7273
%_run_on_acc_0 : [#users=1] = call_module[target=_run_on_acc_0](args = (%x,), kwargs = {})
7374
%_run_on_gpu_1 : [#users=1] = call_module[target=_run_on_gpu_1](args = (%_run_on_acc_0,), kwargs = {})
74-
return _run_on_gpu_1
75+
%_run_on_acc_2 : [#users=1] = call_module[target=_run_on_acc_2](args = (%_run_on_gpu_1,), kwargs = {})
76+
return _run_on_acc_2
7577
"""
7678

7779
# Take a look at what inside each submodule. _run_on_acc_0 contains linear and relu while
78-
# _run_on_gpu_1 contains linalg_norm which currently is not supported by fx2trt.
80+
# _run_on_gpu_1 contains linalg_norm which currently is not supported by fx2trt. _run_on_acc_3
81+
# is the another submodule supported.
7982
print(split_mod._run_on_acc_0.graph)
8083
print(split_mod._run_on_gpu_1.graph)
84+
print(split_mod._run_on_acc_2.graph)
8185
"""
8286
graph():
8387
%x : [#users=1] = placeholder[target=x]
@@ -90,32 +94,51 @@ def forward(self, x):
9094
%relu_1 : [#users=1] = placeholder[target=relu_1]
9195
%linalg_norm_1 : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.linalg_norm](args = (), ...
9296
return linalg_norm_1
97+
graph():
98+
%linalg_norm_1 : [#users=1] = placeholder[target=linalg_norm_1]
99+
%relu_3 : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.relu](args = (), kwargs = {input: %linalg_norm_1, inplace: False})
100+
return relu_3
93101
"""
94102

95-
# Now let's lower split_mod._run_on_acc_0. If we know the model can be fully lowered,
96-
# we can skip the splitter part.
97-
interp = TRTInterpreter(split_mod._run_on_acc_0, InputTensorSpec.from_tensors(inputs))
98-
r = interp.run()
99-
trt_mod = TRTModule(r.engine, r.input_names, r.output_names)
100-
split_mod._run_on_acc_0 = trt_mod
101-
102-
cuda_inputs = [input.cuda() for input in inputs]
103-
split_mod.cuda()
104-
lowered_model_output = split_mod(*cuda_inputs)
103+
def get_submod_inputs(mod, submod, inputs):
104+
acc_inputs = None
105+
106+
def get_input(self, inputs):
107+
nonlocal acc_inputs
108+
acc_inputs = inputs
109+
110+
handle = submod.register_forward_pre_hook(get_input)
111+
mod(*inputs)
112+
handle.remove()
113+
return acc_inputs
114+
115+
# Since the model is splitted into three segments. We need to lower each TRT eligible segment.
116+
# If we know the model can be fully lowered, we can skip the splitter part.
117+
for name, _ in split_mod.named_children():
118+
if "_run_on_acc" in name:
119+
submod = getattr(split_mod, name)
120+
# Get submodule inputs for fx2trt
121+
acc_inputs = get_submod_inputs(split_mod, submod, inputs)
122+
123+
# fx2trt replacement
124+
interp = TRTInterpreter(
125+
submod,
126+
InputTensorSpec.from_tensors(acc_inputs),
127+
explicit_batch_dimension=True,
128+
)
129+
r = interp.run(lower_precision=LowerPrecision.FP32)
130+
trt_mod = TRTModule(*r)
131+
setattr(split_mod, name, trt_mod)
132+
133+
lowered_model_output = split_mod(*inputs)
134+
135+
# Save and load model
136+
torch.save(split_mod, "trt.pt")
137+
reload_trt_mod = torch.load("trt.pt")
138+
reload_model_output = reload_trt_mod(*inputs)
105139

106140
# Make sure the results match
107-
model.cuda()
108-
regular_model_output = model(*cuda_inputs)
141+
regular_model_output = model(*inputs)
109142
torch.testing.assert_close(
110-
lowered_model_output, regular_model_output.to(torch.float16), atol=3e-3, rtol=1e-2
143+
reload_model_output, regular_model_output, atol=3e-3, rtol=1e-2
111144
)
112-
113-
# We can utilize the trt profiler to print out the time spend on each layer.
114-
trt_mod.enable_profiling()
115-
trt_mod(*cuda_inputs)
116-
"""
117-
Reformatting CopyNode for Input Tensor 0 to LayerType.FULLY_CONNECTED_acc_ops.linear_linear_1: 0.027392ms
118-
LayerType.FULLY_CONNECTED_acc_ops.linear_linear_1: 0.023072ms
119-
PWN(ActivationType.RELU_acc_ops.relu_relu_1): 0.008928ms
120-
"""
121-
trt_mod.disable_profiling()

py/torch_tensorrt/fx/example/lower_example.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,6 @@ def run_configuration_benchmark(
198198

199199

200200
if __name__ == "__main__":
201-
test_model = torchvision.models.resnet101()
202-
input = [torch.cuda.FloatTensor(1024, 3, 224, 224)] # type: ignore[attr-defined]
203-
benchmark(test_model, input, 100, 1024)
201+
test_model = torchvision.models.resnet18(pretrained=True)
202+
input = [torch.rand(128, 3, 224, 224)] # type: ignore[attr-defined]
203+
benchmark(test_model, input, 50, 128)

py/torch_tensorrt/fx/example/test_fx2trt.py

Lines changed: 0 additions & 54 deletions
This file was deleted.
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import torch
2+
import copy
3+
import torchvision
4+
import torch_tensorrt
5+
from torch_tensorrt.fx import InputTensorSpec
6+
7+
8+
def test_torch_tensorrt(model, inputs):
9+
# torchscript path
10+
model_ts = copy.deepcopy(model)
11+
inputs_ts = copy.deepcopy(inputs)
12+
# fp32 test
13+
with torch.inference_mode():
14+
ref_fp32 = model_ts(*inputs_ts)
15+
trt_ts_module = torch_tensorrt.compile(
16+
model_ts, inputs=inputs_ts, enabled_precisions={torch.float32}
17+
)
18+
result_fp32 = trt_ts_module(*inputs_ts)
19+
assert(torch.nn.functional.cosine_similarity(ref_fp32.flatten(), result_fp32.flatten(), dim=0)>0.9999)
20+
# fp16 test
21+
model_ts = model_ts.half()
22+
inputs_ts = [i.cuda().half() for i in inputs_ts]
23+
with torch.inference_mode():
24+
ref_fp16 = model_ts(*inputs_ts)
25+
trt_ts_module = torch_tensorrt.compile(
26+
model_ts, inputs=inputs_ts, enabled_precisions={torch.float16}
27+
)
28+
result_fp16 = trt_ts_module(*inputs_ts)
29+
assert(torch.nn.functional.cosine_similarity(ref_fp16.flatten(), result_fp16.flatten(), dim=0)>0.99)
30+
31+
# FX path
32+
model_fx = copy.deepcopy(model)
33+
inputs_fx = copy.deepcopy(inputs)
34+
# fp32 test
35+
with torch.inference_mode():
36+
ref_fp32 = model_fx(*inputs_fx)
37+
trt_fx_module = torch_tensorrt.compile(
38+
model_fx, ir="fx", inputs=inputs_fx, enabled_precisions={torch.float32}
39+
)
40+
result_fp32 = trt_fx_module(*inputs_fx)
41+
assert(torch.nn.functional.cosine_similarity(ref_fp32.flatten(), result_fp32.flatten(), dim=0)>0.9999)
42+
# fp16 test
43+
model_fx = model_fx.cuda().half()
44+
inputs_fx = [i.cuda().half() for i in inputs_fx]
45+
with torch.inference_mode():
46+
ref_fp16 = model_fx(*inputs_fx)
47+
trt_fx_module = torch_tensorrt.compile(
48+
model_fx, ir="fx", inputs=inputs_fx, enabled_precisions={torch.float16}
49+
)
50+
result_fp16 = trt_fx_module(*inputs_fx)
51+
assert(torch.nn.functional.cosine_similarity(ref_fp16.flatten(), result_fp16.flatten(), dim=0)>0.99 )
52+
53+
54+
if __name__ == "__main__":
55+
model = torchvision.models.resnet18(pretrained=True).cuda().eval()
56+
inputs = [torch.ones((32, 3, 224, 224), device=torch.device('cuda'))] # type: ignore[attr-defined]
57+
test_torch_tensorrt(model, inputs)

py/torch_tensorrt/fx/fx2trt.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,21 @@ def run(
164164
timing_cache=None,
165165
profiling_verbosity=None,
166166
) -> TRTInterpreterResult:
167+
"""
168+
Build TensorRT engine with some configs.
169+
Args:
170+
max_batch_size: set accordingly for maximum batch size you will use.
171+
max_workspace_size: set to the maximum size we can afford for temporary buffer
172+
lower_precision: the precision model layers are running on (TensorRT will choose the best perforamnce precision).
173+
sparse_weights: allow the builder to examine weights and use optimized functions when weights have suitable sparsity
174+
force_fp32_output: force output to be fp32
175+
strict_type_constraints: Usually we should set it to False unless we want to control the precision of certain layer for numeric reasons.
176+
algorithm_selector: set up algorithm selection for certain layer
177+
timing_cache: enable timing cache for TensorRT
178+
profiling_verbosity: TensorRT logging level
179+
Return:
180+
TRTInterpreterResult
181+
"""
167182
TRT_INTERPRETER_CALL_PRE_OBSERVER.observe(self.module)
168183

169184
# For float outputs, we set their dtype to fp16 only if lower_precision == LowerPrecision.FP16 and

0 commit comments

Comments
 (0)