Skip to content

Commit 6d49702

Browse files
authored
Arm backend: Add additional checks for operator support. (#8367)
* Refactor TOSA support to use chain flow This seems to match the intention of OperatorSupportBase and increases the flexibility of our support flow. New checks if different kinds are simply added to `tosa_support_factory` * Add additional checks for operator support. This can be used to avoid partitioning parts of a model when debugging. Though any OperatorSupportBase can be used, we add three OperatorSupport as utilities: DontPartition: Don't partition based on node target DontPartitionName: Don't partition based on node name DontPartitionModule: Don't partition based on which module the op comes from. All these checks can match parts of the target name, and save a list of the nodes they reject for debugging. Signed-off-by: Erik Lundell <[email protected]>
1 parent 93c3b66 commit 6d49702

File tree

10 files changed

+405
-51
lines changed

10 files changed

+405
-51
lines changed

backends/arm/arm_partitioner.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@
77

88
import logging
99
import os
10-
from typing import Callable, final, List, Optional, Tuple
10+
from typing import Callable, final, List, Optional, Sequence, Tuple
1111

1212
import torch
1313
from executorch.backends.arm.arm_backend import ( # type: ignore[attr-defined]
1414
ArmBackend,
1515
) # usort: skip
1616
from executorch.backends.arm.operator_support.tosa_supported_operators import (
17-
TOSASupportedOperators,
17+
tosa_support_factory,
1818
)
1919
from executorch.backends.arm.tosa_specification import TosaSpecification
2020
from executorch.exir.backend.compile_spec_schema import CompileSpec
@@ -27,6 +27,8 @@
2727
from executorch.exir.dialects._ops import ops as exir_ops
2828
from torch.export.exported_program import ExportedProgram
2929
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
30+
from torch.fx.passes.operator_support import OperatorSupportBase
31+
3032

3133
logger = logging.getLogger(__name__)
3234
logger.setLevel(logging.WARNING)
@@ -54,8 +56,13 @@ def is_dequant_node(node: torch.fx.node.Node) -> bool:
5456

5557
@final
5658
class ArmPartitioner(Partitioner):
57-
def __init__(self, compile_spec: List[CompileSpec]) -> None:
59+
def __init__(
60+
self,
61+
compile_spec: List[CompileSpec],
62+
additional_checks: Optional[Sequence[OperatorSupportBase]] = None,
63+
) -> None:
5864
self.delegation_spec = DelegationSpec(ArmBackend.__name__, compile_spec)
65+
self.additional_checks = additional_checks
5966

6067
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
6168
# Run the CapabilityBasedPartitioner to return the largest possible
@@ -72,7 +79,7 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
7279

7380
capability_partitioner = CapabilityBasedPartitioner(
7481
exported_program.graph_module,
75-
TOSASupportedOperators(tosa_spec),
82+
tosa_support_factory(tosa_spec, self.additional_checks),
7683
allows_single_node_partition=True,
7784
)
7885
partition_list = capability_partitioner.propose_partitions()

backends/arm/operator_support/convolution_support.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class ConvolutionSupported(SupportedTOSAOperatorCheck):
2424
TosaSpecification.create_from_string("TOSA-0.80+MI"),
2525
]
2626

27-
def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
27+
def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
2828

2929
# Not implemented
3030
transposed = cast(bool, node.args[6])

backends/arm/operator_support/pool_2d_support.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class AvgPool2dSupported(SupportedTOSAOperatorCheck):
4343
TosaSpecification.create_from_string("TOSA-0.80+MI"),
4444
]
4545

46-
def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
46+
def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
4747
if not (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset):
4848
return True
4949

@@ -73,7 +73,7 @@ class MaxPool2dSupported(SupportedTOSAOperatorCheck):
7373
TosaSpecification.create_from_string("TOSA-0.80+MI"),
7474
]
7575

76-
def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
76+
def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
7777
if not (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset):
7878
return True
7979

backends/arm/operator_support/reduce_sum_support.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class SumSupported(SupportedTOSAOperatorCheck):
2323
TosaSpecification.create_from_string("TOSA-0.80+MI"),
2424
]
2525

26-
def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
26+
def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
2727
if not (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset):
2828
return True
2929

backends/arm/operator_support/right_shift_support.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -29,7 +29,7 @@ class RightShiftSupported(SupportedTOSAOperatorCheck):
2929
TosaSpecification.create_from_string("TOSA-0.80+MI"),
3030
]
3131

32-
def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
32+
def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
3333

3434
# TODO MLETORCH-525 Remove warning
3535
if isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset:

backends/arm/operator_support/to_copy_support.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,9 @@ def _merge_supported_types(
7070
)
7171
POSSIBLE_TYPE_CONVERSIONS = {torch.int64: torch.int32}
7272

73-
def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool:
73+
def is_node_tosa_supported(
74+
self, node: fx.Node, tosa_spec: TosaSpecification
75+
) -> bool:
7476
assert node.target in self.targets
7577

7678
if tosa_spec not in self.tosa_specs:

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 41 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -6,70 +6,86 @@
66
# pyre-unsafe
77

88
import operator
9-
from typing import Type
9+
from typing import final, Optional, Sequence, Type
1010

