|
| 1 | +import logging |
| 2 | +from typing import Dict, List, Optional, Sequence, Tuple |
| 3 | + |
| 4 | +import torch |
| 5 | + |
| 6 | +from torch.fx.passes.splitter_base import ( |
| 7 | + Subgraph, |
| 8 | + _SplitterBase, |
| 9 | + _SplitterSettingBase, |
| 10 | + FxNetAccNodesFinder, |
| 11 | + FxNetAccFusionsFinder, |
| 12 | +) |
| 13 | +import torch.fx.passes.operator_support as ops |
| 14 | +from torch.fx.passes.tools_common import NodeSet, CALLABLE_NODE_OPS |
| 15 | +from torch.fx.node import Target |
| 16 | + |
| 17 | +from torch_tensorrt.dynamo.conversion.converter_registry import ConverterRegistry |
| 18 | +from .common import DEFAULT_SINGLE_NODE_PARTITIONS |
| 19 | +from torch_tensorrt.dynamo._defaults import MIN_BLOCK_SIZE |
| 20 | + |
| 21 | +from torch_tensorrt.dynamo import DYNAMO_CONVERTERS as CONVERTERS |
| 22 | + |
| 23 | + |
| 24 | +logger = logging.getLogger(__name__) |
| 25 | + |
| 26 | + |
| 27 | +class OpSupportTester(ops.OperatorSupportBase): |
| 28 | + """Class to determine whether operators within a module are supported""" |
| 29 | + |
| 30 | + def __init__(self, torch_executed_ops: Sequence[Target] = set()) -> None: |
| 31 | + super().__init__() |
| 32 | + |
| 33 | + # Initialize sets of supported/unsupported operators |
| 34 | + self.supported_operators = {} |
| 35 | + self.unsupported_operators = {} |
| 36 | + self.torch_executed_ops = torch_executed_ops |
| 37 | + |
| 38 | + def is_node_supported( |
| 39 | + self, submodules: Dict[str, torch.nn.Module], node: torch.fx.Node |
| 40 | + ) -> bool: |
| 41 | + node_name = ConverterRegistry.qualified_name_or_str(node.target) |
| 42 | + |
| 43 | + if node in CONVERTERS and node_name not in self.torch_executed_ops: |
| 44 | + # If node is a proper, supported computational node, store the operator |
| 45 | + if not node.is_impure(): |
| 46 | + if node_name not in self.supported_operators: |
| 47 | + self.supported_operators[node_name] = 1 |
| 48 | + else: |
| 49 | + self.supported_operators[node_name] += 1 |
| 50 | + |
| 51 | + return True |
| 52 | + else: |
| 53 | + if not node.is_impure(): |
| 54 | + if node_name not in self.unsupported_operators: |
| 55 | + self.unsupported_operators[node_name] = 1 |
| 56 | + else: |
| 57 | + self.unsupported_operators[node_name] += 1 |
| 58 | + |
| 59 | + return False |
| 60 | + |
| 61 | + def print_support_overview(self, num_trt_blocks: Optional[int] = None): |
| 62 | + if num_trt_blocks is not None: |
| 63 | + logger.debug( |
| 64 | + f"\nNumber of TensorRT-Accelerated Engines Generated: {num_trt_blocks}" |
| 65 | + ) |
| 66 | + |
| 67 | + # Reformat support messages for debugger to print node overview as a single string |
| 68 | + supported_nodes_str = "\nSupported Nodes:\n" |
| 69 | + for node_name, count in self.supported_operators.items(): |
| 70 | + supported_nodes_str += f"- {node_name} + Operator Count: {count}\n" |
| 71 | + |
| 72 | + logger.debug(supported_nodes_str) |
| 73 | + |
| 74 | + if self.unsupported_operators: |
| 75 | + unsupported_nodes_str = "\nUnsupported or Excluded Nodes:\n" |
| 76 | + for node_name, count in self.unsupported_operators.items(): |
| 77 | + unsupported_nodes_str += f"- {node_name} + Operator Count: {count}\n" |
| 78 | + |
| 79 | + logger.debug(unsupported_nodes_str) |
| 80 | + else: |
| 81 | + logger.debug("\nAll Nodes Supported\n") |
| 82 | + |
| 83 | + |
| 84 | +class TRTPartitioner(_SplitterBase): |
| 85 | + """Partitioner to split an FX graph into subgraphs based on operator support |
| 86 | +
|
| 87 | + Adapted from, and modified for the Torch-TensorRT Dynamo case: |
| 88 | + https://github.com/pytorch/pytorch/blob/93f538db355ea10c684a57f7a632ed03292ef98f/torch/fx/passes/splitter_base.py#L256C9-L871 |
| 89 | +
|
| 90 | + Args: |
| 91 | + module: FX GraphModule to partition |
| 92 | + operator_support: OperatorSupport class describing allowed operators |
| 93 | + allowed_single_node_partition_ops: Nodes which can be included in single-node partitons. |
| 94 | + Generally useful for module-level exclusion ops which are intensive despite being single functions |
| 95 | + min_block_size: Minimum number of computational operators per block |
| 96 | + Returns: |
| 97 | + torch.fx.GraphModule |
| 98 | + """ |
| 99 | + |
| 100 | + def __init__( |
| 101 | + self, |
| 102 | + module: torch.fx.GraphModule, |
| 103 | + operator_support: ops.OperatorSupportBase, |
| 104 | + allowed_single_node_partition_ops: Optional[ |
| 105 | + Sequence[str] |
| 106 | + ] = DEFAULT_SINGLE_NODE_PARTITIONS, |
| 107 | + min_block_size: int = MIN_BLOCK_SIZE, |
| 108 | + ): |
| 109 | + """ |
| 110 | + Preprocesses graph before splitting: |
| 111 | + - finds nodes supported by ACC, |
| 112 | + - finds fusion groups for ACC nodes having non-tensor IO, |
| 113 | + - builds a graph of direct dependencies, |
| 114 | + - builds a map of fused nodes to their fusions. |
| 115 | + As a result we get self.acc_nodes, self.deps and self.fusions. |
| 116 | + """ |
| 117 | + assert isinstance(module, torch.fx.GraphModule) |
| 118 | + |
| 119 | + self.module = module |
| 120 | + |
| 121 | + self.settings = _SplitterSettingBase( |
| 122 | + min_acc_module_size=min_block_size, |
| 123 | + allow_non_tensor=True, |
| 124 | + ) |
| 125 | + self.operator_support = operator_support |
| 126 | + |
| 127 | + # Get all accelerated nodes based on operator support conditions |
| 128 | + self.acc_nodes = FxNetAccNodesFinder( |
| 129 | + self.module, self.operator_support, self.settings.allow_non_tensor |
| 130 | + )() |
| 131 | + |
| 132 | + if self.settings.skip_fusion: |
| 133 | + self.fusions = {} |
| 134 | + else: |
| 135 | + self.fusions = FxNetAccFusionsFinder(module, set(self.acc_nodes))() |
| 136 | + |
| 137 | + # Modify deps to add more deps for fused nodes |
| 138 | + self.deps = self.find_deps() |
| 139 | + self.update_deps_for_fusions() |
| 140 | + |
| 141 | + self.non_acc_submodule_name = "_run_on_gpu_" |
| 142 | + self._node_submodule_map: Dict[str, str] = {} |
| 143 | + |
| 144 | + self.num_trt_accelerated_subgraphs = None |
| 145 | + self.allowed_single_node_partition_ops = allowed_single_node_partition_ops |
| 146 | + |
| 147 | + def remove_small_acc_subgraphs(self, subgraphs: List[Subgraph]) -> List[Subgraph]: |
| 148 | + """ |
| 149 | + This pass finds ACC submodules with less than specified size and merges |
| 150 | + them with adjacent GPU submodules. |
| 151 | + """ |
| 152 | + result: List[Subgraph] = [] |
| 153 | + for subgraph in subgraphs: |
| 154 | + if subgraph.is_acc: |
| 155 | + if len(subgraph.nodes) >= self.settings.min_acc_module_size or any( |
| 156 | + ConverterRegistry.qualified_name_or_str(node.target) |
| 157 | + in self.allowed_single_node_partition_ops |
| 158 | + for node in subgraph.nodes |
| 159 | + ): |
| 160 | + result.append(subgraph) |
| 161 | + else: |
| 162 | + logger.debug( |
| 163 | + "Eliminating acc subgraph because it's smaller than the threshold: " |
| 164 | + f"{len(subgraph.nodes)} < {self.settings.min_acc_module_size}" |
| 165 | + ) |
| 166 | + if result: |
| 167 | + result[-1].nodes.extend(subgraph.nodes) |
| 168 | + else: |
| 169 | + subgraph.is_acc = False |
| 170 | + result.append(subgraph) |
| 171 | + else: |
| 172 | + if result and not result[-1].is_acc: |
| 173 | + result[-1].nodes.extend(subgraph.nodes) |
| 174 | + else: |
| 175 | + result.append(subgraph) |
| 176 | + return result |
| 177 | + |
| 178 | + def partition_graph(self) -> torch.fx.GraphModule: |
| 179 | + """Partitions the GraphModule into subgraphs based on operator support |
| 180 | +
|
| 181 | + Returns a GraphModule with submodules for each segment |
| 182 | + """ |
| 183 | + # Delegate nodes based on operator coverage |
| 184 | + subgraphs = self.put_nodes_into_subgraphs() |
| 185 | + |
| 186 | + # Remove segments smaller than the block size (with exceptions) |
| 187 | + subgraphs = self.remove_small_acc_subgraphs(subgraphs) |
| 188 | + |
| 189 | + # Set the number of TRT engines to be generated |
| 190 | + self.num_trt_accelerated_subgraphs = len([s for s in subgraphs if s.is_acc]) |
| 191 | + |
| 192 | + # Tag the accelerated nodes and split the graph accordingly |
| 193 | + self.tag(subgraphs) |
| 194 | + return self.split() |
| 195 | + |
| 196 | + def starter_nodes(self) -> Tuple[NodeSet, NodeSet]: |
| 197 | + """Generates starter nodes for partitioning + segmentation""" |
| 198 | + # Starter accelerated nodes are all callable accelerated ops |
| 199 | + starter_acc_nodes = { |
| 200 | + node for node in self.acc_nodes if node.op in CALLABLE_NODE_OPS |
| 201 | + } |
| 202 | + |
| 203 | + # Started non-accelerated nodes are the rest of the callable nodes |
| 204 | + starter_non_acc_nodes = { |
| 205 | + node |
| 206 | + for node in self.module.graph.nodes |
| 207 | + if (node not in starter_acc_nodes and node.op in CALLABLE_NODE_OPS) |
| 208 | + } |
| 209 | + |
| 210 | + return starter_non_acc_nodes, starter_acc_nodes |
| 211 | + |
| 212 | + |
| 213 | +def partition( |
| 214 | + gm: torch.fx.GraphModule, |
| 215 | + verbose: bool = True, |
| 216 | + min_block_size: int = MIN_BLOCK_SIZE, |
| 217 | + torch_executed_ops: Sequence[Target] = set(), |
| 218 | +) -> torch.fx.GraphModule: |
| 219 | + """Partition an FX GraphModule with aten ops into TRT engines |
| 220 | + Partitioning is based on converter operator support |
| 221 | +
|
| 222 | + Args: |
| 223 | + gm: FX GraphModule to partition |
| 224 | + verbose: Bool representing whether to print operator support |
| 225 | + min_block_size: Minimum number of operators per TRT-Engine Block |
| 226 | + torch_executed_ops: Sequence of operations to run in Torch, regardless of converter coverage |
| 227 | + Returns: |
| 228 | + torch.fx.GraphModule |
| 229 | + """ |
| 230 | + # Ensure graph is clean prior to partitioning |
| 231 | + gm.graph.eliminate_dead_code() |
| 232 | + gm.graph.lint() |
| 233 | + gm.recompile() |
| 234 | + |
| 235 | + # Construct |
| 236 | + supported_ops = OpSupportTester(torch_executed_ops=torch_executed_ops) |
| 237 | + partitioner = TRTPartitioner(gm, supported_ops, min_block_size=min_block_size) |
| 238 | + |
| 239 | + partitioned_graph = partitioner.partition_graph() |
| 240 | + |
| 241 | + if verbose: |
| 242 | + supported_ops.print_support_overview(partitioner.num_trt_accelerated_subgraphs) |
| 243 | + |
| 244 | + return partitioned_graph |
0 commit comments