Skip to content

Arm backend: Tosa tools update #9451

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ ignore_missing_imports = True
[mypy-serializer.*]
ignore_missing_imports = True

[mypy-tosa_tools.*]
ignore_missing_imports = True

[mypy-setuptools.*]
ignore_missing_imports = True

Expand Down
3 changes: 2 additions & 1 deletion backends/arm/operators/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@

from typing import Dict, List

import serializer.tosa_serializer as ts # type: ignore
import torch

import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_specification import TosaSpecification
from torch.export import ExportedProgram
Expand Down
8 changes: 3 additions & 5 deletions backends/arm/operators/op_abs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,13 @@
import executorch.backends.arm.tosa_quant_utils as tqutils
import executorch.backends.arm.tosa_utils as tutils

import serializer.tosa_serializer as ts # type: ignore
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_specification import TosaSpecification

from serializer.tosa_serializer import TosaOp
from torch.fx import Node


Expand Down Expand Up @@ -70,7 +68,7 @@ def define_node(

# Do the INT32 Abs
tosa_graph.addOperator(
TosaOp.Op().ABS,
ts.TosaOp.Op().ABS,
[
rescaled_inputs[0].name,
],
Expand Down Expand Up @@ -126,7 +124,7 @@ def define_node(

# MI lowering
tosa_graph.addOperator(
TosaOp.Op().ABS,
ts.TosaOp.Op().ABS,
[inputs[0].name],
[output.name],
None,
Expand Down
7 changes: 3 additions & 4 deletions backends/arm/operators/op_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,13 @@
import executorch.backends.arm.tosa_quant_utils as tqutils
import executorch.backends.arm.tosa_utils as tutils

import serializer.tosa_serializer as ts # type: ignore
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_specification import TosaSpecification
from serializer.tosa_serializer import TosaOp
from torch.fx import Node


Expand Down Expand Up @@ -82,7 +81,7 @@ def define_node(

# Do the INT32 Add
tosa_graph.addOperator(
TosaOp.Op().ADD,
ts.TosaOp.Op().ADD,
[input1.name, input2.name],
[add_output.name],
None,
Expand Down Expand Up @@ -135,7 +134,7 @@ def define_node(

# MI lowering
tosa_graph.addOperator(
TosaOp.Op().ADD,
ts.TosaOp.Op().ADD,
[input1.name, input2.name],
[output.name],
None,
Expand Down
5 changes: 2 additions & 3 deletions backends/arm/operators/op_amax.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@
# LICENSE file in the root directory of this source tree.
from typing import List

import serializer.tosa_serializer as ts
import tosa_tools.v0_80.serializer.tosa_serializer as ts
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from serializer.tosa_serializer import TosaOp
from torch.fx import Node


Expand Down Expand Up @@ -48,5 +47,5 @@ def define_node(
attr.AxisAttribute(input.dim_order.index(dim))

tosa_graph.addOperator(
TosaOp.Op().REDUCE_MAX, [input.name], [output.name], attr
ts.TosaOp.Op().REDUCE_MAX, [input.name], [output.name], attr
)
5 changes: 2 additions & 3 deletions backends/arm/operators/op_amin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@
# LICENSE file in the root directory of this source tree.
from typing import List

import serializer.tosa_serializer as ts
import tosa_tools.v0_80.serializer.tosa_serializer as ts
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from serializer.tosa_serializer import TosaOp
from torch.fx import Node


Expand Down Expand Up @@ -48,5 +47,5 @@ def define_node(
attr.AxisAttribute(input.dim_order.index(dim))

tosa_graph.addOperator(
TosaOp.Op().REDUCE_MIN, [input.name], [output.name], attr
ts.TosaOp.Op().REDUCE_MIN, [input.name], [output.name], attr
)
5 changes: 2 additions & 3 deletions backends/arm/operators/op_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@
# pyre-unsafe
from typing import cast, List

import serializer.tosa_serializer as ts # type: ignore
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
from executorch.backends.arm.operators.node_visitor import ( # type: ignore
NodeVisitor,
register_node_visitor,
)

from executorch.backends.arm.tosa_mapping import TosaArg # type: ignore
from serializer.tosa_serializer import TosaOp
from torch.fx import Node


Expand Down Expand Up @@ -49,5 +48,5 @@ def define_node(
attr.AxisAttribute(inputs[0].dim_order.index(dim))

tosa_graph.addOperator(
TosaOp.Op().REDUCE_ANY, [inputs[0].name], [output.name], attr
ts.TosaOp.Op().REDUCE_ANY, [inputs[0].name], [output.name], attr
)
3 changes: 2 additions & 1 deletion backends/arm/operators/op_avg_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
# pyre-unsafe
from typing import List

import serializer.tosa_serializer as ts # type: ignore
import torch

import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore

from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
get_input_qparams,
get_output_qparams,
Expand Down
6 changes: 3 additions & 3 deletions backends/arm/operators/op_bmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
# pyre-unsafe
from typing import List

import serializer.tosa_serializer as ts # type: ignore
import torch

import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore

from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
get_input_qparams,
get_output_qparams,
Expand All @@ -20,7 +21,6 @@
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_quant_utils import build_rescale
from serializer.tosa_serializer import TosaOp


@register_node_visitor
Expand Down Expand Up @@ -64,7 +64,7 @@ def define_node(
attr.MatMulAttribute(A_zp=input0_zp, B_zp=input1_zp)

tosa_graph.addOperator(
TosaOp.Op().MATMUL,
ts.TosaOp.Op().MATMUL,
[inputs[0].name, inputs[1].name],
[bmm_output_name],
attr,
Expand Down
8 changes: 5 additions & 3 deletions backends/arm/operators/op_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,12 @@

from typing import List

import serializer.tosa_serializer as ts # type: ignore
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from serializer.tosa_serializer import TosaOp
from torch.fx import Node


Expand Down Expand Up @@ -42,5 +41,8 @@ def define_node(
attr.AxisAttribute(dim)

tosa_graph.addOperator(
TosaOp.Op().CONCAT, [tensor.name for tensor in tensors], [output.name], attr
ts.TosaOp.Op().CONCAT,
[tensor.name for tensor in tensors],
[output.name],
attr,
)
7 changes: 3 additions & 4 deletions backends/arm/operators/op_clamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,16 @@

from typing import Any, List, Tuple

import serializer.tosa_serializer as ts # type: ignore

import torch

import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)

from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_specification import TosaSpecification
from serializer.tosa_serializer import TosaOp
from torch.fx import Node


Expand Down Expand Up @@ -51,7 +50,7 @@ def _create_clamp_node(
min_fp32,
max_fp32,
)
tosa_graph.addOperator(TosaOp.Op().CLAMP, [input_name], [output_name], attr)
tosa_graph.addOperator(ts.TosaOp.Op().CLAMP, [input_name], [output_name], attr)

def _get_min_max_arguments(
self, node: Node, dtype_min: int | float, dtype_max: int | float
Expand Down
8 changes: 5 additions & 3 deletions backends/arm/operators/op_constant_pad_nd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@

from typing import List

import serializer.tosa_serializer as ts
import torch

import tosa_tools.v0_80.serializer.tosa_serializer as ts

from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
get_input_qparams,
)
Expand All @@ -18,7 +19,6 @@
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from serializer.tosa_serializer import TosaOp


@register_node_visitor
Expand Down Expand Up @@ -71,4 +71,6 @@ def define_node(
attr = ts.TosaSerializerAttribute()
attr.PadAttribute(tosa_graph.builder, output_pad, pad_const_qs, pad_const_fp)

tosa_graph.addOperator(TosaOp.Op().PAD, [inputs[0].name], [output.name], attr)
tosa_graph.addOperator(
ts.TosaOp.Op().PAD, [inputs[0].name], [output.name], attr
)
3 changes: 2 additions & 1 deletion backends/arm/operators/op_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
# pyre-unsafe
from typing import List

import serializer.tosa_serializer as ts # type: ignore
import torch

import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore

from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
get_input_qparams,
get_output_qparams,
Expand Down
5 changes: 2 additions & 3 deletions backends/arm/operators/op_eq.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@

import executorch.backends.arm.tosa_quant_utils as tqutils

import serializer.tosa_serializer as ts # type: ignore
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from serializer.tosa_serializer import TosaOp

from torch.fx import Node

Expand Down Expand Up @@ -53,7 +52,7 @@ def define_node(

# Do the equal comparison
tosa_graph.addOperator(
TosaOp.Op().EQUAL,
ts.TosaOp.Op().EQUAL,
[input_nodes[0].name, input_nodes[1].name],
output.name,
None,
Expand Down
6 changes: 3 additions & 3 deletions backends/arm/operators/op_erf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
# pyre-unsafe
from typing import List

import serializer.tosa_serializer as ts # type: ignore
import torch.fx

import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_specification import TosaSpecification
from serializer.tosa_serializer import TosaOp


@register_node_visitor
Expand Down Expand Up @@ -41,4 +41,4 @@ def define_node(
if not (inputs[0].dtype == ts.DType.FP32):
raise ValueError("All inputs need to be FP32." f"Got {inputs[0].dtype=}")
# MI lowering
tosa_graph.addOperator(TosaOp.Op().ERF, [inputs[0].name], [output.name])
tosa_graph.addOperator(ts.TosaOp.Op().ERF, [inputs[0].name], [output.name])
6 changes: 2 additions & 4 deletions backends/arm/operators/op_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,13 @@
# pyre-unsafe
from typing import List

import serializer.tosa_serializer as ts # type: ignore
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_specification import TosaSpecification

from serializer.tosa_serializer import TosaOp
from torch.fx import Node


Expand Down Expand Up @@ -46,4 +44,4 @@ def define_node(
f"{inputs[0].dtype} and output dtype: {output.dtype}"
)

tosa_graph.addOperator(TosaOp.Op().EXP, [inputs[0].name], [output.name])
tosa_graph.addOperator(ts.TosaOp.Op().EXP, [inputs[0].name], [output.name])
2 changes: 1 addition & 1 deletion backends/arm/operators/op_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import numpy as np

import serializer.tosa_serializer as ts # type: ignore
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
Expand Down
5 changes: 2 additions & 3 deletions backends/arm/operators/op_ge.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@

import executorch.backends.arm.tosa_quant_utils as tqutils

import serializer.tosa_serializer as ts # type: ignore
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from serializer.tosa_serializer import TosaOp

from torch.fx import Node

Expand Down Expand Up @@ -52,7 +51,7 @@ def define_node(
input_nodes = rescaled_inputs

tosa_graph.addOperator(
TosaOp.Op().GREATER_EQUAL,
ts.TosaOp.Op().GREATER_EQUAL,
[input_nodes[0].name, input_nodes[1].name],
[output.name],
None,
Expand Down
Loading
Loading