1
1
import logging
2
- from typing import Dict , List , Optional , Sequence , Set
2
+ from typing import Dict , List , Optional , Sequence , Set , Tuple
3
3
4
4
import torch
5
5
6
+ from torch .fx .passes .splitter_base import (
7
+ Subgraph ,
8
+ _SplitterBase ,
9
+ _SplitterSettingBase ,
10
+ FxNetAccNodesFinder ,
11
+ FxNetAccFusionsFinder ,
12
+ )
13
+ import torch .fx .passes .operator_support as ops
14
+ from torch .fx .passes .tools_common import NodeSet , CALLABLE_NODE_OPS
15
+ from torch .fx .node import Target
16
+
17
+ from torch_tensorrt .dynamo .conversion .converter_registry import ConverterRegistry
6
18
from torch_tensorrt .dynamo .lowering import SUBSTITUTION_REGISTRY
7
19
from torch_tensorrt .dynamo ._defaults import MIN_BLOCK_SIZE
8
- from torch .fx .passes .infra .partitioner import CapabilityBasedPartitioner , Partition
9
- from torch .fx .graph_module import GraphModule
10
20
from torch .fx .node import _get_qualified_name
11
- from torch .fx .passes .operator_support import OperatorSupport
12
21
13
22
from torch_tensorrt .dynamo import DYNAMO_CONVERTERS as CONVERTERS
14
23
21
30
)
22
31
23
32
24
- class TRTPartitioner (CapabilityBasedPartitioner ):
25
- """Partitioner to split an FX graph into subgraphs based on operator support
26
-
27
- Args:
28
- graph_module: FX GraphModule to partition
29
- operator_support: OperatorSupport class describing allowed operators
30
- non_compute_ops: Operators which are not considered computational (e.g. getattr)
31
- allowed_single_node_partition_ops: Nodes which can be included in single-node partitons.
32
- Generally useful for module-level exclusion ops which are intensive despite being single functions
33
- min_block_size: Minimum number of computational operators per block
34
- Returns:
35
- torch.fx.GraphModule
36
- """
37
-
38
- def __init__ (
39
- self ,
40
- graph_module : GraphModule ,
41
- operator_support : OperatorSupport ,
42
- * ,
43
- non_compute_ops : Optional [Sequence [str ]] = None ,
44
- allowed_single_node_partition_ops : Optional [
45
- Sequence [str ]
46
- ] = DEFAULT_SINGLE_NODE_PARTITIONS ,
47
- min_block_size = MIN_BLOCK_SIZE ,
48
- ) -> None :
49
- super ().__init__ (
50
- graph_module ,
51
- operator_support ,
52
- allows_single_node_partition = True ,
53
- non_compute_ops = non_compute_ops ,
54
- allowed_single_node_partition_ops = allowed_single_node_partition_ops ,
55
- )
56
-
57
- self .min_block_size = min_block_size
58
-
59
- def propose_partitions (self ) -> List [Partition ]:
60
- # Propose partitions using the default, then refine the results
61
- initial_proposed_partitions = super ().propose_partitions ()
62
- partitions = {i : part for i , part in enumerate (initial_proposed_partitions )}
63
-
64
- # For each partition, determine whether or not the number of computational operators
65
- # exceeds the threshold, and if not, remove that partition
66
- partitions_to_remove = {}
67
- for id , partition in partitions .items ():
68
- default_non_compute_ops = {"torch.ops.aten.view" , "_operator.getitem" }
69
- non_compute_ops = default_non_compute_ops .union (set (self .non_compute_ops ))
70
- exempted_partition = False
71
-
72
- compute_node_count = 0
73
- for node in partition .nodes :
74
- # Partitions are exempted from min_block_size if they contain an allowed single-node op
75
- if (
76
- node .op == "call_function"
77
- and _get_qualified_name (node .target )
78
- in self .allowed_single_node_partition_ops
79
- ):
80
- exempted_partition = True
81
- break
82
- elif (
83
- node .op == "call_function"
84
- and _get_qualified_name (node .target ) not in non_compute_ops
85
- ):
86
- compute_node_count += 1
87
-
88
- if compute_node_count < self .min_block_size and not exempted_partition :
89
- partitions_to_remove [id ] = compute_node_count
90
-
91
- # Remove any nodes violating the criteria specified by the user
92
- for id , count in partitions_to_remove .items ():
93
- logger .debug (
94
- f"Removing partition which has { count } < { self .min_block_size } computational operators"
95
- )
96
- del partitions [id ]
97
-
98
- return [partitions [k ] for k in sorted (partitions .keys ())]
99
-
100
- def partition_and_fuse (self ) -> GraphModule :
101
- partitions = self .propose_partitions ()
102
- fused_gm = self .fuse_partitions (partitions )
103
- return fused_gm
104
-
105
-
106
- class TorchTensorRTOperatorSupport (OperatorSupport ):
33
+ class OpSupportTester (ops .OperatorSupportBase ):
107
34
"""Class to determine whether operators within a module are supported"""
108
35
109
- def __init__ (self , support_dict = None , torch_executed_ops = set ()):
110
- super ().__init__ (support_dict )
36
+ def __init__ (self , torch_executed_ops : Sequence [ Target ] = set ()) -> None :
37
+ super ().__init__ ()
111
38
112
39
# Initialize sets of supported/unsupported operators
113
40
self .supported_operators = {}
@@ -117,11 +44,7 @@ def __init__(self, support_dict=None, torch_executed_ops=set()):
117
44
def is_node_supported (
118
45
self , submodules : Dict [str , torch .nn .Module ], node : torch .fx .Node
119
46
) -> bool :
120
- node_name = (
121
- _get_qualified_name (node .target )
122
- if not isinstance (node .target , str )
123
- else node .target
124
- )
47
+ node_name = ConverterRegistry .qualified_name_or_str (node .target )
125
48
126
49
if node in CONVERTERS and node_name not in self .torch_executed_ops :
127
50
# If node is a proper, supported computational node, store the operator
@@ -164,11 +87,139 @@ def print_support_overview(self, num_trt_blocks: Optional[int] = None):
164
87
logger .debug ("\n All Nodes Supported\n " )
165
88
166
89
90
+ class TRTPartitioner (_SplitterBase ):
91
+ """Partitioner to split an FX graph into subgraphs based on operator support
92
+
93
+ Adapted from, and modified for the Torch-TensorRT Dynamo case:
94
+ https://github.com/pytorch/pytorch/blob/93f538db355ea10c684a57f7a632ed03292ef98f/torch/fx/passes/splitter_base.py#L256C9-L871
95
+
96
+ Args:
97
+ graph_module: FX GraphModule to partition
98
+ operator_support: OperatorSupport class describing allowed operators
99
+ allowed_single_node_partition_ops: Nodes which can be included in single-node partitons.
100
+ Generally useful for module-level exclusion ops which are intensive despite being single functions
101
+ min_block_size: Minimum number of computational operators per block
102
+ Returns:
103
+ torch.fx.GraphModule
104
+ """
105
+
106
+ def __init__ (
107
+ self ,
108
+ graph_module : torch .fx .GraphModule ,
109
+ operator_support : ops .OperatorSupportBase ,
110
+ allowed_single_node_partition_ops : Optional [
111
+ Sequence [str ]
112
+ ] = DEFAULT_SINGLE_NODE_PARTITIONS ,
113
+ min_block_size : int = MIN_BLOCK_SIZE ,
114
+ ):
115
+ """
116
+ Preprocesses graph before splitting:
117
+ - finds nodes supported by ACC,
118
+ - finds fusion groups for ACC nodes having non-tensor IO,
119
+ - builds a graph of direct dependencies,
120
+ - builds a map of fused nodes to their fusions.
121
+ As a result we get self.acc_nodes, self.deps and self.fusions.
122
+ """
123
+ assert isinstance (graph_module , torch .fx .GraphModule )
124
+
125
+ self .graph_module = graph_module
126
+
127
+ self .settings = _SplitterSettingBase (
128
+ min_acc_module_size = min_block_size , allow_non_tensor = True
129
+ )
130
+ self .operator_support = operator_support
131
+
132
+ # Get all accelerated nodes based on operator support conditions
133
+ self .acc_nodes = FxNetAccNodesFinder (
134
+ self .graph_module , self .operator_support , self .settings .allow_non_tensor
135
+ )()
136
+
137
+ if self .settings .skip_fusion :
138
+ self .fusions = {}
139
+ else :
140
+ self .fusions = FxNetAccFusionsFinder (graph_module , self .acc_nodes )()
141
+
142
+ # Modify deps to add more deps for fused nodes
143
+ self .deps = self .find_deps ()
144
+ self .update_deps_for_fusions ()
145
+
146
+ self .non_acc_submodule_name = "_run_on_gpu_"
147
+ self ._node_submodule_map : Dict [str , str ] = {}
148
+
149
+ self .num_trt_accelerated_subgraphs = None
150
+ self .allowed_single_node_partition_ops = allowed_single_node_partition_ops
151
+
152
+ def remove_small_acc_subgraphs (self , subgraphs : List [Subgraph ]) -> List [Subgraph ]:
153
+ """
154
+ This pass finds ACC submodules with less than specified size and merges
155
+ them with adjacent GPU submodules.
156
+ """
157
+ result : List [Subgraph ] = []
158
+ for subgraph in subgraphs :
159
+ if subgraph .is_acc :
160
+ if len (subgraph .nodes ) >= self .settings .min_acc_module_size or any (
161
+ ConverterRegistry .qualified_name_or_str (node .target )
162
+ in self .allowed_single_node_partition_ops
163
+ for node in subgraph .nodes
164
+ ):
165
+ result .append (subgraph )
166
+ else :
167
+ logger .debug (
168
+ "Eliminating acc subgraph because it's smaller than the threshold: "
169
+ f"{ len (subgraph .nodes )} < { self .settings .min_acc_module_size } "
170
+ )
171
+ if result :
172
+ result [- 1 ].nodes .extend (subgraph .nodes )
173
+ else :
174
+ subgraph .is_acc = False
175
+ result .append (subgraph )
176
+ else :
177
+ if result and not result [- 1 ].is_acc :
178
+ result [- 1 ].nodes .extend (subgraph .nodes )
179
+ else :
180
+ result .append (subgraph )
181
+ return result
182
+
183
+ def partition_graph (self ) -> torch .fx .GraphModule :
184
+ """Partitions the GraphModule into subgraphs based on operator support
185
+
186
+ Returns a GraphModule with submodules for each segment
187
+ """
188
+ # Delegate nodes based on operator coverage
189
+ subgraphs = self .put_nodes_into_subgraphs ()
190
+
191
+ # Remove segments smaller than the block size (with exceptions)
192
+ subgraphs = self .remove_small_acc_subgraphs (subgraphs )
193
+
194
+ # Set the number of TRT engines to be generated
195
+ self .num_trt_accelerated_subgraphs = len ([s for s in subgraphs if s .is_acc ])
196
+
197
+ # Tag the accelerated nodes and split the graph accordingly
198
+ self .tag (subgraphs )
199
+ return self .split ()
200
+
201
+ def starter_nodes (self ) -> Tuple [NodeSet , NodeSet ]:
202
+ """Generates starter nodes for partitioning + segmentation"""
203
+ # Starter accelerated nodes are all callable accelerated ops
204
+ starter_acc_nodes = {
205
+ node for node in self .acc_nodes if node .op in CALLABLE_NODE_OPS
206
+ }
207
+
208
+ # Started non-accelerated nodes are the rest of the callable nodes
209
+ starter_non_acc_nodes = {
210
+ node
211
+ for node in self .graph_module .graph .nodes
212
+ if (node not in starter_acc_nodes and node .op in CALLABLE_NODE_OPS )
213
+ }
214
+
215
+ return starter_non_acc_nodes , starter_acc_nodes
216
+
217
+
167
218
def partition (
168
219
gm : torch .fx .GraphModule ,
169
220
verbose : bool = True ,
170
221
min_block_size : int = MIN_BLOCK_SIZE ,
171
- torch_executed_ops : Sequence [str ] = set (),
222
+ torch_executed_ops : Sequence [Target ] = set (),
172
223
) -> torch .fx .GraphModule :
173
224
"""Partition an FX GraphModule with aten ops into TRT engines
174
225
Partitioning is based on converter operator support
@@ -181,18 +232,21 @@ def partition(
181
232
Returns:
182
233
torch.fx.GraphModule
183
234
"""
184
- supported_ops = TorchTensorRTOperatorSupport (torch_executed_ops = torch_executed_ops )
235
+ # Ensure graph is clean prior to partitioning
236
+ gm .graph .eliminate_dead_code ()
237
+ gm .graph .lint ()
238
+ gm .recompile ()
239
+
240
+ # Construct
241
+ supported_ops = OpSupportTester (torch_executed_ops = torch_executed_ops )
185
242
partitioner = TRTPartitioner (gm , supported_ops , min_block_size = min_block_size )
186
243
187
- # Determine partitions based on user specifications and operator support
188
- # Then, fuse partitions and display overview of supported/unsupported operators
189
- partitions = partitioner .propose_partitions ()
190
- fused_graph = partitioner .fuse_partitions (partitions )
244
+ partitioned_graph = partitioner .partition_graph ()
191
245
192
246
if verbose :
193
- supported_ops .print_support_overview (len ( partitions ) )
247
+ supported_ops .print_support_overview (partitioner . num_trt_accelerated_subgraphs )
194
248
195
- return fused_graph
249
+ return partitioned_graph
196
250
197
251
198
252
def get_submod_inputs (
0 commit comments