From a6537cff7e2f4252e9ce3c6296de8c64a16fcd4a Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Wed, 27 Dec 2023 18:39:13 -0800 Subject: [PATCH] fix: Repair usage of `torch_executed_ops` - Previously, `torch_executed_ops` were excluded at partitioning time, but not conversion time, causing a bug with obscure usages of `getitem` - Now, `torch_executed_ops` are excluded at partitioning time and their converters are explicitly disabled --- py/torch_tensorrt/dynamo/_compiler.py | 15 ++++++++--- py/torch_tensorrt/dynamo/_settings.py | 7 +++--- .../dynamo/conversion/_ConverterRegistry.py | 25 +++++++++++++++++++ .../partitioning/_adjacency_partitioner.py | 6 +++-- .../partitioning/_global_partitioner.py | 25 ++++++++----------- tests/py/dynamo/models/test_dyn_models.py | 5 ++-- tests/py/dynamo/models/test_export_serde.py | 5 ++-- 7 files changed, 61 insertions(+), 27 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 5edb7cc32a..9cfd190441 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -2,10 +2,11 @@ import collections.abc import logging -from typing import Any, List, Optional, Sequence, Set, Tuple, Union +from typing import Any, Collection, List, Optional, Sequence, Set, Tuple, Union import torch from torch.export import ExportedProgram +from torch.fx.node import Target from torch_tensorrt._Device import Device from torch_tensorrt._enums import ( # TODO: Should probabably be the TRT EngineCapability Enum EngineCapability, @@ -49,6 +50,9 @@ convert_module, repair_long_or_double_inputs, ) +from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( + DYNAMO_CONVERTERS as CONVERTERS, +) from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions from torch_tensorrt.dynamo.utils import ( get_torch_inputs, @@ -85,7 +89,7 @@ def compile( truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE, require_full_compilation: bool = REQUIRE_FULL_COMPILATION, min_block_size: int = MIN_BLOCK_SIZE, - torch_executed_ops: Optional[List[str]] = None, + torch_executed_ops: Optional[Collection[Target]] = None, torch_executed_modules: Optional[List[str]] = None, pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES, max_aux_streams: Optional[int] = MAX_AUX_STREAMS, @@ -143,7 +147,7 @@ def compile( calibrator (Union(torch_tensorrt._C.IInt8Calibrator, tensorrt.IInt8Calibrator)): Calibrator object which will provide data to the PTQ system for INT8 Calibration require_full_compilation (bool): Require modules to be compiled end to end or return an error as opposed to returning a hybrid graph where operations that cannot be run in TensorRT are run in PyTorch min_block_size (int): The minimum number of contiguous TensorRT convertable operations in order to run a set of operations in TensorRT - torch_executed_ops (List[str]): List of aten operators that must be run in PyTorch. An error will be thrown if this list is not empty but ``require_full_compilation`` is True + torch_executed_ops (Collection[Target]): Set of aten operators that must be run in PyTorch. An error will be thrown if this set is not empty but ``require_full_compilation`` is True torch_executed_modules (List[str]): List of modules that must be run in PyTorch. An error will be thrown if this list is not empty but ``require_full_compilation`` is True pass_through_build_failures (bool): Error out if there are issues during compilation (only applicable to torch.compile workflows) max_aux_stream (Optional[int]): Maximum streams in the engine @@ -212,7 +216,7 @@ def compile( "min_block_size": min_block_size, "torch_executed_ops": torch_executed_ops if torch_executed_ops is not None - else [], + else set(), "pass_through_build_failures": pass_through_build_failures, "max_aux_streams": max_aux_streams, "version_compatible": version_compatible, @@ -256,6 +260,9 @@ def compile_module( """ dryrun_tracker = DryRunTracker() + # Set torch-executed ops + CONVERTERS.set_disallowed_targets(settings.torch_executed_ops) + # Check the number of supported operations in the graph num_supported_ops, total_ops = partitioning.get_graph_converter_support( gm, settings.debug, settings.torch_executed_ops diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 2992496665..2420a227d8 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -1,8 +1,9 @@ from dataclasses import dataclass, field -from typing import Optional, Set, Union +from typing import Collection, Optional, Union import torch from tensorrt import EngineCapability +from torch.fx.node import Target from torch_tensorrt._Device import Device from torch_tensorrt.dynamo._defaults import ( DEBUG, @@ -41,7 +42,7 @@ class CompilationSettings: debug (bool): Whether to print out verbose debugging information workspace_size (int): Workspace TRT is allowed to use for the module (0 is default) min_block_size (int): Minimum number of operators per TRT-Engine Block - torch_executed_ops (Sequence[str]): Sequence of operations to run in Torch, regardless of converter coverage + torch_executed_ops (Collection[Target]): Collection of operations to run in Torch, regardless of converter coverage pass_through_build_failures (bool): Whether to fail on TRT engine build errors (True) or not (False) max_aux_streams (Optional[int]): Maximum number of allowed auxiliary TRT streams for each engine version_compatible (bool): Provide version forward-compatibility for engine plan files @@ -75,7 +76,7 @@ class CompilationSettings: debug: bool = DEBUG workspace_size: int = WORKSPACE_SIZE min_block_size: int = MIN_BLOCK_SIZE - torch_executed_ops: Set[str] = field(default_factory=set) + torch_executed_ops: Collection[Target] = field(default_factory=set) pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES max_aux_streams: Optional[int] = MAX_AUX_STREAMS version_compatible: bool = VERSION_COMPATIBLE diff --git a/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py b/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py index d689de3e54..050a62ef3e 100644 --- a/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py +++ b/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py @@ -6,6 +6,7 @@ from typing import ( Any, Callable, + Collection, Dict, List, Optional, @@ -212,8 +213,16 @@ def __init__( CallingConvention.CTX for _ in range(len(self.registries)) ] + self.disallowed_targets: Collection[Target] = set() + self.validate_invariants() + def set_disallowed_targets(self, torch_executed_ops: Collection[Target]) -> None: + self.disallowed_targets = torch_executed_ops + + def get_disallowed_targets(self, torch_executed_ops: Collection[Target]) -> None: + self.disallowed_targets = torch_executed_ops + def validate_invariants(self) -> None: """Validates the invariants required of the dictionaries in the registries @@ -253,6 +262,14 @@ def __getitem_without_validation__( self.validate_invariants() + if ( + key in self.disallowed_targets + or self.qualified_name_or_str(key) in self.disallowed_targets + ): + raise KeyError( + f"A converter exists for {key}, but it was " "explicitly disallowed" + ) + # Iterate over all registries and return the first converter found for registry, calling_convention in zip( self.registries, self.registry_calling_conventions @@ -288,6 +305,14 @@ def __getitem__( self.validate_invariants() key = node.target + if ( + key in self.disallowed_targets + or self.qualified_name_or_str(key) in self.disallowed_targets + ): + raise KeyError( + f"A converter exists for {key}, but it was " "explicitly disallowed" + ) + # Iterate over all registries, validating the converter on the input node # If no capability_validator function is found, assume full coverage for registry, calling_convention in zip( diff --git a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py index 5ec5293474..e263b51bb2 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py @@ -42,8 +42,10 @@ def is_node_supported( node_name = ConverterRegistry.qualified_name_or_str(node.target) if ( - node in CONVERTERS or node.op == "get_attr" - ) and node_name not in self.torch_executed_ops: + (node in CONVERTERS or node.op == "get_attr") + and node_name not in self.torch_executed_ops + and node.target not in self.torch_executed_ops + ): # If node is a proper, supported computational node, store the operator if not node.is_impure() and node.op != "get_attr": if node_name not in self.supported_operators: diff --git a/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py index 4c8efb234e..5982cc95ba 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py @@ -1,8 +1,9 @@ import logging -from typing import Collection, Dict, List, Mapping, Optional, Sequence, Set, Tuple +from typing import Collection, Dict, List, Mapping, Optional, Sequence, Tuple import torch from torch.fx.graph_module import GraphModule +from torch.fx.node import Target from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition from torch.fx.passes.operator_support import OperatorSupport, SupportDict from torch_tensorrt.dynamo._defaults import ( @@ -133,16 +134,14 @@ class TorchTensorRTOperatorSupport(OperatorSupport): # type: ignore[misc] def __init__( self, support_dict: Optional[SupportDict] = None, - torch_executed_ops: Optional[Set[str]] = None, + torch_executed_ops: Collection[Target] = set(), ): super().__init__(support_dict) # Initialize sets of supported/unsupported operators self.supported_operators: Dict[str, int] = {} self.unsupported_operators: Dict[str, int] = {} - self.torch_executed_ops: Set[str] = ( - torch_executed_ops if torch_executed_ops is not None else set() - ) + self.torch_executed_ops: Collection[Target] = torch_executed_ops def is_node_supported( self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node @@ -150,8 +149,10 @@ def is_node_supported( node_name = ConverterRegistry.qualified_name_or_str(node.target) if ( - node in CONVERTERS or node.op == "get_attr" - ) and node_name not in self.torch_executed_ops: + (node in CONVERTERS or node.op == "get_attr") + and node_name not in self.torch_executed_ops + and node.target not in self.torch_executed_ops + ): # If node is a proper, supported computational node, store the operator if not node.is_impure() and node.op != "get_attr": if node_name not in self.supported_operators: @@ -201,7 +202,7 @@ def partition( gm: torch.fx.GraphModule, verbose: bool = DEBUG, min_block_size: int = MIN_BLOCK_SIZE, - torch_executed_ops: Optional[Set[str]] = None, + torch_executed_ops: Collection[Target] = set(), require_full_compilation: bool = REQUIRE_FULL_COMPILATION, ) -> Tuple[torch.fx.GraphModule, TorchTensorRTOperatorSupport]: """Partition an FX GraphModule with aten ops into TRT engines @@ -211,16 +212,12 @@ def partition( gm: FX GraphModule to partition verbose: Bool representing whether to print operator support min_block_size: Minimum number of operators per TRT-Engine Block - torch_executed_ops: Sequence of operations to run in Torch, regardless of converter coverage + torch_executed_ops: Collection of operations to run in Torch, regardless of converter coverage require_full_compilation: Whether to require that all operators be run in TRT Returns: torch.fx.GraphModule, TorchTensorRTOperatorSupport """ - supported_ops = TorchTensorRTOperatorSupport( - torch_executed_ops=torch_executed_ops - if torch_executed_ops is not None - else set() - ) + supported_ops = TorchTensorRTOperatorSupport(torch_executed_ops=torch_executed_ops) partitioner = TRTPartitioner( gm, supported_ops, diff --git a/tests/py/dynamo/models/test_dyn_models.py b/tests/py/dynamo/models/test_dyn_models.py index d110845145..ceb4a6dd2c 100644 --- a/tests/py/dynamo/models/test_dyn_models.py +++ b/tests/py/dynamo/models/test_dyn_models.py @@ -3,9 +3,10 @@ import pytest import timm import torch -import torch_tensorrt as torchtrt from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity +import torch_tensorrt as torchtrt + assertions = unittest.TestCase() @@ -97,7 +98,7 @@ def forward(self, x): "ir": ir, "pass_through_build_failures": True, "optimization_level": 1, - "torch_executed_ops": "torch.ops.aten.abs.default", + "torch_executed_ops": {"torch.ops.aten.abs.default"}, "min_block_size": 1, } diff --git a/tests/py/dynamo/models/test_export_serde.py b/tests/py/dynamo/models/test_export_serde.py index f5911cb940..ea7700443c 100644 --- a/tests/py/dynamo/models/test_export_serde.py +++ b/tests/py/dynamo/models/test_export_serde.py @@ -3,10 +3,11 @@ import pytest import timm import torch -import torch_tensorrt as torchtrt import torchvision.models as models from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity +import torch_tensorrt as torchtrt + assertions = unittest.TestCase() @@ -206,7 +207,7 @@ def forward(self, x): ], "ir": ir, "min_block_size": 1, - "torch_executed_ops": "torch.ops.aten.relu.default", + "torch_executed_ops": {"torch.ops.aten.relu.default"}, } exp_program = torchtrt.dynamo.trace(model, **compile_spec)