Skip to content

Commit 292f5ce

Browse files
committed
feat: Improve Dynamo partitioning system
- Upgrade Dynamo partitioning to use a custom version of the Torch _SplitterBase for efficiency and optimized usage in the Dynamo case - Validate existing use cases are still functional, with the same partitioning schema as before - Upgrade qualified name checking - Update testing for new partitioner - Add new directory to store available partitioners
1 parent 0527edd commit 292f5ce

File tree

11 files changed

+355
-47
lines changed

11 files changed

+355
-47
lines changed

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import (
1212
pre_aot_substitutions,
1313
)
14-
from torch_tensorrt.dynamo.lowering._partition import (
14+
from torch_tensorrt.dynamo.partitioning import (
1515
partition,
1616
get_submod_inputs,
1717
)
@@ -131,6 +131,11 @@ def _compile_module(
131131
# Iterate over all components that can be accelerated
132132
# Generate the corresponding TRT Module for those
133133
for name, _ in partitioned_module.named_children():
134+
135+
# Criteria for a module to be convertible to TRT
136+
if "_run_on_acc" not in name:
137+
continue
138+
134139
submodule = getattr(partitioned_module, name)
135140

136141
# Get submodule inputs

py/torch_tensorrt/dynamo/conversion/converter_registry.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,8 @@ def unique_targets(self):
305305
"""Returns the set of unique converter targets stored across all registries"""
306306
return set.union(*[set(registry.keys()) for registry in self.registries])
307307

308-
def qualified_name_or_str(self, target: Target) -> str:
308+
@staticmethod
309+
def qualified_name_or_str(target: Target) -> str:
309310
"""Returns string representation of an FX Node target"""
310311
if isinstance(target, str):
311312
return target

py/torch_tensorrt/dynamo/lowering/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,5 @@
55
SUBSTITUTION_REGISTRY,
66
register_substitution,
77
)
8-
from ._partition import partition, get_submod_inputs, DEFAULT_SINGLE_NODE_PARTITIONS
98
from .substitutions import *
109
from ._fusers import *
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .common import get_submod_inputs
2+
from ._adjacency_partitioner import (
3+
partition,
4+
)
Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
import logging
2+
from typing import Dict, List, Optional, Sequence, Tuple
3+
4+
import torch
5+
6+
from torch.fx.passes.splitter_base import (
7+
Subgraph,
8+
_SplitterBase,
9+
_SplitterSettingBase,
10+
FxNetAccNodesFinder,
11+
FxNetAccFusionsFinder,
12+
)
13+
import torch.fx.passes.operator_support as ops
14+
from torch.fx.passes.tools_common import NodeSet, CALLABLE_NODE_OPS
15+
from torch.fx.node import Target
16+
17+
from torch_tensorrt.dynamo.conversion.converter_registry import ConverterRegistry
18+
from .common import DEFAULT_SINGLE_NODE_PARTITIONS
19+
from torch_tensorrt.dynamo._defaults import MIN_BLOCK_SIZE
20+
21+
from torch_tensorrt.dynamo import DYNAMO_CONVERTERS as CONVERTERS
22+
23+
24+
logger = logging.getLogger(__name__)
25+
26+
27+
class OpSupportTester(ops.OperatorSupportBase):
28+
"""Class to determine whether operators within a module are supported"""
29+
30+
def __init__(self, torch_executed_ops: Sequence[Target] = set()) -> None:
31+
super().__init__()
32+
33+
# Initialize sets of supported/unsupported operators
34+
self.supported_operators = {}
35+
self.unsupported_operators = {}
36+
self.torch_executed_ops = torch_executed_ops
37+
38+
def is_node_supported(
39+
self, submodules: Dict[str, torch.nn.Module], node: torch.fx.Node
40+
) -> bool:
41+
node_name = ConverterRegistry.qualified_name_or_str(node.target)
42+
43+
if node in CONVERTERS and node_name not in self.torch_executed_ops:
44+
# If node is a proper, supported computational node, store the operator
45+
if not node.is_impure():
46+
if node_name not in self.supported_operators:
47+
self.supported_operators[node_name] = 1
48+
else:
49+
self.supported_operators[node_name] += 1
50+
51+
return True
52+
else:
53+
if not node.is_impure():
54+
if node_name not in self.unsupported_operators:
55+
self.unsupported_operators[node_name] = 1
56+
else:
57+
self.unsupported_operators[node_name] += 1
58+
59+
return False
60+
61+
def print_support_overview(self, num_trt_blocks: Optional[int] = None):
62+
if num_trt_blocks is not None:
63+
logger.debug(
64+
f"\nNumber of TensorRT-Accelerated Engines Generated: {num_trt_blocks}"
65+
)
66+
67+
# Reformat support messages for debugger to print node overview as a single string
68+
supported_nodes_str = "\nSupported Nodes:\n"
69+
for node_name, count in self.supported_operators.items():
70+
supported_nodes_str += f"- {node_name} + Operator Count: {count}\n"
71+
72+
logger.debug(supported_nodes_str)
73+
74+
if self.unsupported_operators:
75+
unsupported_nodes_str = "\nUnsupported or Excluded Nodes:\n"
76+
for node_name, count in self.unsupported_operators.items():
77+
unsupported_nodes_str += f"- {node_name} + Operator Count: {count}\n"
78+
79+
logger.debug(unsupported_nodes_str)
80+
else:
81+
logger.debug("\nAll Nodes Supported\n")
82+
83+
84+
class TRTPartitioner(_SplitterBase):
85+
"""Partitioner to split an FX graph into subgraphs based on operator support
86+
87+
Adapted from, and modified for the Torch-TensorRT Dynamo case:
88+
https://github.com/pytorch/pytorch/blob/93f538db355ea10c684a57f7a632ed03292ef98f/torch/fx/passes/splitter_base.py#L256C9-L871
89+
90+
Args:
91+
module: FX GraphModule to partition
92+
operator_support: OperatorSupport class describing allowed operators
93+
allowed_single_node_partition_ops: Nodes which can be included in single-node partitons.
94+
Generally useful for module-level exclusion ops which are intensive despite being single functions
95+
min_block_size: Minimum number of computational operators per block
96+
Returns:
97+
torch.fx.GraphModule
98+
"""
99+
100+
def __init__(
101+
self,
102+
module: torch.fx.GraphModule,
103+
operator_support: ops.OperatorSupportBase,
104+
allowed_single_node_partition_ops: Optional[
105+
Sequence[str]
106+
] = DEFAULT_SINGLE_NODE_PARTITIONS,
107+
min_block_size: int = MIN_BLOCK_SIZE,
108+
):
109+
"""
110+
Preprocesses graph before splitting:
111+
- finds nodes supported by ACC,
112+
- finds fusion groups for ACC nodes having non-tensor IO,
113+
- builds a graph of direct dependencies,
114+
- builds a map of fused nodes to their fusions.
115+
As a result we get self.acc_nodes, self.deps and self.fusions.
116+
"""
117+
assert isinstance(module, torch.fx.GraphModule)
118+
119+
self.module = module
120+
121+
self.settings = _SplitterSettingBase(
122+
min_acc_module_size=min_block_size,
123+
allow_non_tensor=True,
124+
)
125+
self.operator_support = operator_support
126+
127+
# Get all accelerated nodes based on operator support conditions
128+
self.acc_nodes = FxNetAccNodesFinder(
129+
self.module, self.operator_support, self.settings.allow_non_tensor
130+
)()
131+
132+
if self.settings.skip_fusion:
133+
self.fusions = {}
134+
else:
135+
self.fusions = FxNetAccFusionsFinder(module, set(self.acc_nodes))()
136+
137+
# Modify deps to add more deps for fused nodes
138+
self.deps = self.find_deps()
139+
self.update_deps_for_fusions()
140+
141+
self.non_acc_submodule_name = "_run_on_gpu_"
142+
self._node_submodule_map: Dict[str, str] = {}
143+
144+
self.num_trt_accelerated_subgraphs = None
145+
self.allowed_single_node_partition_ops = allowed_single_node_partition_ops
146+
147+
def remove_small_acc_subgraphs(self, subgraphs: List[Subgraph]) -> List[Subgraph]:
148+
"""
149+
This pass finds ACC submodules with less than specified size and merges
150+
them with adjacent GPU submodules.
151+
"""
152+
result: List[Subgraph] = []
153+
for subgraph in subgraphs:
154+
if subgraph.is_acc:
155+
if len(subgraph.nodes) >= self.settings.min_acc_module_size or any(
156+
ConverterRegistry.qualified_name_or_str(node.target)
157+
in self.allowed_single_node_partition_ops
158+
for node in subgraph.nodes
159+
):
160+
result.append(subgraph)
161+
else:
162+
logger.debug(
163+
"Eliminating acc subgraph because it's smaller than the threshold: "
164+
f"{len(subgraph.nodes)} < {self.settings.min_acc_module_size}"
165+
)
166+
if result:
167+
result[-1].nodes.extend(subgraph.nodes)
168+
else:
169+
subgraph.is_acc = False
170+
result.append(subgraph)
171+
else:
172+
if result and not result[-1].is_acc:
173+
result[-1].nodes.extend(subgraph.nodes)
174+
else:
175+
result.append(subgraph)
176+
return result
177+
178+
def partition_graph(self) -> torch.fx.GraphModule:
179+
"""Partitions the GraphModule into subgraphs based on operator support
180+
181+
Returns a GraphModule with submodules for each segment
182+
"""
183+
# Delegate nodes based on operator coverage
184+
subgraphs = self.put_nodes_into_subgraphs()
185+
186+
# Remove segments smaller than the block size (with exceptions)
187+
subgraphs = self.remove_small_acc_subgraphs(subgraphs)
188+
189+
# Set the number of TRT engines to be generated
190+
self.num_trt_accelerated_subgraphs = len([s for s in subgraphs if s.is_acc])
191+
192+
# Tag the accelerated nodes and split the graph accordingly
193+
self.tag(subgraphs)
194+
return self.split()
195+
196+
def starter_nodes(self) -> Tuple[NodeSet, NodeSet]:
197+
"""Generates starter nodes for partitioning + segmentation"""
198+
# Starter accelerated nodes are all callable accelerated ops
199+
starter_acc_nodes = {
200+
node for node in self.acc_nodes if node.op in CALLABLE_NODE_OPS
201+
}
202+
203+
# Started non-accelerated nodes are the rest of the callable nodes
204+
starter_non_acc_nodes = {
205+
node
206+
for node in self.module.graph.nodes
207+
if (node not in starter_acc_nodes and node.op in CALLABLE_NODE_OPS)
208+
}
209+
210+
return starter_non_acc_nodes, starter_acc_nodes
211+
212+
213+
def partition(
214+
gm: torch.fx.GraphModule,
215+
verbose: bool = True,
216+
min_block_size: int = MIN_BLOCK_SIZE,
217+
torch_executed_ops: Sequence[Target] = set(),
218+
) -> torch.fx.GraphModule:
219+
"""Partition an FX GraphModule with aten ops into TRT engines
220+
Partitioning is based on converter operator support
221+
222+
Args:
223+
gm: FX GraphModule to partition
224+
verbose: Bool representing whether to print operator support
225+
min_block_size: Minimum number of operators per TRT-Engine Block
226+
torch_executed_ops: Sequence of operations to run in Torch, regardless of converter coverage
227+
Returns:
228+
torch.fx.GraphModule
229+
"""
230+
# Ensure graph is clean prior to partitioning
231+
gm.graph.eliminate_dead_code()
232+
gm.graph.lint()
233+
gm.recompile()
234+
235+
# Construct
236+
supported_ops = OpSupportTester(torch_executed_ops=torch_executed_ops)
237+
partitioner = TRTPartitioner(gm, supported_ops, min_block_size=min_block_size)
238+
239+
partitioned_graph = partitioner.partition_graph()
240+
241+
if verbose:
242+
supported_ops.print_support_overview(partitioner.num_trt_accelerated_subgraphs)
243+
244+
return partitioned_graph

py/torch_tensorrt/dynamo/lowering/_partition.py renamed to py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py

Lines changed: 2 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import logging
2-
from typing import Dict, List, Optional, Sequence, Set
2+
from typing import Dict, List, Optional, Sequence
33

44
import torch
55

6-
from torch_tensorrt.dynamo.lowering import SUBSTITUTION_REGISTRY
76
from torch_tensorrt.dynamo._defaults import MIN_BLOCK_SIZE
87
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
98
from torch.fx.graph_module import GraphModule
9+
from .common import DEFAULT_SINGLE_NODE_PARTITIONS
1010
from torch.fx.node import _get_qualified_name
1111
from torch.fx.passes.operator_support import OperatorSupport
1212

@@ -15,11 +15,6 @@
1515

1616
logger = logging.getLogger(__name__)
1717

18-
DEFAULT_SINGLE_NODE_PARTITIONS: Set[str] = set(
19-
_get_qualified_name(to_replace.new_operator)
20-
for to_replace in SUBSTITUTION_REGISTRY.values()
21-
)
22-
2318

2419
class TRTPartitioner(CapabilityBasedPartitioner):
2520
"""Partitioner to split an FX graph into subgraphs based on operator support
@@ -193,29 +188,3 @@ def partition(
193188
supported_ops.print_support_overview(len(partitions))
194189

195190
return fused_graph
196-
197-
198-
def get_submod_inputs(
199-
mod: torch.fx.GraphModule,
200-
submod: torch.fx.GraphModule,
201-
inputs: Sequence[torch.Tensor],
202-
) -> Sequence[torch.Tensor]:
203-
"""Helper function to get inputs to a Torch submodule
204-
205-
Args:
206-
mod: Parent FX GraphModule
207-
submod: Child FX GraphModule
208-
inputs: Sample inputs to parent module
209-
Returns:
210-
Sequence of Tensors representing inputs to child module
211-
"""
212-
acc_inputs = None
213-
214-
def get_input(self, inputs):
215-
nonlocal acc_inputs
216-
acc_inputs = inputs
217-
218-
handle = submod.register_forward_pre_hook(get_input)
219-
mod(*inputs)
220-
handle.remove()
221-
return acc_inputs
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import torch
2+
import logging
3+
from typing import Sequence, Set
4+
from torch_tensorrt.dynamo.lowering import SUBSTITUTION_REGISTRY
5+
from torch.fx.node import _get_qualified_name
6+
7+
DEFAULT_SINGLE_NODE_PARTITIONS: Set[str] = set(
8+
_get_qualified_name(to_replace.new_operator)
9+
for to_replace in SUBSTITUTION_REGISTRY.values()
10+
)
11+
12+
13+
logger = logging.getLogger(__name__)
14+
15+
16+
def get_submod_inputs(
17+
mod: torch.fx.GraphModule,
18+
submod: torch.fx.GraphModule,
19+
inputs: Sequence[torch.Tensor],
20+
) -> Sequence[torch.Tensor]:
21+
"""Helper function to get inputs to a Torch submodule
22+
23+
Args:
24+
mod: Parent FX GraphModule
25+
submod: Child FX GraphModule
26+
inputs: Sample inputs to parent module
27+
Returns:
28+
Sequence of Tensors representing inputs to child module
29+
"""
30+
acc_inputs = None
31+
32+
def get_input(self, inputs):
33+
nonlocal acc_inputs
34+
acc_inputs = inputs
35+
36+
handle = submod.register_forward_pre_hook(get_input)
37+
mod(*inputs)
38+
handle.remove()
39+
return acc_inputs

0 commit comments

Comments
 (0)