Skip to content

Commit 78752a0

Browse files
Arm backend: Refactor Arm backend pass unittests (#8368)
Adds a new TestPassPipeline for testing single passes and updates pass test to use the new pipeline. test_ioquantisation_pass_u55_BI does not use the pipeline as a special case, since it differs substantially from the other tests.
1 parent 6d49702 commit 78752a0

8 files changed

+302
-276
lines changed
Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,44 @@
11
# Copyright 2025 Arm Limited and/or its affiliates.
2-
# All rights reserved.
32
#
43
# This source code is licensed under the BSD-style license found in the
54
# LICENSE file in the root directory of this source tree.
65

7-
import unittest
6+
from typing import Tuple
87

98
import torch
109
from executorch.backends.arm._passes.cast_int64_pass import CastInt64ToInt32Pass
1110

12-
from executorch.backends.arm.test import common
11+
from executorch.backends.arm.test.tester.test_pipeline import TestPassPipeline
1312

14-
from executorch.backends.arm.test.tester.arm_tester import ArmTester, RunPasses
13+
input_t = Tuple[torch.Tensor] # Input x
1514

1615

1716
class Int64Model(torch.nn.Module):
1817

1918
def forward(self, x: torch.Tensor):
2019
return x + 3
2120

22-
def get_inputs(self):
21+
def get_inputs(self) -> input_t:
2322
return (torch.rand(4),)
2423

2524

26-
class TestCastInt64Pass(unittest.TestCase):
27-
28-
def test_int64_model(self):
29-
module = Int64Model()
30-
test_pass_stage = RunPasses(passes_with_exported_program=[CastInt64ToInt32Pass])
31-
tester = (
32-
ArmTester(
33-
module,
34-
example_inputs=module.get_inputs(),
35-
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"),
36-
)
37-
.export()
38-
.to_edge()
39-
.run_passes(test_pass_stage)
40-
.run_method_and_compare_outputs()
41-
)
42-
exported_program = tester.get_artifact("RunPasses").exported_program()
43-
for state in exported_program.state_dict:
44-
assert exported_program.state_dict[state].dtype != torch.int64
25+
def test_int64_model_tosa_BI():
26+
module = Int64Model()
27+
op_checks = {
28+
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 1,
29+
"executorch_exir_dialects_edge__ops_aten_add_Tensor": 1,
30+
}
31+
pipeline = TestPassPipeline[input_t](
32+
module,
33+
module.get_inputs(),
34+
tosa_version="TOSA-0.80+BI",
35+
ops_before_pass=op_checks,
36+
ops_after_pass=op_checks,
37+
passes_with_exported_program=[CastInt64ToInt32Pass],
38+
)
39+
pipeline.pop_stage("quantize")
40+
pipeline.run()
41+
42+
exported_program = pipeline.tester.get_artifact("RunPasses").exported_program()
43+
for state in exported_program.state_dict:
44+
assert exported_program.state_dict[state].dtype == torch.int32
Lines changed: 25 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,50 @@
11
# Copyright 2024-2025 Arm Limited and/or its affiliates.
2-
# All rights reserved.
32
#
43
# This source code is licensed under the BSD-style license found in the
54
# LICENSE file in the root directory of this source tree.
65

7-
import unittest
6+
from typing import Tuple
87

98
import torch
109
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
1110
FoldAndAnnotateQParamsPass,
1211
)
12+
from executorch.backends.arm.test.tester.test_pipeline import TestPassPipeline
1313

14-
from executorch.backends.arm.test import common
15-
from executorch.backends.arm.test.tester.arm_tester import ArmTester
1614

17-
from executorch.backends.xnnpack.test.tester.tester import RunPasses
15+
input_t = Tuple[torch.Tensor, torch.Tensor] # Input x, y
1816

1917

2018
class SimpleQuantizeModel(torch.nn.Module):
2119
def forward(self, x, y):
2220
return x + torch.max((x + x), (y + y))
2321

24-
def get_inputs(self):
22+
def get_inputs(self) -> input_t:
2523
return (torch.rand(1, 1280, 7, 7), torch.rand(1, 1280, 7, 7))
2624

2725

