Skip to content

Commit 705f100

Browse files
authored
Arm backend: Bump Vela pin to support matmul on Ethos-U55 (#9184)
- Remove xfails - Refactor test_mm to use testing pipelines Signed-off-by: Erik Lundell <[email protected]>
1 parent 9216768 commit 705f100

File tree

3 files changed

+66
-160
lines changed

3 files changed

+66
-160
lines changed

backends/arm/test/ops/test_bmm.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -150,11 +150,9 @@ def test_bmm_single_input_tosa_BI(self, test_data_generator: Callable[[], Tuple]
150150
test_data = test_data_generator()
151151
self._test_bmm_tosa_BI_pipeline(self.BMMSingleInput(), test_data)
152152

153-
# Expected to fail on FVP as TOSA.MATMUL is not supported on U55
154153
@parameterized.expand(BMM.test_data_generators)
155154
@pytest.mark.corstone_fvp
156-
@conftest.expectedFailureOnFVP
157-
def test_bmm_u55_BI_xfails(self, test_data_generator: Callable[[], Tuple]):
155+
def test_bmm_u55_BI(self, test_data_generator: Callable[[], Tuple]):
158156
test_data = test_data_generator()
159157
self._test_bmm_ethosu_BI_pipeline(
160158
self.BMM(), common.get_u55_compile_spec(), test_data
@@ -171,10 +169,7 @@ def test_bmm_u85_BI(self, test_data_generator: Callable[[], Tuple]):
171169
# Expected to fail on FVP as TOSA.MATMUL is not supported on U55
172170
@parameterized.expand(BMMSingleInput.test_data_generators)
173171
@pytest.mark.corstone_fvp
174-
@conftest.expectedFailureOnFVP
175-
def test_bmm_single_input_u55_BI_xfails(
176-
self, test_data_generator: Callable[[], Tuple]
177-
):
172+
def test_bmm_single_input_u55_BI(self, test_data_generator: Callable[[], Tuple]):
178173
test_data = test_data_generator()
179174
self._test_bmm_ethosu_BI_pipeline(
180175
self.BMMSingleInput(), common.get_u55_compile_spec(), test_data

backends/arm/test/ops/test_mm.py

Lines changed: 63 additions & 152 deletions
Original file line numberDiff line numberDiff line change
@@ -4,160 +4,71 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import logging
8-
import unittest
7+
from typing import Callable
98

10-
from typing import Callable, Tuple
11-
12-
import pytest
139
import torch
1410
from executorch.backends.arm.test import common
15-
from executorch.backends.arm.test.tester.arm_tester import ArmTester
16-
from executorch.exir.backend.backend_details import CompileSpec
11+
from executorch.backends.arm.test.tester.test_pipeline import (
12+
EthosU55PipelineBI,
13+
EthosU85PipelineBI,
14+
TosaPipelineBI,
15+
TosaPipelineMI,
16+
)
1717
from parameterized import parameterized
1818

19-
logger = logging.getLogger(__name__)
20-
logger.setLevel(logging.INFO)
21-
22-
23-
class TestMM(unittest.TestCase):
24-
"""Tests MatMul"""
25-
26-
class MM(torch.nn.Module):
27-
test_data_generators = [
28-
lambda: (torch.rand(3, 5), torch.rand(5, 2)),
29-
lambda: (torch.rand(1, 1), torch.rand(1, 1)),
30-
lambda: (torch.ones(55, 3), torch.ones(3, 44)),
31-
lambda: (10000 * torch.randn(1, 10), torch.randn(10, 5)),
32-
lambda: (-10 * torch.randn(32, 64), 5 + 5 * torch.randn(64, 32)),
33-
]
34-
35-
def forward(self, x, y):
36-
return torch.mm(x, y)
37-
38-
class MMSingleInput(torch.nn.Module):
39-
test_data_generators = [
40-
lambda: (torch.rand(3, 3),),
41-
lambda: (torch.ones(128, 128),),
42-
lambda: (10000 * torch.randn(25, 25),),
43-
lambda: (5 + 5 * torch.randn(64, 64),),
44-
]
45-
46-
def forward(self, x):
47-
return torch.mm(x, x)
48-
49-
def _test_mm_tosa_MI_pipeline(
50-
self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
51-
):
52-
(
53-
ArmTester(
54-
module,
55-
example_inputs=test_data,
56-
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"),
57-
)
58-
.export()
59-
.check_count({"torch.ops.aten.mm.default": 1})
60-
.check_not(["torch.ops.quantized_decomposed"])
61-
.to_edge()
62-
.partition()
63-
.check_not(["executorch_exir_dialects_edge__ops_aten_mm_default"])
64-
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
65-
.to_executorch()
66-
.run_method_and_compare_outputs(inputs=test_data)
67-
)
68-
69-
def _test_mm_tosa_BI_pipeline(
70-
self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
71-
):
72-
(
73-
ArmTester(
74-
module,
75-
example_inputs=test_data,
76-
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"),
77-
)
78-
.quantize()
79-
.export()
80-
.check_count({"torch.ops.aten.mm.default": 1})
81-
.check(["torch.ops.quantized_decomposed"])
82-
.to_edge()
83-
.partition()
84-
.check_not(["executorch_exir_dialects_edge__ops_aten_mm_default"])
85-
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
86-
.to_executorch()
87-
.run_method_and_compare_outputs(inputs=test_data)
88-
)
89-
90-
def _test_mm_ethosu_BI_pipeline(
91-
self,
92-
compile_spec: CompileSpec,
93-
module: torch.nn.Module,
94-
test_data: Tuple[torch.Tensor],
95-
):
96-
(
97-
ArmTester(
98-
module,
99-
example_inputs=test_data,
100-
compile_spec=compile_spec,
101-
)
102-
.quantize()
103-
.export()
104-
.check_count({"torch.ops.aten.mm.default": 1})
105-
.check(["torch.ops.quantized_decomposed"])
106-
.to_edge()
107-
.partition()
108-
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
109-
.to_executorch()
110-
)
111-
112-
@parameterized.expand(MM.test_data_generators)
113-
def test_mm_tosa_MI(self, test_data_generator: Callable[[], Tuple]):
114-
test_data = test_data_generator()
115-
self._test_mm_tosa_MI_pipeline(self.MM(), test_data)
116-
117-
@parameterized.expand(MMSingleInput.test_data_generators)
118-
@pytest.mark.flaky # TODO: Investigate flakyness (MLETORCH-534)
119-
def test_mm_single_input_tosa_MI(self, test_data_generator: Callable[[], Tuple]):
120-
test_data = test_data_generator()
121-
self._test_mm_tosa_MI_pipeline(self.MMSingleInput(), test_data)
122-
123-
@parameterized.expand(MM.test_data_generators)
124-
@pytest.mark.flaky # TODO: Investigate flakyness (MLETORCH-534)
125-
def test_mm_tosa_BI(self, test_data_generator: Callable[[], Tuple]):
126-
test_data = test_data_generator()
127-
self._test_mm_tosa_BI_pipeline(self.MM(), test_data)
128-
129-
@parameterized.expand(MMSingleInput.test_data_generators)
130-
@pytest.mark.flaky # TODO: Investigate flakyness (MLETORCH-534)
131-
def test_mm_single_input_tosa_BI(self, test_data_generator: Callable[[], Tuple]):
132-
test_data = test_data_generator()
133-
self._test_mm_tosa_BI_pipeline(self.MMSingleInput(), test_data)
134-
135-
# TODO: Enable numerical testing
136-
@parameterized.expand(MM.test_data_generators)
137-
def test_mm_u55_BI(self, test_data_generator: Callable[[], Tuple]):
138-
test_data = test_data_generator()
139-
self._test_mm_ethosu_BI_pipeline(
140-
common.get_u55_compile_spec(), self.MM(), test_data
141-
)
142-
143-
# TODO: Enable numerical testing
144-
@parameterized.expand(MMSingleInput.test_data_generators)
145-
def test_mm_single_input_u55_BI(self, test_data_generator: Callable[[], Tuple]):
146-
test_data = test_data_generator()
147-
self._test_mm_ethosu_BI_pipeline(
148-
common.get_u55_compile_spec(), self.MMSingleInput(), test_data
149-
)
150-
151-
@parameterized.expand(MM.test_data_generators)
152-
def test_mm_u85_BI(self, test_data_generator: Callable[[], Tuple]):
153-
test_data = test_data_generator()
154-
self._test_mm_ethosu_BI_pipeline(
155-
common.get_u85_compile_spec(), self.MM(), test_data
156-
)
157-
158-
@parameterized.expand(MMSingleInput.test_data_generators)
159-
def test_mm_single_input_u85_BI(self, test_data_generator: Callable[[], Tuple]):
160-
test_data = test_data_generator()
161-
self._test_mm_ethosu_BI_pipeline(
162-
common.get_u85_compile_spec(), self.MMSingleInput(), test_data
163-
)
19+
test_t = tuple[torch.Tensor, torch.Tensor]
20+
21+
22+
class MM(torch.nn.Module):
23+
test_data_generators = [
24+
lambda: (torch.rand(3, 5), torch.rand(5, 2)),
25+
lambda: (torch.rand(1, 1), torch.rand(1, 1)),
26+
lambda: (torch.ones(55, 3), torch.ones(3, 44)),
27+
lambda: (10000 * torch.randn(1, 10), torch.randn(10, 5)),
28+
lambda: (-10 * torch.randn(32, 64), 5 + 5 * torch.randn(64, 32)),
29+
]
30+
aten_op = "torch.ops.aten.mm.default"
31+
exir_op = "executorch_exir_dialects_edge__ops_aten_mm_default"
32+
33+
def forward(self, x, y):
34+
return torch.mm(x, y)
35+
36+
37+
@parameterized.expand(MM.test_data_generators)
38+
def test_mm_tosa_MI(test_data_generator: Callable[[], tuple]):
39+
test_data = test_data_generator()
40+
TosaPipelineMI[test_t](MM(), test_data, MM.aten_op).run()
41+
42+
43+
@parameterized.expand(MM.test_data_generators)
44+
def test_mm_tosa_BI(test_data_generator: Callable[[], tuple]):
45+
test_data = test_data_generator()
46+
TosaPipelineBI[test_t](MM(), test_data, MM.aten_op, MM.exir_op).run()
47+
48+
49+
@parameterized.expand(MM.test_data_generators)
50+
def test_mm_tosa_u55(test_data_generator: Callable[[], tuple]):
51+
test_data = test_data_generator()
52+
EthosU55PipelineBI[test_t](MM(), test_data, MM.aten_op).run()
53+
54+
55+
@parameterized.expand(MM.test_data_generators)
56+
def test_mm_tosa_u85(test_data_generator: Callable[[], tuple]):
57+
test_data = test_data_generator()
58+
EthosU85PipelineBI[test_t](MM(), test_data, MM.aten_op, MM.exir_op).run()
59+
60+
61+
@parameterized.expand(MM.test_data_generators)
62+
@common.SkipIfNoCorstone300
63+
def test_mm_tosa_u55_on_fvp(test_data_generator: Callable[[], tuple]):
64+
test_data = test_data_generator()
65+
EthosU55PipelineBI[test_t](MM(), test_data, MM.aten_op, run_on_fvp=True).run()
66+
67+
68+
@parameterized.expand(MM.test_data_generators)
69+
@common.SkipIfNoCorstone320
70+
def test_mm_tosa_u85_on_fvp(test_data_generator: Callable[[], tuple]):
71+
test_data = test_data_generator()
72+
EthosU85PipelineBI[test_t](
73+
MM(), test_data, MM.aten_op, MM.exir_op, run_on_fvp=True
74+
).run()

examples/arm/setup.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ tosa_reference_model_rev="70ed0b40fa831387e36abdb4f7fb9670a3464f5a"
6565

6666
# vela
6767
vela_repo_url="https://gitlab.arm.com/artificial-intelligence/ethos-u/ethos-u-vela"
68-
vela_rev="46d88f56902be0706e051c10153ffb7620e01ee3"
68+
vela_rev="425541302c7e4b6fbeca7c0061286b131ee507c3"
6969

7070
########
7171
### Optional user args

0 commit comments

Comments
 (0)