Skip to content

Commit 7064d03

Browse files
committed
Arm backend: Move to use tosa_tools.v0_80 for TOSA operations
NFC that moves to use the TOSA tools from a different namespace. Preparation to run 0.80 and 1.0 along side during the transition period. Signed-off-by: Per Åstrand <[email protected]> Change-Id: I69d6d6a46fa95213afde8952f151fc9b41065099
1 parent 534b062 commit 7064d03

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+157
-166
lines changed

.mypy.ini

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@ ignore_missing_imports = True
8080
[mypy-serializer.*]
8181
ignore_missing_imports = True
8282

83+
[mypy-tosa_tools.*]
84+
ignore_missing_imports = True
85+
8386
[mypy-setuptools.*]
8487
ignore_missing_imports = True
8588

backends/arm/operators/node_visitor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77

88
from typing import Dict, List
99

10-
import serializer.tosa_serializer as ts # type: ignore
1110
import torch
11+
12+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
1213
from executorch.backends.arm.tosa_mapping import TosaArg
1314
from executorch.backends.arm.tosa_specification import TosaSpecification
1415
from torch.export import ExportedProgram

backends/arm/operators/op_abs.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,13 @@
99
import executorch.backends.arm.tosa_quant_utils as tqutils
1010
import executorch.backends.arm.tosa_utils as tutils
1111

12-
import serializer.tosa_serializer as ts # type: ignore
12+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
1313
from executorch.backends.arm.operators.node_visitor import (
1414
NodeVisitor,
1515
register_node_visitor,
1616
)
1717
from executorch.backends.arm.tosa_mapping import TosaArg
1818
from executorch.backends.arm.tosa_specification import TosaSpecification
19-
20-
from serializer.tosa_serializer import TosaOp
2119
from torch.fx import Node
2220

2321