28-
class TestFoldAndAnnotateQParamsPass(unittest.TestCase):
26+
def test_fold_qdq_pass_tosa_BI():
2927
"""
3028
Tests the FoldAndAnnotateQParamsPass which folds dq/q nodes into
3129
the node and stores the quantization parameters in meta.
32-
"""
3330
34-
def test_fold_qdq_pass(self):
35-
"""
36-
Check that the pass runs for add operation and that one q node and one dq node
37-
is removed from the representation.
38-
"""
39-
module = SimpleQuantizeModel()
40-
test_pass_stage = RunPasses([FoldAndAnnotateQParamsPass])
41-
(
42-
ArmTester(
43-
module,
44-
example_inputs=module.get_inputs(),
45-
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"),
46-
)
47-
.quantize()
48-
.export()
49-
.to_edge()
50-
.check_count(
51-
{
52-
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 7,
53-
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 6,
54-
}
55-
)
56-
.run_passes(test_pass_stage)
57-
.check_count(
58-
{
59-
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 1,
60-
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2,
61-
}
62-
)
63-
)
31+
Check that the pass runs for add operation and that one q node and one dq node
32+
is removed from the representation.
33+
"""
34+
module = SimpleQuantizeModel()
35+
pipeline = TestPassPipeline[input_t](
36+
module,
37+
module.get_inputs(),
38+
tosa_version="TOSA-0.80+BI",
39+
ops_before_pass={
40+
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 7,
41+
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 6,
42+
},
43+
ops_after_pass={
44+
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 1,
45+
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2,
46+
},
47+
pass_list=[FoldAndAnnotateQParamsPass],
48+
)
49+
pipeline.pop_stage(-1) # Do not compare output
50+
pipeline.run()

backends/arm/test/passes/test_fuse_batchnorm_pass.py

Lines changed: 28 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
# Copyright 2025 Arm Limited and/or its affiliates.
2-
# All rights reserved.
32
#
43
# This source code is licensed under the BSD-style license found in the
54
# LICENSE file in the root directory of this source tree.
6-
import unittest
5+
6+
from typing import Tuple
77

88
import torch
99
from executorch.backends.arm._passes.fuse_batchnorm2d_pass import FuseBatchnorm2DPass
1010
from executorch.backends.arm.test import common
11-
from executorch.backends.arm.test.tester.arm_tester import ArmTester, RunPasses
12-
from parameterized import parameterized
11+
from executorch.backends.arm.test.tester.test_pipeline import TestPassPipeline
12+
13+
input_t = Tuple[torch.Tensor] # Input x
1314

1415

1516
class MergeOneOfTwoBN(torch.nn.Module):
@@ -35,7 +36,7 @@ def __init__(self, affine: bool):
3536
self.batch_norm2d.bias = torch.nn.Parameter(torch.rand(3))
3637
self.relu6 = torch.nn.ReLU6()
3738

38-
def get_inputs(self) -> tuple[torch.Tensor]:
39+
def get_inputs(self) -> input_t:
3940
return (torch.randn(1, 3, 256, 256),)
4041

4142
def forward(self, x):
@@ -72,7 +73,7 @@ def __init__(self, affine: bool):
7273
self.batch_norm2d.bias = torch.nn.Parameter(torch.rand(3))
7374
self.relu6 = torch.nn.ReLU6()
7475

75-
def get_inputs(self) -> tuple[torch.Tensor]:
76+
def get_inputs(self) -> input_t:
7677
return (torch.randn(1, 3, 256, 256),)
7778

7879
def forward(self, x):
@@ -110,7 +111,7 @@ def __init__(self, affine: bool):
110111
self.batch_norm2d.bias = torch.nn.Parameter(torch.rand(3))
111112
self.relu6 = torch.nn.ReLU6()
112113

113-
def get_inputs(self) -> tuple[torch.Tensor]:
114+
def get_inputs(self) -> input_t:
114115
return (torch.randn(1, 3, 256, 256),)
115116

116117
def forward(self, x):
@@ -126,33 +127,23 @@ def forward(self, x):
126127
return z, a
127128

128129

129-
modules = [
130-
MergeOneOfTwoBN(True),
131-
MergeOneOfTwoBN(False),
132-
MergeTwosOfTwoBN(True),
133-
MergeNoBN(True),
134-
]
135-
136-
137-
class TestFuseBatchnormPass(unittest.TestCase):
138-
139-
@parameterized.expand(modules)
140-
def test_fuse_batchnorm_tosa_MI(self, module):
141-
"""Test various cases where the batchnorm should and shouldn't be fused."""
142-
inputs = module.get_inputs()
143-
test_pass_stage = RunPasses(passes_with_exported_program=[FuseBatchnorm2DPass])
144-
(
145-
(
146-
ArmTester(
147-
module,
148-
example_inputs=inputs,
149-
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"),
150-
)
151-
.export()
152-
.to_edge()
153-
.check_count(module.ops_before_pass)
154-
.run_passes(test_pass_stage)
155-
.check_count(module.ops_after_pass)
156-
.run_method_and_compare_outputs()
157-
)
158-
)
130+
modules = {
131+
"merge_one_of_two_bn_affine": MergeOneOfTwoBN(True),
132+
"merge_one_of_two_bn": MergeOneOfTwoBN(False),
133+
"merge_two_of_two_bn_affine": MergeTwosOfTwoBN(True),
134+
"merge_no_bn_affine": MergeNoBN(True),
135+
}
136+
137+
138+
@common.parametrize("module", modules)
139+
def test_fuse_batchnorm_tosa_MI(module):
140+
"""Test various cases where the batchnorm should and shouldn't be fused."""
141+
pipeline = TestPassPipeline[input_t](
142+
module,
143+
module.get_inputs(),
144+
tosa_version="TOSA-0.80+MI",
145+
ops_before_pass=module.ops_before_pass,
146+
ops_after_pass=module.ops_after_pass,
147+
passes_with_exported_program=[FuseBatchnorm2DPass],
148+
)
149+
pipeline.run()
Lines changed: 24 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,46 @@
11
# Copyright 2025 Arm Limited and/or its affiliates.
2-
# All rights reserved.
32
#
43
# This source code is licensed under the BSD-style license found in the
54
# LICENSE file in the root directory of this source tree.
65