1111
import torch.fx as fx
1212
from executorch.backends.arm.tosa_specification import TosaSpecification
1313
from executorch.exir.dialects._ops import ops as exir_ops
14-
from torch.fx.passes.operator_support import OperatorSupportBase
14+
from torch.fx.passes.operator_support import any_chain, chain, OperatorSupportBase
1515

1616

17-
class SupportedTOSAOperatorCheck:
17+
class SupportedTOSAOperatorCheck(OperatorSupportBase):
1818
"""
1919
Supported OP for TOSA lowering
2020
"""
2121

22+
def __init__(self, tosa_spec: TosaSpecification):
23+
self.tosa_spec = tosa_spec
24+
2225
# Should be populated by subclass implementation
2326
tosa_specs: list[TosaSpecification] = []
2427
targets: list[str] = []
2528

26-
def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool:
29+
@final
30+
def is_node_supported(self, submodules, node: fx.Node) -> bool:
31+
if node.target not in self.targets:
32+
return False
33+
return self.is_node_tosa_supported(node, self.tosa_spec)
34+
35+
def is_node_tosa_supported(
36+
self, node: fx.Node, tosa_spec: TosaSpecification
37+
) -> bool:
2738
"""
2839
Checks if the fx.Node node is lowerable using the TOSA specification defined by tosa_spec.
29-
To be implemented by subclasses targeting
3040
"""
31-
raise NotImplementedError("NodeVisitor must be extended.")
41+
raise NotImplementedError("SupportedTOSAOperatorCheck must be extended.")
3242

3343

3444
# container for all SupportedTosaOperatorCheck classes
35-
_tosa_spec_dicts: dict[
36-
TosaSpecification, dict[str, Type[SupportedTOSAOperatorCheck]]
37-
] = {
38-
TosaSpecification.create_from_string("TOSA-0.80+BI"): {},
39-
TosaSpecification.create_from_string("TOSA-0.80+MI"): {},
45+
_tosa_spec_support: dict[TosaSpecification, list[Type[SupportedTOSAOperatorCheck]]] = {
46+
TosaSpecification.create_from_string("TOSA-0.80+BI"): [],
47+
TosaSpecification.create_from_string("TOSA-0.80+MI"): [],
4048
}
4149

4250

43-
def register_tosa_support_check(checker):
51+
def register_tosa_support_check(checker: Type[SupportedTOSAOperatorCheck]):
4452
"""
4553
Decorator to mark a subclass implmentation of SupportedTosaOperatorCheck
4654
to be registered for checking if a torch.fx.Node is lowerable given
4755
a TOSA specification.
4856
"""
4957
for tosa_spec in checker.tosa_specs:
50-
for target in checker.targets:
51-
_tosa_spec_dicts[tosa_spec][target] = checker
58+
_tosa_spec_support[tosa_spec].append(checker)
5259
return checker
5360

5461

5562
def get_registered_tosa_support_checks(
5663
tosa_spec: TosaSpecification,
57-
) -> dict[str, SupportedTOSAOperatorCheck]:
64+
) -> list[Type[SupportedTOSAOperatorCheck]]:
5865

59-
if tosa_spec not in _tosa_spec_dicts:
66+
if tosa_spec not in _tosa_spec_support:
6067
raise RuntimeError
6168

62-
tosa_support_checks = {}
63-
for target, tosa_check in _tosa_spec_dicts[tosa_spec].items():
64-
tosa_support_checks[target] = tosa_check()
65-
66-
return tosa_support_checks
69+
return _tosa_spec_support[tosa_spec]
6770

6871

69-
class TOSASupportedOperators(OperatorSupportBase):
70-
def __init__(self, tosa_spec: TosaSpecification):
71-
super().__init__()
72-
self.tosa_spec = tosa_spec
72+
def tosa_support_factory(
73+
tosa_spec: TosaSpecification,
74+
additional_checks: Optional[Sequence[OperatorSupportBase]] = None,
75+
) -> OperatorSupportBase:
76+
return chain(
77+
any_chain(
78+
BaseTOSASupportList(),
79+
*(
80+
check(tosa_spec)
81+
for check in get_registered_tosa_support_checks(tosa_spec)
82+
),
83+
),
84+
*additional_checks if additional_checks else [],
85+
)
86+
87+
88+
class BaseTOSASupportList(OperatorSupportBase):
7389

7490
def is_node_supported(self, submodules, node: fx.Node) -> bool:
7591
supported = node.op == "call_function" and node.target in [
@@ -123,18 +139,4 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool:
123139
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
124140
]
125141

126-
if not supported:
127-
supported = self.is_node_supported_custom(node)
128-
129-
# Override partitioning based on pre partition passes
130-
if "arm_override_partition" in node.meta:
131-
supported = supported & node.meta["arm_override_partition"]
132-
node.meta.pop("arm_override_partition")
133-
134142
return supported
135-
136-
def is_node_supported_custom(self, node: fx.Node) -> bool:
137-
tosa_checks = get_registered_tosa_support_checks(self.tosa_spec)
138-
if node.target in tosa_checks.keys():
139-
return tosa_checks[node.target].is_node_supported(node, self.tosa_spec) # type: ignore[index]
140-
return False

0 commit comments

Comments
 (0)