@@ -70,7 +68,7 @@ def define_node(
7068

7169
# Do the INT32 Abs
7270
tosa_graph.addOperator(
73-
TosaOp.Op().ABS,
71+
ts.TosaOp.Op().ABS,
7472
[
7573
rescaled_inputs[0].name,
7674
],
@@ -126,7 +124,7 @@ def define_node(
126124

127125
# MI lowering
128126
tosa_graph.addOperator(
129-
TosaOp.Op().ABS,
127+
ts.TosaOp.Op().ABS,
130128
[inputs[0].name],
131129
[output.name],
132130
None,

backends/arm/operators/op_add.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,13 @@
1010
import executorch.backends.arm.tosa_quant_utils as tqutils
1111
import executorch.backends.arm.tosa_utils as tutils
1212

13-
import serializer.tosa_serializer as ts # type: ignore
13+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
1414
from executorch.backends.arm.operators.node_visitor import (
1515
NodeVisitor,
1616
register_node_visitor,
1717
)
1818
from executorch.backends.arm.tosa_mapping import TosaArg
1919
from executorch.backends.arm.tosa_specification import TosaSpecification
20-
from serializer.tosa_serializer import TosaOp
2120
from torch.fx import Node
2221

2322

@@ -73,7 +72,7 @@ def define_node(
7372

7473
# Do the INT32 Add
7574
tosa_graph.addOperator(
76-
TosaOp.Op().ADD,
75+
ts.TosaOp.Op().ADD,
7776
[input1.name, input2.name],
7877
[add_output.name],
7978
None,
@@ -119,7 +118,7 @@ def define_node(
119118

120119
# MI lowering
121120
tosa_graph.addOperator(
122-
TosaOp.Op().ADD,
121+
ts.TosaOp.Op().ADD,
123122
[input1.name, input2.name],
124123
[output.name],
125124
None,

backends/arm/operators/op_amax.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,13 @@
44
# LICENSE file in the root directory of this source tree.
55
from typing import List
66

7-
import serializer.tosa_serializer as ts
7+
import tosa_tools.v0_80.serializer.tosa_serializer as ts
88
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
99
from executorch.backends.arm.operators.node_visitor import (
1010
NodeVisitor,
1111
register_node_visitor,
1212
)
1313
from executorch.backends.arm.tosa_mapping import TosaArg
14-
from serializer.tosa_serializer import TosaOp
1514
from torch.fx import Node
1615

1716

@@ -48,5 +47,5 @@ def define_node(
4847
attr.AxisAttribute(input.dim_order.index(dim))
4948

5049
tosa_graph.addOperator(
51-
TosaOp.Op().REDUCE_MAX, [input.name], [output.name], attr
50+
ts.TosaOp.Op().REDUCE_MAX, [input.name], [output.name], attr
5251
)

backends/arm/operators/op_amin.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,13 @@
44
# LICENSE file in the root directory of this source tree.
55
from typing import List
66

7-
import serializer.tosa_serializer as ts
7+
import tosa_tools.v0_80.serializer.tosa_serializer as ts
88
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
99
from executorch.backends.arm.operators.node_visitor import (
1010
NodeVisitor,
1111
register_node_visitor,
1212
)
1313
from executorch.backends.arm.tosa_mapping import TosaArg
14-
from serializer.tosa_serializer import TosaOp
1514
from torch.fx import Node
1615

1716

@@ -48,5 +47,5 @@ def define_node(
4847
attr.AxisAttribute(input.dim_order.index(dim))
4948

5049
tosa_graph.addOperator(
51-
TosaOp.Op().REDUCE_MIN, [input.name], [output.name], attr
50+
ts.TosaOp.Op().REDUCE_MIN, [input.name], [output.name], attr
5251
)

backends/arm/operators/op_any.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,13 @@
66
# pyre-unsafe
77
from typing import cast, List
88

9-
import serializer.tosa_serializer as ts # type: ignore
9+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
1010
from executorch.backends.arm.operators.node_visitor import ( # type: ignore
1111
NodeVisitor,
1212
register_node_visitor,
1313
)
1414

1515
from executorch.backends.arm.tosa_mapping import TosaArg # type: ignore
16-
from serializer.tosa_serializer import TosaOp
1716
from torch.fx import Node
1817

1918

@@ -49,5 +48,5 @@ def define_node(
4948
attr.AxisAttribute(inputs[0].dim_order.index(dim))
5049

5150
tosa_graph.addOperator(
52-
TosaOp.Op().REDUCE_ANY, [inputs[0].name], [output.name], attr
51+
ts.TosaOp.Op().REDUCE_ANY, [inputs[0].name], [output.name], attr
5352
)

backends/arm/operators/op_avg_pool2d.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
# pyre-unsafe
77
from typing import List
88

9-
import serializer.tosa_serializer as ts # type: ignore
109
import torch
1110

11+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
12+
1213
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
1314
get_input_qparams,
1415
get_output_qparams,

backends/arm/operators/op_bmm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
# pyre-unsafe
88
from typing import List
99

10-
import serializer.tosa_serializer as ts # type: ignore
1110
import torch
1211

12+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
13+
1314
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
1415
get_input_qparams,
1516
get_output_qparams,
@@ -20,7 +21,6 @@
2021
)
2122
from executorch.backends.arm.tosa_mapping import TosaArg
2223
from executorch.backends.arm.tosa_quant_utils import build_rescale
23-
from serializer.tosa_serializer import TosaOp
2424

2525

2626
@register_node_visitor
@@ -64,7 +64,7 @@ def define_node(
6464
attr.MatMulAttribute(A_zp=input0_zp, B_zp=input1_zp)
6565

6666
tosa_graph.addOperator(
67-
TosaOp.Op().MATMUL,
67+
ts.TosaOp.Op().MATMUL,
6868
[inputs[0].name, inputs[1].name],
6969
[bmm_output_name],
7070
attr,

backends/arm/operators/op_cat.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,12 @@
77

88
from typing import List
99

10-
import serializer.tosa_serializer as ts # type: ignore
10+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
1111
from executorch.backends.arm.operators.node_visitor import (
1212
NodeVisitor,
1313
register_node_visitor,
1414
)
1515
from executorch.backends.arm.tosa_mapping import TosaArg
16-
from serializer.tosa_serializer import TosaOp
1716
from torch.fx import Node
1817

1918

@@ -42,5 +41,8 @@ def define_node(
4241
attr.AxisAttribute(dim)
4342

4443
tosa_graph.addOperator(
45-
TosaOp.Op().CONCAT, [tensor.name for tensor in tensors], [output.name], attr
44+
ts.TosaOp.Op().CONCAT,
45+
[tensor.name for tensor in tensors],
46+
[output.name],
47+
attr,
4648
)

backends/arm/operators/op_clamp.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,16 @@
88

99
from typing import Any, List, Tuple
1010

11-
import serializer.tosa_serializer as ts # type: ignore
12-
1311
import torch
12+
13+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
1414
from executorch.backends.arm.operators.node_visitor import (
1515
NodeVisitor,
1616
register_node_visitor,
1717
)
1818

1919
from executorch.backends.arm.tosa_mapping import TosaArg
2020
from executorch.backends.arm.tosa_specification import TosaSpecification
21-
from serializer.tosa_serializer import TosaOp
2221
from torch.fx import Node
2322

2423

@@ -51,7 +50,7 @@ def _create_clamp_node(
5150
min_fp32,
5251
max_fp32,
5352
)
54-
tosa_graph.addOperator(TosaOp.Op().CLAMP, [input_name], [output_name], attr)
53+
tosa_graph.addOperator(ts.TosaOp.Op().CLAMP, [input_name], [output_name], attr)
5554

5655
def _get_min_max_arguments(
5756
self, node: Node, dtype_min: int | float, dtype_max: int | float

backends/arm/operators/op_constant_pad_nd.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77

88
from typing import List
99

10-
import serializer.tosa_serializer as ts
1110
import torch
1211

12+
import tosa_tools.v0_80.serializer.tosa_serializer as ts
13+
1314
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
1415
get_input_qparams,
1516
)
@@ -18,7 +19,6 @@
1819
register_node_visitor,
1920
)
2021
from executorch.backends.arm.tosa_mapping import TosaArg
21-
from serializer.tosa_serializer import TosaOp
2222

2323

2424
@register_node_visitor
@@ -71,4 +71,6 @@ def define_node(
7171
attr = ts.TosaSerializerAttribute()
7272
attr.PadAttribute(tosa_graph.builder, output_pad, pad_const_qs, pad_const_fp)
7373

74-
tosa_graph.addOperator(TosaOp.Op().PAD, [inputs[0].name], [output.name], attr)
74+
tosa_graph.addOperator(
75+
ts.TosaOp.Op().PAD, [inputs[0].name], [output.name], attr
76+
)

backends/arm/operators/op_conv2d.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
# pyre-unsafe
77
from typing import List
88

9-
import serializer.tosa_serializer as ts # type: ignore
109
import torch
1110

11+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
12+
1213
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
1314
get_input_qparams,
1415
get_output_qparams,

backends/arm/operators/op_eq.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,12 @@
99

1010
import executorch.backends.arm.tosa_quant_utils as tqutils
1111

12-
import serializer.tosa_serializer as ts # type: ignore
12+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
1313
from executorch.backends.arm.operators.node_visitor import (
1414
NodeVisitor,
1515
register_node_visitor,
1616
)
1717
from executorch.backends.arm.tosa_mapping import TosaArg
18-
from serializer.tosa_serializer import TosaOp
1918

2019
from torch.fx import Node
2120

@@ -51,7 +50,7 @@ def define_node(
5150

5251
# Do the equal comparison
5352
tosa_graph.addOperator(
54-
TosaOp.Op().EQUAL,
53+
ts.TosaOp.Op().EQUAL,
5554
[input_nodes[0].name, input_nodes[1].name],
5655
output.name,
5756
None,

backends/arm/operators/op_exp.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,13 @@
66
# pyre-unsafe
77
from typing import List
88

9-
import serializer.tosa_serializer as ts # type: ignore
9+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
1010
from executorch.backends.arm.operators.node_visitor import (
1111
NodeVisitor,
1212
register_node_visitor,
1313
)
1414
from executorch.backends.arm.tosa_mapping import TosaArg
1515
from executorch.backends.arm.tosa_specification import TosaSpecification
16-
17-
from serializer.tosa_serializer import TosaOp
1816
from torch.fx import Node
1917

2018

@@ -39,4 +37,4 @@ def define_node(
3937
assert len(node.all_input_nodes) == 1
4038
assert inputs[0].dtype == output.dtype == ts.DType.FP32
4139

42-
tosa_graph.addOperator(TosaOp.Op().EXP, [inputs[0].name], [output.name])
40+
tosa_graph.addOperator(ts.TosaOp.Op().EXP, [inputs[0].name], [output.name])

backends/arm/operators/op_full.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import numpy as np
1010

11-
import serializer.tosa_serializer as ts # type: ignore
11+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
1212
from executorch.backends.arm.operators.node_visitor import (
1313
NodeVisitor,
1414
register_node_visitor,

backends/arm/operators/op_ge.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,12 @@
99

1010
import executorch.backends.arm.tosa_quant_utils as tqutils
1111

12-
import serializer.tosa_serializer as ts # type: ignore
12+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
1313
from executorch.backends.arm.operators.node_visitor import (
1414
NodeVisitor,
1515
register_node_visitor,
1616
)
1717
from executorch.backends.arm.tosa_mapping import TosaArg
18-
from serializer.tosa_serializer import TosaOp
1918

2019
from torch.fx import Node
2120

@@ -50,7 +49,7 @@ def define_node(
5049
input_nodes = rescaled_inputs
5150

5251
tosa_graph.addOperator(
53-
TosaOp.Op().GREATER_EQUAL,
52+
ts.TosaOp.Op().GREATER_EQUAL,
5453
[input_nodes[0].name, input_nodes[1].name],
5554
[output.name],
5655
None,

backends/arm/operators/op_get_item.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66
# pyre-unsafe
77
from typing import List
88

9-
import serializer.tosa_serializer as ts # type: ignore
109
import torch
10+
11+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
1112
from executorch.backends.arm.operators.node_visitor import (
1213
NodeVisitor,
1314
register_node_visitor,
1415
)
1516
from executorch.backends.arm.tosa_mapping import TosaArg
16-
from serializer.tosa_serializer import TosaOp
1717

1818

1919
@register_node_visitor
@@ -32,4 +32,4 @@ def define_node(
3232
) -> None:
3333
item_name = inputs[0].name
3434
## Simply add an identityOp
35-
tosa_graph.addOperator(TosaOp.Op().IDENTITY, [item_name], [output.name])
35+
tosa_graph.addOperator(ts.TosaOp.Op().IDENTITY, [item_name], [output.name])

0 commit comments

Comments
 (0)