7-
import unittest
6+
7+
from typing import Tuple
88

99
import torch
1010
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
1111
FoldAndAnnotateQParamsPass,
1212
)
1313
from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass
14+
from executorch.backends.arm.test.tester.test_pipeline import TestPassPipeline
1415

15-
from executorch.backends.arm.test import common
16-
17-
from executorch.backends.arm.test.tester.arm_tester import ArmTester, RunPasses
16+
input_t = Tuple[torch.Tensor] # Input x
1817

1918

2019
class Sigmoid(torch.nn.Module):
2120

2221
def forward(self, x: torch.Tensor):
2322
return x.sigmoid()
2423

25-
def get_inputs(self):
24+
def get_inputs(self) -> input_t:
2625
return (torch.rand(4),)
2726

2827

29-
class TestInsertTablePass(unittest.TestCase):
30-
31-
def test_insert_table_tosa_BI(self):
32-
module = Sigmoid()
33-
test_pass_stage = RunPasses(
34-
[FoldAndAnnotateQParamsPass],
35-
passes_with_exported_program=[InsertTableOpsPass],
36-
)
37-
(
38-
ArmTester(
39-
module,
40-
example_inputs=module.get_inputs(),
41-
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"),
42-
)
43-
.quantize()
44-
.export()
45-
.to_edge()
46-
.run_passes(test_pass_stage)
47-
.check("tosa._table")
48-
.check_count(
49-
{
50-
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 1,
51-
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 1,
52-
}
53-
)
54-
.check_not(["aten_sigmoid_default"])
55-
)
28+
def test_insert_table_tosa_BI():
29+
module = Sigmoid()
30+
pipeline = TestPassPipeline[input_t](
31+
module,
32+
module.get_inputs(),
33+
tosa_version="TOSA-0.80+BI",
34+
ops_before_pass={},
35+
ops_after_pass={
36+
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 1,
37+
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 1,
38+
"tosa._table": 1,
39+
},
40+
ops_not_after_pass=["aten_sigmoid_default"],
41+
pass_list=[FoldAndAnnotateQParamsPass],
42+
passes_with_exported_program=[InsertTableOpsPass],
43+
)
44+
pipeline.pop_stage(-1) # Do not compare output
45+
46+
pipeline.run()

0 commit comments

Comments
 (0)