Skip to content

Commit 6cfbb59

Browse files
committed
feat: Improve Dynamo partitioning system
- Upgrade Dynamo partitioning to use a custom version of the Torch _SplitterBase for efficiency and optimized usage in the Dynamo case - Validate existing use cases are still functional, with the same partitioning schema as before - Upgrade qualified name checking
1 parent 0527edd commit 6cfbb59

File tree

4 files changed

+384
-103
lines changed

4 files changed

+384
-103
lines changed

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,11 @@ def _compile_module(
131131
# Iterate over all components that can be accelerated
132132
# Generate the corresponding TRT Module for those
133133
for name, _ in partitioned_module.named_children():
134+
135+
# Criteria for a module to be convertible to TRT
136+
if "_run_on_acc" not in name:
137+
continue
138+
134139
submodule = getattr(partitioned_module, name)
135140

136141
# Get submodule inputs

py/torch_tensorrt/dynamo/conversion/converter_registry.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,8 @@ def unique_targets(self):
305305
"""Returns the set of unique converter targets stored across all registries"""
306306
return set.union(*[set(registry.keys()) for registry in self.registries])
307307

308-
def qualified_name_or_str(self, target: Target) -> str:
308+
@staticmethod
309+
def qualified_name_or_str(target: Target) -> str:
309310
"""Returns string representation of an FX Node target"""
310311
if isinstance(target, str):
311312
return target

py/torch_tensorrt/dynamo/lowering/_partition.py

Lines changed: 156 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,23 @@
11
import logging
2-
from typing import Dict, List, Optional, Sequence, Set
2+
from typing import Dict, List, Optional, Sequence, Set, Tuple
33

44
import torch
55

6+
from torch.fx.passes.splitter_base import (
7+
Subgraph,
8+
_SplitterBase,
9+
_SplitterSettingBase,
10+
FxNetAccNodesFinder,
11+
FxNetAccFusionsFinder,
12+
)
13+
import torch.fx.passes.operator_support as ops
14+
from torch.fx.passes.tools_common import NodeSet, CALLABLE_NODE_OPS
15+
from torch.fx.node import Target
16+
17+
from torch_tensorrt.dynamo.conversion.converter_registry import ConverterRegistry
618
from torch_tensorrt.dynamo.lowering import SUBSTITUTION_REGISTRY
719
from torch_tensorrt.dynamo._defaults import MIN_BLOCK_SIZE
8-
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
9-
from torch.fx.graph_module import GraphModule
1020
from torch.fx.node import _get_qualified_name
11-
from torch.fx.passes.operator_support import OperatorSupport
1221

1322
from torch_tensorrt.dynamo import DYNAMO_CONVERTERS as CONVERTERS
1423

@@ -21,93 +30,11 @@
2130
)
2231

2332

24-
class TRTPartitioner(CapabilityBasedPartitioner):
25-
"""Partitioner to split an FX graph into subgraphs based on operator support
26-
27-
Args:
28-
graph_module: FX GraphModule to partition
29-
operator_support: OperatorSupport class describing allowed operators
30-
non_compute_ops: Operators which are not considered computational (e.g. getattr)
31-
allowed_single_node_partition_ops: Nodes which can be included in single-node partitons.
32-
Generally useful for module-level exclusion ops which are intensive despite being single functions
33-
min_block_size: Minimum number of computational operators per block
34-
Returns:
35-
torch.fx.GraphModule
36-
"""
37-
38-
def __init__(
39-
self,
40-
graph_module: GraphModule,
41-
operator_support: OperatorSupport,
42-
*,
43-
non_compute_ops: Optional[Sequence[str]] = None,
44-
allowed_single_node_partition_ops: Optional[
45-
Sequence[str]
46-
] = DEFAULT_SINGLE_NODE_PARTITIONS,
47-
min_block_size=MIN_BLOCK_SIZE,
48-
) -> None:
49-
super().__init__(
50-
graph_module,
51-
operator_support,
52-
allows_single_node_partition=True,
53-
non_compute_ops=non_compute_ops,
54-
allowed_single_node_partition_ops=allowed_single_node_partition_ops,
55-
)
56-
57-
self.min_block_size = min_block_size
58-
59-
def propose_partitions(self) -> List[Partition]:
60-
# Propose partitions using the default, then refine the results
61-
initial_proposed_partitions = super().propose_partitions()
62-
partitions = {i: part for i, part in enumerate(initial_proposed_partitions)}
63-
64-
# For each partition, determine whether or not the number of computational operators
65-
# exceeds the threshold, and if not, remove that partition
66-
partitions_to_remove = {}
67-
for id, partition in partitions.items():
68-
default_non_compute_ops = {"torch.ops.aten.view", "_operator.getitem"}
69-
non_compute_ops = default_non_compute_ops.union(set(self.non_compute_ops))
70-
exempted_partition = False
71-
72-
compute_node_count = 0
73-
for node in partition.nodes:
74-
# Partitions are exempted from min_block_size if they contain an allowed single-node op
75-
if (
76-
node.op == "call_function"
77-
and _get_qualified_name(node.target)
78-
in self.allowed_single_node_partition_ops
79-
):
80-
exempted_partition = True
81-
break
82-
elif (
83-
node.op == "call_function"
84-
and _get_qualified_name(node.target) not in non_compute_ops
85-
):
86-
compute_node_count += 1
87-
88-
if compute_node_count < self.min_block_size and not exempted_partition:
89-
partitions_to_remove[id] = compute_node_count
90-
91-
# Remove any nodes violating the criteria specified by the user
92-
for id, count in partitions_to_remove.items():
93-
logger.debug(
94-
f"Removing partition which has {count} < {self.min_block_size} computational operators"
95-
)
96-
del partitions[id]
97-
98-
return [partitions[k] for k in sorted(partitions.keys())]
99-
100-
def partition_and_fuse(self) -> GraphModule:
101-
partitions = self.propose_partitions()
102-
fused_gm = self.fuse_partitions(partitions)
103-
return fused_gm
104-
105-
106-
class TorchTensorRTOperatorSupport(OperatorSupport):
33+
class OpSupportTester(ops.OperatorSupportBase):
10734
"""Class to determine whether operators within a module are supported"""
10835

