|
4 | 4 | # This source code is licensed under the BSD-style license found in the |
5 | 5 | # LICENSE file in the root directory of this source tree. |
6 | 6 |
|
7 | | -import logging |
8 | | -import unittest |
| 7 | +from typing import Callable |
9 | 8 |
|
10 | | -from typing import Callable, Tuple |
11 | | - |
12 | | -import pytest |
13 | 9 | import torch |
14 | 10 | from executorch.backends.arm.test import common |
15 | | -from executorch.backends.arm.test.tester.arm_tester import ArmTester |
16 | | -from executorch.exir.backend.backend_details import CompileSpec |
| 11 | +from executorch.backends.arm.test.tester.test_pipeline import ( |
| 12 | + EthosU55PipelineBI, |
| 13 | + EthosU85PipelineBI, |
| 14 | + TosaPipelineBI, |
| 15 | + TosaPipelineMI, |
| 16 | +) |
17 | 17 | from parameterized import parameterized |
18 | 18 |
|
19 | | -logger = logging.getLogger(__name__) |
20 | | -logger.setLevel(logging.INFO) |
21 | | - |
22 | | - |
23 | | -class TestMM(unittest.TestCase): |
24 | | - """Tests MatMul""" |
25 | | - |
26 | | - class MM(torch.nn.Module): |
27 | | - test_data_generators = [ |
28 | | - lambda: (torch.rand(3, 5), torch.rand(5, 2)), |
29 | | - lambda: (torch.rand(1, 1), torch.rand(1, 1)), |
30 | | - lambda: (torch.ones(55, 3), torch.ones(3, 44)), |
31 | | - lambda: (10000 * torch.randn(1, 10), torch.randn(10, 5)), |
32 | | - lambda: (-10 * torch.randn(32, 64), 5 + 5 * torch.randn(64, 32)), |
33 | | - ] |
34 | | - |
35 | | - def forward(self, x, y): |
36 | | - return torch.mm(x, y) |
37 | | - |
38 | | - class MMSingleInput(torch.nn.Module): |
39 | | - test_data_generators = [ |
40 | | - lambda: (torch.rand(3, 3),), |
41 | | - lambda: (torch.ones(128, 128),), |
42 | | - lambda: (10000 * torch.randn(25, 25),), |
43 | | - lambda: (5 + 5 * torch.randn(64, 64),), |
44 | | - ] |
45 | | - |
46 | | - def forward(self, x): |
47 | | - return torch.mm(x, x) |
48 | | - |
49 | | - def _test_mm_tosa_MI_pipeline( |
50 | | - self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] |
51 | | - ): |
52 | | - ( |
53 | | - ArmTester( |
54 | | - module, |
55 | | - example_inputs=test_data, |
56 | | - compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), |
57 | | - ) |
58 | | - .export() |
59 | | - .check_count({"torch.ops.aten.mm.default": 1}) |
60 | | - .check_not(["torch.ops.quantized_decomposed"]) |
61 | | - .to_edge() |
62 | | - .partition() |
63 | | - .check_not(["executorch_exir_dialects_edge__ops_aten_mm_default"]) |
64 | | - .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) |
65 | | - .to_executorch() |
66 | | - .run_method_and_compare_outputs(inputs=test_data) |
67 | | - ) |
68 | | - |
69 | | - def _test_mm_tosa_BI_pipeline( |
70 | | - self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] |
71 | | - ): |
72 | | - ( |
73 | | - ArmTester( |
74 | | - module, |
75 | | - example_inputs=test_data, |
76 | | - compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), |
77 | | - ) |
78 | | - .quantize() |
79 | | - .export() |
80 | | - .check_count({"torch.ops.aten.mm.default": 1}) |
81 | | - .check(["torch.ops.quantized_decomposed"]) |
82 | | - .to_edge() |
83 | | - .partition() |
84 | | - .check_not(["executorch_exir_dialects_edge__ops_aten_mm_default"]) |
85 | | - .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) |
86 | | - .to_executorch() |
87 | | - .run_method_and_compare_outputs(inputs=test_data) |
88 | | - ) |
89 | | - |
90 | | - def _test_mm_ethosu_BI_pipeline( |
91 | | - self, |
92 | | - compile_spec: CompileSpec, |
93 | | - module: torch.nn.Module, |
94 | | - test_data: Tuple[torch.Tensor], |
95 | | - ): |
96 | | - ( |
97 | | - ArmTester( |
98 | | - module, |
99 | | - example_inputs=test_data, |
100 | | - compile_spec=compile_spec, |
101 | | - ) |
102 | | - .quantize() |
103 | | - .export() |
104 | | - .check_count({"torch.ops.aten.mm.default": 1}) |
105 | | - .check(["torch.ops.quantized_decomposed"]) |
106 | | - .to_edge() |
107 | | - .partition() |
108 | | - .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) |
109 | | - .to_executorch() |
110 | | - ) |
111 | | - |
112 | | - @parameterized.expand(MM.test_data_generators) |
113 | | - def test_mm_tosa_MI(self, test_data_generator: Callable[[], Tuple]): |
114 | | - test_data = test_data_generator() |
115 | | - self._test_mm_tosa_MI_pipeline(self.MM(), test_data) |
116 | | - |
117 | | - @parameterized.expand(MMSingleInput.test_data_generators) |
118 | | - @pytest.mark.flaky # TODO: Investigate flakyness (MLETORCH-534) |
119 | | - def test_mm_single_input_tosa_MI(self, test_data_generator: Callable[[], Tuple]): |
120 | | - test_data = test_data_generator() |
121 | | - self._test_mm_tosa_MI_pipeline(self.MMSingleInput(), test_data) |
122 | | - |
123 | | - @parameterized.expand(MM.test_data_generators) |
124 | | - @pytest.mark.flaky # TODO: Investigate flakyness (MLETORCH-534) |
125 | | - def test_mm_tosa_BI(self, test_data_generator: Callable[[], Tuple]): |
126 | | - test_data = test_data_generator() |
127 | | - self._test_mm_tosa_BI_pipeline(self.MM(), test_data) |
128 | | - |
129 | | - @parameterized.expand(MMSingleInput.test_data_generators) |
130 | | - @pytest.mark.flaky # TODO: Investigate flakyness (MLETORCH-534) |
131 | | - def test_mm_single_input_tosa_BI(self, test_data_generator: Callable[[], Tuple]): |
132 | | - test_data = test_data_generator() |
133 | | - self._test_mm_tosa_BI_pipeline(self.MMSingleInput(), test_data) |
134 | | - |
135 | | - # TODO: Enable numerical testing |
136 | | - @parameterized.expand(MM.test_data_generators) |
137 | | - def test_mm_u55_BI(self, test_data_generator: Callable[[], Tuple]): |
138 | | - test_data = test_data_generator() |
139 | | - self._test_mm_ethosu_BI_pipeline( |
140 | | - common.get_u55_compile_spec(), self.MM(), test_data |
141 | | - ) |
142 | | - |
143 | | - # TODO: Enable numerical testing |
144 | | - @parameterized.expand(MMSingleInput.test_data_generators) |
145 | | - def test_mm_single_input_u55_BI(self, test_data_generator: Callable[[], Tuple]): |
146 | | - test_data = test_data_generator() |
147 | | - self._test_mm_ethosu_BI_pipeline( |
148 | | - common.get_u55_compile_spec(), self.MMSingleInput(), test_data |
149 | | - ) |
150 | | - |
151 | | - @parameterized.expand(MM.test_data_generators) |
152 | | - def test_mm_u85_BI(self, test_data_generator: Callable[[], Tuple]): |
153 | | - test_data = test_data_generator() |
154 | | - self._test_mm_ethosu_BI_pipeline( |
155 | | - common.get_u85_compile_spec(), self.MM(), test_data |
156 | | - ) |
157 | | - |
158 | | - @parameterized.expand(MMSingleInput.test_data_generators) |
159 | | - def test_mm_single_input_u85_BI(self, test_data_generator: Callable[[], Tuple]): |
160 | | - test_data = test_data_generator() |
161 | | - self._test_mm_ethosu_BI_pipeline( |
162 | | - common.get_u85_compile_spec(), self.MMSingleInput(), test_data |
163 | | - ) |
| 19 | +test_t = tuple[torch.Tensor, torch.Tensor] |
| 20 | + |
| 21 | + |
| 22 | +class MM(torch.nn.Module): |
| 23 | + test_data_generators = [ |
| 24 | + lambda: (torch.rand(3, 5), torch.rand(5, 2)), |
| 25 | + lambda: (torch.rand(1, 1), torch.rand(1, 1)), |
| 26 | + lambda: (torch.ones(55, 3), torch.ones(3, 44)), |
| 27 | + lambda: (10000 * torch.randn(1, 10), torch.randn(10, 5)), |
| 28 | + lambda: (-10 * torch.randn(32, 64), 5 + 5 * torch.randn(64, 32)), |
| 29 | + ] |
| 30 | + aten_op = "torch.ops.aten.mm.default" |
| 31 | + exir_op = "executorch_exir_dialects_edge__ops_aten_mm_default" |
| 32 | + |
| 33 | + def forward(self, x, y): |
| 34 | + return torch.mm(x, y) |
| 35 | + |
| 36 | + |
| 37 | +@parameterized.expand(MM.test_data_generators) |
| 38 | +def test_mm_tosa_MI(test_data_generator: Callable[[], tuple]): |
| 39 | + test_data = test_data_generator() |
| 40 | + TosaPipelineMI[test_t](MM(), test_data, MM.aten_op).run() |
| 41 | + |
| 42 | + |
| 43 | +@parameterized.expand(MM.test_data_generators) |
| 44 | +def test_mm_tosa_BI(test_data_generator: Callable[[], tuple]): |
| 45 | + test_data = test_data_generator() |
| 46 | + TosaPipelineBI[test_t](MM(), test_data, MM.aten_op, MM.exir_op).run() |
| 47 | + |
| 48 | + |
| 49 | +@parameterized.expand(MM.test_data_generators) |
| 50 | +def test_mm_tosa_u55(test_data_generator: Callable[[], tuple]): |
| 51 | + test_data = test_data_generator() |
| 52 | + EthosU55PipelineBI[test_t](MM(), test_data, MM.aten_op).run() |
| 53 | + |
| 54 | + |
| 55 | +@parameterized.expand(MM.test_data_generators) |
| 56 | +def test_mm_tosa_u85(test_data_generator: Callable[[], tuple]): |
| 57 | + test_data = test_data_generator() |
| 58 | + EthosU85PipelineBI[test_t](MM(), test_data, MM.aten_op, MM.exir_op).run() |
| 59 | + |
| 60 | + |
| 61 | +@parameterized.expand(MM.test_data_generators) |
| 62 | +@common.SkipIfNoCorstone300 |
| 63 | +def test_mm_tosa_u55_on_fvp(test_data_generator: Callable[[], tuple]): |
| 64 | + test_data = test_data_generator() |
| 65 | + EthosU55PipelineBI[test_t](MM(), test_data, MM.aten_op, run_on_fvp=True).run() |
| 66 | + |
| 67 | + |
| 68 | +@parameterized.expand(MM.test_data_generators) |
| 69 | +@common.SkipIfNoCorstone320 |
| 70 | +def test_mm_tosa_u85_on_fvp(test_data_generator: Callable[[], tuple]): |
| 71 | + test_data = test_data_generator() |
| 72 | + EthosU85PipelineBI[test_t]( |
| 73 | + MM(), test_data, MM.aten_op, MM.exir_op, run_on_fvp=True |
| 74 | + ).run() |
0 commit comments