Skip to content

Add additional checks for operator support. #8367

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions backends/arm/arm_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@

import logging
import os
from typing import Callable, final, List, Optional, Tuple
from typing import Callable, final, List, Optional, Sequence, Tuple

import torch
from executorch.backends.arm.arm_backend import ( # type: ignore[attr-defined]
ArmBackend,
) # usort: skip
from executorch.backends.arm.operator_support.tosa_supported_operators import (
TOSASupportedOperators,
tosa_support_factory,
)
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.exir.backend.compile_spec_schema import CompileSpec
Expand All @@ -27,6 +27,8 @@
from executorch.exir.dialects._ops import ops as exir_ops
from torch.export.exported_program import ExportedProgram
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch.fx.passes.operator_support import OperatorSupportBase


logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)
Expand Down Expand Up @@ -54,8 +56,13 @@ def is_dequant_node(node: torch.fx.node.Node) -> bool:

@final
class ArmPartitioner(Partitioner):
def __init__(self, compile_spec: List[CompileSpec]) -> None:
def __init__(
self,
compile_spec: List[CompileSpec],
additional_checks: Optional[Sequence[OperatorSupportBase]] = None,
) -> None:
self.delegation_spec = DelegationSpec(ArmBackend.__name__, compile_spec)
self.additional_checks = additional_checks

def partition(self, exported_program: ExportedProgram) -> PartitionResult:
# Run the CapabilityBasedPartitioner to return the largest possible
Expand All @@ -72,7 +79,7 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:

capability_partitioner = CapabilityBasedPartitioner(
exported_program.graph_module,
TOSASupportedOperators(tosa_spec),
tosa_support_factory(tosa_spec, self.additional_checks),
allows_single_node_partition=True,
)
partition_list = capability_partitioner.propose_partitions()
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/operator_support/convolution_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class ConvolutionSupported(SupportedTOSAOperatorCheck):
TosaSpecification.create_from_string("TOSA-0.80+MI"),
]

def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):

# Not implemented
transposed = cast(bool, node.args[6])
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/operator_support/pool_2d_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class AvgPool2dSupported(SupportedTOSAOperatorCheck):
TosaSpecification.create_from_string("TOSA-0.80+MI"),
]

def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
if not (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset):
return True

Expand Down Expand Up @@ -73,7 +73,7 @@ class MaxPool2dSupported(SupportedTOSAOperatorCheck):
TosaSpecification.create_from_string("TOSA-0.80+MI"),
]

def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
if not (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset):
return True

Expand Down
2 changes: 1 addition & 1 deletion backends/arm/operator_support/reduce_sum_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class SumSupported(SupportedTOSAOperatorCheck):
TosaSpecification.create_from_string("TOSA-0.80+MI"),
]

def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
if not (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset):
return True

Expand Down
4 changes: 2 additions & 2 deletions backends/arm/operator_support/right_shift_support.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -29,7 +29,7 @@ class RightShiftSupported(SupportedTOSAOperatorCheck):
TosaSpecification.create_from_string("TOSA-0.80+MI"),
]

def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):

# TODO MLETORCH-525 Remove warning
if isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset:
Expand Down
4 changes: 3 additions & 1 deletion backends/arm/operator_support/to_copy_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ def _merge_supported_types(
)
POSSIBLE_TYPE_CONVERSIONS = {torch.int64: torch.int32}

def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool:
def is_node_tosa_supported(
self, node: fx.Node, tosa_spec: TosaSpecification
) -> bool:
assert node.target in self.targets

if tosa_spec not in self.tosa_specs:
Expand Down
80 changes: 41 additions & 39 deletions backends/arm/operator_support/tosa_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,70 +6,86 @@
# pyre-unsafe

import operator
from typing import Type
from typing import final, Optional, Sequence, Type

import torch.fx as fx
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.exir.dialects._ops import ops as exir_ops
from torch.fx.passes.operator_support import OperatorSupportBase
from torch.fx.passes.operator_support import any_chain, chain, OperatorSupportBase


