|
13 | 13 |
|
14 | 14 | from executorch.backends.apple.coreml.compiler import CoreMLBackend
|
15 | 15 | from executorch.backends.apple.coreml.partition import CoreMLPartitioner
|
| 16 | +from executorch.exir.backend.utils import format_delegated_graph |
16 | 17 |
|
17 | 18 |
|
18 | 19 | class TestCoreMLPartitioner(unittest.TestCase):
|
@@ -79,6 +80,50 @@ def test_vit_skip_conv(self):
|
79 | 80 | "getitem",
|
80 | 81 | ]
|
81 | 82 |
|
| 83 | + def test_ops_to_not_decompose(self): |
| 84 | + class Model(torch.nn.Module): |
| 85 | + def forward(self, q, k, v, mask): |
| 86 | + return torch.ops.aten.scaled_dot_product_attention.default( |
| 87 | + q, k, v, attn_mask=mask |
| 88 | + ) |
| 89 | + |
| 90 | + model = Model() |
| 91 | + model.eval() |
| 92 | + |
| 93 | + batch_size = 1 |
| 94 | + n_heads = 12 |
| 95 | + seq_len = 1 |
| 96 | + max_seq_length = 32 |
| 97 | + embedding_dim = 16 |
| 98 | + q = torch.randn(batch_size, n_heads, seq_len, embedding_dim) |
| 99 | + k = torch.randn(batch_size, n_heads, max_seq_length, embedding_dim) |
| 100 | + v = torch.randn(batch_size, n_heads, max_seq_length, embedding_dim) |
| 101 | + mask = torch.randn(seq_len, max_seq_length) |
| 102 | + example_inputs = (q, k, v, mask) |
| 103 | + ep = torch.export.export(model, example_inputs) |
| 104 | + coreml_partitioner = CoreMLPartitioner() |
| 105 | + |
| 106 | + # Using to_edge_transform_and_lower, we expect SDPA will be preserved and show up in delegated graph |
| 107 | + edge_program_manager = executorch.exir.to_edge_transform_and_lower( |
| 108 | + ep, partitioner=[coreml_partitioner] |
| 109 | + ) |
| 110 | + self.assertTrue( |
| 111 | + "executorch.exir.dialects.edge._ops.aten.scaled_dot_product_attention.default" |
| 112 | + in format_delegated_graph( |
| 113 | + edge_program_manager.exported_program().graph_module |
| 114 | + ) |
| 115 | + ) |
| 116 | + |
| 117 | + # Using to_edge flow, we expect SDPA will be decomposed and not show up in delegated graph |
| 118 | + edge_program_manager2 = executorch.exir.to_edge(ep) |
| 119 | + edge_program_manager2.to_backend(coreml_partitioner) |
| 120 | + self.assertTrue( |
| 121 | + "executorch.exir.dialects.edge._ops.aten.scaled_dot_product_attention.default" |
| 122 | + not in format_delegated_graph( |
| 123 | + edge_program_manager2.exported_program().graph_module |
| 124 | + ) |
| 125 | + ) |
| 126 | + |
82 | 127 | def test_buffer(self):
|
83 | 128 | embedding_dim = 3
|
84 | 129 | max_seq_len = 2
|
@@ -129,4 +174,5 @@ def forward(self, q, k_val, input_pos):
|
129 | 174 | test_runner = TestCoreMLPartitioner()
|
130 | 175 | test_runner.test_add_sub_skip_mm()
|
131 | 176 | test_runner.test_vit_skip_conv()
|
| 177 | + test_runner.test_ops_to_not_decompose() |
132 | 178 | test_runner.test_buffer()
|
0 commit comments