Skip to content

Commit 12b1330

Browse files
committed
init
1 parent 740135f commit 12b1330

File tree

2 files changed

+61
-1
lines changed

2 files changed

+61
-1
lines changed

backends/apple/coreml/partition/coreml_partitioner.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Please refer to the license found in the LICENSE file in the root directory of the source tree.
44

55
import logging
6-
from typing import List, Optional
6+
from typing import Callable, List, Optional, Tuple
77

88
import coremltools as ct
99

@@ -104,3 +104,17 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
104104
return PartitionResult(
105105
tagged_exported_program=exported_program, partition_tags=partition_tags
106106
)
107+
108+
def ops_to_not_decompose(
109+
self, ep: ExportedProgram
110+
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
111+
do_not_decompose = []
112+
op_support = OperatorsSupportedForCoreMLBackend()
113+
for node in ep.graph.nodes:
114+
if (
115+
node.op == "call_function"
116+
and isinstance(node.target, torch._ops.OpOverload)
117+
and op_support.is_node_supported(None, node)
118+
):
119+
do_not_decompose.append(node.target)
120+
return do_not_decompose, None

backends/apple/coreml/test/test_coreml_partitioner.py

+46
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from executorch.backends.apple.coreml.compiler import CoreMLBackend
1515
from executorch.backends.apple.coreml.partition import CoreMLPartitioner
16+
from executorch.exir.backend.utils import format_delegated_graph
1617

1718

1819
class TestCoreMLPartitioner(unittest.TestCase):
@@ -79,6 +80,50 @@ def test_vit_skip_conv(self):
7980
"getitem",
8081
]
8182

83+
def test_ops_to_not_decompose(self):
84+
class Model(torch.nn.Module):
85+
def forward(self, q, k, v, mask):
86+
return torch.ops.aten.scaled_dot_product_attention.default(
87+
q, k, v, attn_mask=mask
88+
)
89+
90+
model = Model()
91+
model.eval()
92+
93+
batch_size = 1
94+
n_heads = 12
95+
seq_len = 1
96+
max_seq_length = 32
97+
embedding_dim = 16
98+
q = torch.randn(batch_size, n_heads, seq_len, embedding_dim)
99+
k = torch.randn(batch_size, n_heads, max_seq_length, embedding_dim)
100+
v = torch.randn(batch_size, n_heads, max_seq_length, embedding_dim)
101+
mask = torch.randn(seq_len, max_seq_length)
102+
example_inputs = (q, k, v, mask)
103+
ep = torch.export.export(model, example_inputs)
104+
coreml_partitioner = CoreMLPartitioner()
105+
106+
# Using to_edge_transform_and_lower, we expect SDPA will be preserved and show up in delegated graph
107+
edge_program_manager = executorch.exir.to_edge_transform_and_lower(
108+
ep, partitioner=[coreml_partitioner]
109+
)
110+
self.assertTrue(
111+
"executorch.exir.dialects.edge._ops.aten.scaled_dot_product_attention.default"
112+
in format_delegated_graph(
113+
edge_program_manager.exported_program().graph_module
114+
)
115+
)
116+
117+
# Using to_edge flow, we expect SDPA will be decomposed and not show up in delegated graph
118+
edge_program_manager2 = executorch.exir.to_edge(ep)
119+
edge_program_manager2.to_backend(coreml_partitioner)
120+
self.assertTrue(
121+
"executorch.exir.dialects.edge._ops.aten.scaled_dot_product_attention.default"
122+
not in format_delegated_graph(
123+
edge_program_manager2.exported_program().graph_module
124+
)
125+
)
126+
82127
def test_buffer(self):
83128
embedding_dim = 3
84129
max_seq_len = 2
@@ -129,4 +174,5 @@ def forward(self, q, k_val, input_pos):
129174
test_runner = TestCoreMLPartitioner()
130175
test_runner.test_add_sub_skip_mm()
131176
test_runner.test_vit_skip_conv()
177+
test_runner.test_ops_to_not_decompose()
132178
test_runner.test_buffer()

0 commit comments

Comments
 (0)