class SupportedTOSAOperatorCheck:
class SupportedTOSAOperatorCheck(OperatorSupportBase):
"""
Supported OP for TOSA lowering
"""

def __init__(self, tosa_spec: TosaSpecification):
self.tosa_spec = tosa_spec

# Should be populated by subclass implementation
tosa_specs: list[TosaSpecification] = []
targets: list[str] = []

def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool:
@final
def is_node_supported(self, submodules, node: fx.Node) -> bool:
if node.target not in self.targets:
return False
return self.is_node_tosa_supported(node, self.tosa_spec)

def is_node_tosa_supported(
self, node: fx.Node, tosa_spec: TosaSpecification
) -> bool:
"""
Checks if the fx.Node node is lowerable using the TOSA specification defined by tosa_spec.
To be implemented by subclasses targeting
"""
raise NotImplementedError("NodeVisitor must be extended.")
raise NotImplementedError("SupportedTOSAOperatorCheck must be extended.")


# container for all SupportedTosaOperatorCheck classes
_tosa_spec_dicts: dict[
TosaSpecification, dict[str, Type[SupportedTOSAOperatorCheck]]
] = {
TosaSpecification.create_from_string("TOSA-0.80+BI"): {},
TosaSpecification.create_from_string("TOSA-0.80+MI"): {},
_tosa_spec_support: dict[TosaSpecification, list[Type[SupportedTOSAOperatorCheck]]] = {
TosaSpecification.create_from_string("TOSA-0.80+BI"): [],
TosaSpecification.create_from_string("TOSA-0.80+MI"): [],
}


def register_tosa_support_check(checker):
def register_tosa_support_check(checker: Type[SupportedTOSAOperatorCheck]):
"""
Decorator to mark a subclass implmentation of SupportedTosaOperatorCheck
to be registered for checking if a torch.fx.Node is lowerable given
a TOSA specification.
"""
for tosa_spec in checker.tosa_specs:
for target in checker.targets:
_tosa_spec_dicts[tosa_spec][target] = checker
_tosa_spec_support[tosa_spec].append(checker)
return checker


def get_registered_tosa_support_checks(
tosa_spec: TosaSpecification,
) -> dict[str, SupportedTOSAOperatorCheck]:
) -> list[Type[SupportedTOSAOperatorCheck]]:

if tosa_spec not in _tosa_spec_dicts:
if tosa_spec not in _tosa_spec_support:
raise RuntimeError

tosa_support_checks = {}
for target, tosa_check in _tosa_spec_dicts[tosa_spec].items():
tosa_support_checks[target] = tosa_check()

return tosa_support_checks
return _tosa_spec_support[tosa_spec]


class TOSASupportedOperators(OperatorSupportBase):
def __init__(self, tosa_spec: TosaSpecification):
super().__init__()
self.tosa_spec = tosa_spec
def tosa_support_factory(
tosa_spec: TosaSpecification,
additional_checks: Optional[Sequence[OperatorSupportBase]] = None,
) -> OperatorSupportBase:
return chain(
any_chain(
BaseTOSASupportList(),
*(
check(tosa_spec)
for check in get_registered_tosa_support_checks(tosa_spec)
),
),
*additional_checks if additional_checks else [],
)


class BaseTOSASupportList(OperatorSupportBase):

def is_node_supported(self, submodules, node: fx.Node) -> bool:
supported = node.op == "call_function" and node.target in [
Expand Down Expand Up @@ -123,18 +139,4 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool:
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
]

if not supported:
supported = self.is_node_supported_custom(node)

# Override partitioning based on pre partition passes
if "arm_override_partition" in node.meta:
supported = supported & node.meta["arm_override_partition"]
node.meta.pop("arm_override_partition")

return supported

def is_node_supported_custom(self, node: fx.Node) -> bool:
tosa_checks = get_registered_tosa_support_checks(self.tosa_spec)
if node.target in tosa_checks.keys():
return tosa_checks[node.target].is_node_supported(node, self.tosa_spec) # type: ignore[index]
return False
Loading
Loading