|
6 | 6 | # pyre-unsafe
|
7 | 7 |
|
8 | 8 | import operator
|
9 |
| -from typing import Type |
| 9 | +from typing import final, Optional, Sequence, Type |
10 | 10 |
|
11 | 11 | import torch.fx as fx
|
12 | 12 | from executorch.backends.arm.tosa_specification import TosaSpecification
|
13 | 13 | from executorch.exir.dialects._ops import ops as exir_ops
|
14 |
| -from torch.fx.passes.operator_support import OperatorSupportBase |
| 14 | +from torch.fx.passes.operator_support import any_chain, chain, OperatorSupportBase |
15 | 15 |
|
16 | 16 |
|
17 |
| -class SupportedTOSAOperatorCheck: |
| 17 | +class SupportedTOSAOperatorCheck(OperatorSupportBase): |
18 | 18 | """
|
19 | 19 | Supported OP for TOSA lowering
|
20 | 20 | """
|
21 | 21 |
|
| 22 | + def __init__(self, tosa_spec: TosaSpecification): |
| 23 | + self.tosa_spec = tosa_spec |
| 24 | + |
22 | 25 | # Should be populated by subclass implementation
|
23 | 26 | tosa_specs: list[TosaSpecification] = []
|
24 | 27 | targets: list[str] = []
|
25 | 28 |
|
26 |
| - def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool: |
| 29 | + @final |
| 30 | + def is_node_supported(self, submodules, node: fx.Node) -> bool: |
| 31 | + if node.target not in self.targets: |
| 32 | + return False |
| 33 | + return self.is_node_tosa_supported(node, self.tosa_spec) |
| 34 | + |
| 35 | + def is_node_tosa_supported( |
| 36 | + self, node: fx.Node, tosa_spec: TosaSpecification |
| 37 | + ) -> bool: |
27 | 38 | """
|
28 | 39 | Checks if the fx.Node node is lowerable using the TOSA specification defined by tosa_spec.
|
29 |
| - To be implemented by subclasses targeting |
30 | 40 | """
|
31 |
| - raise NotImplementedError("NodeVisitor must be extended.") |
| 41 | + raise NotImplementedError("SupportedTOSAOperatorCheck must be extended.") |
32 | 42 |
|
33 | 43 |
|
34 | 44 | # container for all SupportedTosaOperatorCheck classes
|
35 |
| -_tosa_spec_dicts: dict[ |
36 |
| - TosaSpecification, dict[str, Type[SupportedTOSAOperatorCheck]] |
37 |
| -] = { |
38 |
| - TosaSpecification.create_from_string("TOSA-0.80+BI"): {}, |
39 |
| - TosaSpecification.create_from_string("TOSA-0.80+MI"): {}, |
| 45 | +_tosa_spec_support: dict[TosaSpecification, list[Type[SupportedTOSAOperatorCheck]]] = { |
| 46 | + TosaSpecification.create_from_string("TOSA-0.80+BI"): [], |
| 47 | + TosaSpecification.create_from_string("TOSA-0.80+MI"): [], |
40 | 48 | }
|
41 | 49 |
|
42 | 50 |
|
43 |
| -def register_tosa_support_check(checker): |
| 51 | +def register_tosa_support_check(checker: Type[SupportedTOSAOperatorCheck]): |
44 | 52 | """
|
45 | 53 | Decorator to mark a subclass implmentation of SupportedTosaOperatorCheck
|
46 | 54 | to be registered for checking if a torch.fx.Node is lowerable given
|
47 | 55 | a TOSA specification.
|
48 | 56 | """
|
49 | 57 | for tosa_spec in checker.tosa_specs:
|
50 |
| - for target in checker.targets: |
51 |
| - _tosa_spec_dicts[tosa_spec][target] = checker |
| 58 | + _tosa_spec_support[tosa_spec].append(checker) |
52 | 59 | return checker
|
53 | 60 |
|
54 | 61 |
|
55 | 62 | def get_registered_tosa_support_checks(
|
56 | 63 | tosa_spec: TosaSpecification,
|
57 |
| -) -> dict[str, SupportedTOSAOperatorCheck]: |
| 64 | +) -> list[Type[SupportedTOSAOperatorCheck]]: |
58 | 65 |
|
59 |
| - if tosa_spec not in _tosa_spec_dicts: |
| 66 | + if tosa_spec not in _tosa_spec_support: |
60 | 67 | raise RuntimeError
|
61 | 68 |
|
62 |
| - tosa_support_checks = {} |
63 |
| - for target, tosa_check in _tosa_spec_dicts[tosa_spec].items(): |
64 |
| - tosa_support_checks[target] = tosa_check() |
65 |
| - |
66 |
| - return tosa_support_checks |
| 69 | + return _tosa_spec_support[tosa_spec] |
67 | 70 |
|
68 | 71 |
|
69 |
| -class TOSASupportedOperators(OperatorSupportBase): |
70 |
| - def __init__(self, tosa_spec: TosaSpecification): |
71 |
| - super().__init__() |
72 |
| - self.tosa_spec = tosa_spec |
| 72 | +def tosa_support_factory( |
| 73 | + tosa_spec: TosaSpecification, |
| 74 | + additional_checks: Optional[Sequence[OperatorSupportBase]] = None, |
| 75 | +) -> OperatorSupportBase: |
| 76 | + return chain( |
| 77 | + any_chain( |
| 78 | + BaseTOSASupportList(), |
| 79 | + *( |
| 80 | + check(tosa_spec) |
| 81 | + for check in get_registered_tosa_support_checks(tosa_spec) |
| 82 | + ), |
| 83 | + ), |
| 84 | + *additional_checks if additional_checks else [], |
| 85 | + ) |
| 86 | + |
| 87 | + |
| 88 | +class BaseTOSASupportList(OperatorSupportBase): |
73 | 89 |
|
74 | 90 | def is_node_supported(self, submodules, node: fx.Node) -> bool:
|
75 | 91 | supported = node.op == "call_function" and node.target in [
|
@@ -123,18 +139,4 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool:
|
123 | 139 | exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
|
124 | 140 | ]
|
125 | 141 |
|
126 |
| - if not supported: |
127 |
| - supported = self.is_node_supported_custom(node) |
128 |
| - |
129 |
| - # Override partitioning based on pre partition passes |
130 |
| - if "arm_override_partition" in node.meta: |
131 |
| - supported = supported & node.meta["arm_override_partition"] |
132 |
| - node.meta.pop("arm_override_partition") |
133 |
| - |
134 | 142 | return supported
|
135 |
| - |
136 |
| - def is_node_supported_custom(self, node: fx.Node) -> bool: |
137 |
| - tosa_checks = get_registered_tosa_support_checks(self.tosa_spec) |
138 |
| - if node.target in tosa_checks.keys(): |
139 |
| - return tosa_checks[node.target].is_node_supported(node, self.tosa_spec) # type: ignore[index] |
140 |
| - return False |
|
0 commit comments