109-
def __init__(self, support_dict=None, torch_executed_ops=set()):
110-
super().__init__(support_dict)
36+
def __init__(self, torch_executed_ops: Sequence[Target] = set()) -> None:
37+
super().__init__()
11138

11239
# Initialize sets of supported/unsupported operators
11340
self.supported_operators = {}
@@ -117,11 +44,7 @@ def __init__(self, support_dict=None, torch_executed_ops=set()):
11744
def is_node_supported(
11845
self, submodules: Dict[str, torch.nn.Module], node: torch.fx.Node
11946
) -> bool:
120-
node_name = (
121-
_get_qualified_name(node.target)
122-
if not isinstance(node.target, str)
123-
else node.target
124-
)
47+
node_name = ConverterRegistry.qualified_name_or_str(node.target)
12548

12649
if node in CONVERTERS and node_name not in self.torch_executed_ops:
12750
# If node is a proper, supported computational node, store the operator
@@ -164,11 +87,139 @@ def print_support_overview(self, num_trt_blocks: Optional[int] = None):
16487
logger.debug("\nAll Nodes Supported\n")
16588

16689

90+
class TRTPartitioner(_SplitterBase):
91+
"""Partitioner to split an FX graph into subgraphs based on operator support
92+
93+
Adapted from, and modified for the Torch-TensorRT Dynamo case:
94+
https://github.com/pytorch/pytorch/blob/93f538db355ea10c684a57f7a632ed03292ef98f/torch/fx/passes/splitter_base.py#L256C9-L871
95+
96+
Args:
97+
graph_module: FX GraphModule to partition
98+
operator_support: OperatorSupport class describing allowed operators
99+
allowed_single_node_partition_ops: Nodes which can be included in single-node partitons.
100+
Generally useful for module-level exclusion ops which are intensive despite being single functions
101+
min_block_size: Minimum number of computational operators per block
102+
Returns:
103+
torch.fx.GraphModule
104+
"""
105+
106+
def __init__(
107+
self,
108+
graph_module: torch.fx.GraphModule,
109+
operator_support: ops.OperatorSupportBase,
110+
allowed_single_node_partition_ops: Optional[
111+
Sequence[str]
112+
] = DEFAULT_SINGLE_NODE_PARTITIONS,
113+
min_block_size: int = MIN_BLOCK_SIZE,
114+
):
115+
"""
116+
Preprocesses graph before splitting:
117+
- finds nodes supported by ACC,
118+
- finds fusion groups for ACC nodes having non-tensor IO,
119+
- builds a graph of direct dependencies,
120+
- builds a map of fused nodes to their fusions.
121+
As a result we get self.acc_nodes, self.deps and self.fusions.
122+
"""
123+
assert isinstance(graph_module, torch.fx.GraphModule)
124+
125+
self.graph_module = graph_module
126+
127+
self.settings = _SplitterSettingBase(
128+
min_acc_module_size=min_block_size, allow_non_tensor=True
129+
)
130+
self.operator_support = operator_support
131+
132+
# Get all accelerated nodes based on operator support conditions
133+
self.acc_nodes = FxNetAccNodesFinder(
134+
self.graph_module, self.operator_support, self.settings.allow_non_tensor
135+
)()
136+
137+
if self.settings.skip_fusion:
138+
self.fusions = {}
139+
else:
140+
self.fusions = FxNetAccFusionsFinder(graph_module, self.acc_nodes)()
141+
142+
# Modify deps to add more deps for fused nodes
143+
self.deps = self.find_deps()
144+
self.update_deps_for_fusions()
145+
146+
self.non_acc_submodule_name = "_run_on_gpu_"
147+
self._node_submodule_map: Dict[str, str] = {}
148+
149+
self.num_trt_accelerated_subgraphs = None
150+
self.allowed_single_node_partition_ops = allowed_single_node_partition_ops
151+
152+
def remove_small_acc_subgraphs(self, subgraphs: List[Subgraph]) -> List[Subgraph]:
153+
"""
154+
This pass finds ACC submodules with less than specified size and merges
155+
them with adjacent GPU submodules.
156+
"""
157+
result: List[Subgraph] = []
158+
for subgraph in subgraphs:
159+
if subgraph.is_acc:
160+
if len(subgraph.nodes) >= self.settings.min_acc_module_size or any(
161+
ConverterRegistry.qualified_name_or_str(node.target)
162+
in self.allowed_single_node_partition_ops
163+
for node in subgraph.nodes
164+
):
165+
result.append(subgraph)
166+
else:
167+
logger.debug(
168+
"Eliminating acc subgraph because it's smaller than the threshold: "
169+
f"{len(subgraph.nodes)} < {self.settings.min_acc_module_size}"
170+
)
171+
if result:
172+
result[-1].nodes.extend(subgraph.nodes)
173+
else:
174+
subgraph.is_acc = False
175+
result.append(subgraph)
176+
else:
177+
if result and not result[-1].is_acc:
178+
result[-1].nodes.extend(subgraph.nodes)
179+
else:
180+
result.append(subgraph)
181+
return result
182+
183+
def partition_graph(self) -> torch.fx.GraphModule:
184+
"""Partitions the GraphModule into subgraphs based on operator support
185+
186+
Returns a GraphModule with submodules for each segment
187+
"""
188+
# Delegate nodes based on operator coverage
189+
subgraphs = self.put_nodes_into_subgraphs()
190+
191+
# Remove segments smaller than the block size (with exceptions)
192+
subgraphs = self.remove_small_acc_subgraphs(subgraphs)
193+
194+
# Set the number of TRT engines to be generated
195+
self.num_trt_accelerated_subgraphs = len([s for s in subgraphs if s.is_acc])
196+
197+
# Tag the accelerated nodes and split the graph accordingly
198+
self.tag(subgraphs)
199+
return self.split()
200+
201+
def starter_nodes(self) -> Tuple[NodeSet, NodeSet]:
202+
"""Generates starter nodes for partitioning + segmentation"""
203+
# Starter accelerated nodes are all callable accelerated ops
204+
starter_acc_nodes = {
205+
node for node in self.acc_nodes if node.op in CALLABLE_NODE_OPS
206+
}
207+
208+
# Started non-accelerated nodes are the rest of the callable nodes
209+
starter_non_acc_nodes = {
210+
node
211+
for node in self.graph_module.graph.nodes
212+
if (node not in starter_acc_nodes and node.op in CALLABLE_NODE_OPS)
213+
}
214+
215+
return starter_non_acc_nodes, starter_acc_nodes
216+
217+
167218
def partition(
168219
gm: torch.fx.GraphModule,
169220
verbose: bool = True,
170221
min_block_size: int = MIN_BLOCK_SIZE,
171-
torch_executed_ops: Sequence[str] = set(),
222+
torch_executed_ops: Sequence[Target] = set(),
172223
) -> torch.fx.GraphModule:
173224
"""Partition an FX GraphModule with aten ops into TRT engines
174225
Partitioning is based on converter operator support
@@ -181,18 +232,21 @@ def partition(
181232
Returns:
182233
torch.fx.GraphModule
183234
"""
184-
supported_ops = TorchTensorRTOperatorSupport(torch_executed_ops=torch_executed_ops)
235+
# Ensure graph is clean prior to partitioning
236+
gm.graph.eliminate_dead_code()
237+
gm.graph.lint()
238+
gm.recompile()
239+
240+
# Construct
241+
supported_ops = OpSupportTester(torch_executed_ops=torch_executed_ops)
185242
partitioner = TRTPartitioner(gm, supported_ops, min_block_size=min_block_size)
186243

187-
# Determine partitions based on user specifications and operator support
188-
# Then, fuse partitions and display overview of supported/unsupported operators
189-
partitions = partitioner.propose_partitions()
190-
fused_graph = partitioner.fuse_partitions(partitions)
244+
partitioned_graph = partitioner.partition_graph()
191245

192246
if verbose:
193-
supported_ops.print_support_overview(len(partitions))
247+
supported_ops.print_support_overview(partitioner.num_trt_accelerated_subgraphs)
194248

195-
return fused_graph
249+
return partitioned_graph
196250

197251

198252
def get_submod_inputs(

0 commit comments

Comments
 (0)