3
3
import torch
4
4
import torch .fx
5
5
import torch .nn as nn
6
+ from torch_tensorrt .fx .utils import LowerPrecision
6
7
import torch_tensorrt .fx .tracer .acc_tracer .acc_tracer as acc_tracer
7
8
from torch_tensorrt .fx import InputTensorSpec , TRTInterpreter , TRTModule
8
9
from torch_tensorrt .fx .tools .trt_splitter import TRTSplitter
9
10
10
-
11
11
# The purpose of this example is to demonstrate the overall flow of lowering a PyTorch
12
12
# model to TensorRT via FX with existing FX based tooling. The general lowering flow
13
13
# would be like:
@@ -30,11 +30,12 @@ def forward(self, x):
30
30
x = self .linear (x )
31
31
x = self .relu (x )
32
32
x = torch .linalg .norm (x , ord = 2 , dim = 1 )
33
+ x = self .relu (x )
33
34
return x
34
35
35
36
36
- inputs = [torch .randn (1 , 10 )]
37
- model = Model ().eval ()
37
+ inputs = [torch .randn (( 1 , 10 ), device = torch . device ( 'cuda' ) )]
38
+ model = Model ().cuda (). eval ()
38
39
39
40
# acc_tracer is a custom fx tracer that maps nodes whose targets are PyTorch operators
40
41
# to acc ops.
@@ -64,20 +65,23 @@ def forward(self, x):
64
65
# Split.
65
66
split_mod = splitter ()
66
67
67
- # After split we have two submodules, _run_on_acc_0 and _run_on_gpu_1.
68
+ # After split we have three submodules, _run_on_acc_0 and _run_on_gpu_1.
68
69
print (split_mod .graph )
69
70
"""
70
71
graph():
71
72
%x : [#users=1] = placeholder[target=x]
72
73
%_run_on_acc_0 : [#users=1] = call_module[target=_run_on_acc_0](args = (%x,), kwargs = {})
73
74
%_run_on_gpu_1 : [#users=1] = call_module[target=_run_on_gpu_1](args = (%_run_on_acc_0,), kwargs = {})
74
- return _run_on_gpu_1
75
+ %_run_on_acc_2 : [#users=1] = call_module[target=_run_on_acc_2](args = (%_run_on_gpu_1,), kwargs = {})
76
+ return _run_on_acc_2
75
77
"""
76
78
77
79
# Take a look at what inside each submodule. _run_on_acc_0 contains linear and relu while
78
- # _run_on_gpu_1 contains linalg_norm which currently is not supported by fx2trt.
80
+ # _run_on_gpu_1 contains linalg_norm which currently is not supported by fx2trt. _run_on_acc_3
81
+ # is the another submodule supported.
79
82
print (split_mod ._run_on_acc_0 .graph )
80
83
print (split_mod ._run_on_gpu_1 .graph )
84
+ print (split_mod ._run_on_acc_2 .graph )
81
85
"""
82
86
graph():
83
87
%x : [#users=1] = placeholder[target=x]
@@ -90,32 +94,51 @@ def forward(self, x):
90
94
%relu_1 : [#users=1] = placeholder[target=relu_1]
91
95
%linalg_norm_1 : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.linalg_norm](args = (), ...
92
96
return linalg_norm_1
97
+ graph():
98
+ %linalg_norm_1 : [#users=1] = placeholder[target=linalg_norm_1]
99
+ %relu_3 : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.relu](args = (), kwargs = {input: %linalg_norm_1, inplace: False})
100
+ return relu_3
93
101
"""
94
102
95
- # Now let's lower split_mod._run_on_acc_0. If we know the model can be fully lowered,
96
- # we can skip the splitter part.
97
- interp = TRTInterpreter (split_mod ._run_on_acc_0 , InputTensorSpec .from_tensors (inputs ))
98
- r = interp .run ()
99
- trt_mod = TRTModule (r .engine , r .input_names , r .output_names )
100
- split_mod ._run_on_acc_0 = trt_mod
101
-
102
- cuda_inputs = [input .cuda () for input in inputs ]
103
- split_mod .cuda ()
104
- lowered_model_output = split_mod (* cuda_inputs )
103
+ def get_submod_inputs (mod , submod , inputs ):
104
+ acc_inputs = None
105
+
106
+ def get_input (self , inputs ):
107
+ nonlocal acc_inputs
108
+ acc_inputs = inputs
109
+
110
+ handle = submod .register_forward_pre_hook (get_input )
111
+ mod (* inputs )
112
+ handle .remove ()
113
+ return acc_inputs
114
+
115
+ # Since the model is splitted into three segments. We need to lower each TRT eligible segment.
116
+ # If we know the model can be fully lowered, we can skip the splitter part.
117
+ for name , _ in split_mod .named_children ():
118
+ if "_run_on_acc" in name :
119
+ submod = getattr (split_mod , name )
120
+ # Get submodule inputs for fx2trt
121
+ acc_inputs = get_submod_inputs (split_mod , submod , inputs )
122
+
123
+ # fx2trt replacement
124
+ interp = TRTInterpreter (
125
+ submod ,
126
+ InputTensorSpec .from_tensors (acc_inputs ),
127
+ explicit_batch_dimension = True ,
128
+ )
129
+ r = interp .run (lower_precision = LowerPrecision .FP32 )
130
+ trt_mod = TRTModule (* r )
131
+ setattr (split_mod , name , trt_mod )
132
+
133
+ lowered_model_output = split_mod (* inputs )
134
+
135
+ # Save and load model
136
+ torch .save (split_mod , "trt.pt" )
137
+ reload_trt_mod = torch .load ("trt.pt" )
138
+ reload_model_output = reload_trt_mod (* inputs )
105
139
106
140
# Make sure the results match
107
- model .cuda ()
108
- regular_model_output = model (* cuda_inputs )
141
+ regular_model_output = model (* inputs )
109
142
torch .testing .assert_close (
110
- lowered_model_output , regular_model_output . to ( torch . float16 ) , atol = 3e-3 , rtol = 1e-2
143
+ reload_model_output , regular_model_output , atol = 3e-3 , rtol = 1e-2
111
144
)
112
-
113
- # We can utilize the trt profiler to print out the time spend on each layer.
114
- trt_mod .enable_profiling ()
115
- trt_mod (* cuda_inputs )
116
- """
117
- Reformatting CopyNode for Input Tensor 0 to LayerType.FULLY_CONNECTED_acc_ops.linear_linear_1: 0.027392ms
118
- LayerType.FULLY_CONNECTED_acc_ops.linear_linear_1: 0.023072ms
119
- PWN(ActivationType.RELU_acc_ops.relu_relu_1): 0.008928ms
120
- """
121
- trt_mod .disable_profiling ()
0 commit comments