Skip to content

Commit e727104

Browse files
committed
Refactor TOSA backend preprocess to use visitor pattern
Signed-off-by: Jerry Ge <[email protected]> Change-Id: I1d7275bdc81c9ad60c27f21fa604c134aa3e3646
1 parent a8e05cc commit e727104

25 files changed

+1439
-971
lines changed

backends/arm/arm_backend.py

Lines changed: 32 additions & 967 deletions
Large diffs are not rendered by default.

backends/arm/arm_partitioner.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import logging
2+
import operator
3+
import os
4+
from typing import final
5+
6+
import torch
7+
from executorch.backends.arm.arm_backend import ArmBackend
8+
from executorch.exir.backend.partitioner import (
9+
DelegationSpec,
10+
Partitioner,
11+
PartitionResult,
12+
)
13+
from executorch.exir.dialects._ops import ops as exir_ops
14+
from torch._export.exported_program import ExportedProgram
15+
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
16+
17+
from torch.fx.passes.operator_support import OperatorSupportBase
18+
19+
logger = logging.getLogger(__name__)
20+
logger.setLevel(logging.WARNING)
21+
TOSA_DBG_VERBOSE = os.environ.get("TOSA_DBG_VERBOSE") == "1"
22+
if TOSA_DBG_VERBOSE:
23+
logging.basicConfig(level=logging.INFO)
24+
logger.setLevel(logging.INFO)
25+
26+
27+
class TOSASupportedOperators(OperatorSupportBase):
28+
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
29+
supported = node.op == "call_function" and node.target in [
30+
exir_ops.edge.aten.add.Tensor,
31+
exir_ops.edge.aten.addmm.default,
32+
exir_ops.edge.aten.permute_copy.default,
33+
exir_ops.edge.aten.hardtanh.default,
34+
exir_ops.edge.aten.convolution.default,
35+
exir_ops.edge.aten.div.Tensor,
36+
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
37+
exir_ops.edge.aten.avg_pool2d.default,
38+
exir_ops.edge.aten._softmax.default,
39+
exir_ops.edge.aten.view_copy.default,
40+
exir_ops.edge.aten.clone.default,
41+
operator.getitem,
42+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
43+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
44+
]
45+
return supported
46+
47+
48+
@final
49+
class ArmPartitioner(Partitioner):
50+
compile_spec = []
51+
52+
def __init__(self) -> None:
53+
self.delegation_spec = DelegationSpec(ArmBackend.__name__, self.compile_spec)
54+
55+
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
56+
# Run the CapabilityBasedPartitioner to return the largest possible
57+
# subgraphs containing the nodes with the tags
58+
logger.info("ArmPartitioner::partition")
59+
partition_tags = {}
60+
61+
capability_partitioner = CapabilityBasedPartitioner(
62+
exported_program.graph_module,
63+
TOSASupportedOperators(),
64+
allows_single_node_partition=True,
65+
)
66+
partition_list = capability_partitioner.propose_partitions()
67+
for partition in partition_list:
68+
for node in partition.nodes:
69+
tag = f"tag{partition.id}"
70+
node.meta["delegation_tag"] = tag
71+
partition_tags[tag] = self.delegation_spec
72+
73+
return PartitionResult(
74+
tagged_exported_program=exported_program, partition_tags=partition_tags
75+
)

backends/arm/arm_vela.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import os
2+
import struct
3+
import subprocess
4+
import tempfile
5+
6+
import numpy as np
7+
8+
# Pack either input or output tensor block, compose the related arrays into
9+
# per-io structs to simplify runtime use.
10+
def vela_bin_pack_io(prefix, data):
11+
ios = struct.pack("<i", len(data[prefix + "_shape"]))
12+
for i in range(len(data[prefix + "_shape"])):
13+
io_shape = data[prefix + "_shape"][i]
14+
io_elem_size = data[prefix + "_elem_size"][i]
15+
io_offset = data[prefix + "_offset"][i]
16+
io_region = data[prefix + "_region"][i]
17+
assert len(io_shape) <= 4
18+
inp_pad = io_shape.tolist() + [0] * (4 - len(io_shape))
19+
io_struct = struct.pack(
20+
"<iiiiiii", *inp_pad, io_elem_size, io_offset, io_region
21+
)
22+
ios += io_struct
23+
return ios
24+
25+
26+
# Output via Vela to binary stream for ArmBackendEthosU
27+
# WARNING: Do not change this without changing VelaBinStream.cpp as that
28+
# function consumes this format and the two need to align.
29+
def vela_compile(tosa_graph):
30+
with tempfile.TemporaryDirectory() as tmpdir:
31+
tosaname = "out.tosa"
32+
flatbuffer = tosa_graph.serialize()
33+
with open(os.path.join(tmpdir, tosaname), "wb") as f:
34+
f.write(flatbuffer)
35+
36+
# invoke vela
37+
vela_command = (
38+
f"cd {tmpdir}; vela --accelerator-config ethos-u55-128 {tosaname}"
39+
)
40+
subprocess.run([vela_command], shell=True, check=True)
41+
42+
np_path = os.path.join(tmpdir, "output", "out_sg0_vela.npz")
43+
blocks = b""
44+
45+
with np.load(np_path, allow_pickle=False) as data:
46+
# Construct our modified output_blocks with data in a form easily
47+
# digested on the device side
48+
bin_blocks = {"vela_bin_stream": b""}
49+
50+
# copy command data through unmodified
51+
bin_blocks["cmd_data"] = data["cmd_data"].tobytes()
52+
53+
# copy weight data through unmodified
54+
bin_blocks["weight_data"] = data["weight_data"].tobytes()
55+
56+
# Add a block for scratch, inputs and outputs; scratch shape is a 1 element
57+
# array giving us size in bytes so extract this and add a block of 0's.
58+
# Currently we preallocated this on the host to provide SRAM for computation.
59+
if not isinstance(data["scratch_shape"][0], np.int64):
60+
raise RuntimeError("Expected scratch to be int64")
61+
block_length = int(data["scratch_shape"][0])
62+
bin_blocks["scratch_data"] = b"\x00" * block_length
63+
64+
# Capture inputs and outputs
65+
bin_blocks["inputs"] = vela_bin_pack_io("input", data)
66+
bin_blocks["outputs"] = vela_bin_pack_io("output", data)
67+
68+
bin_blocks["vela_end_stream"] = b""
69+
70+
# Emit the NPZ regions as:
71+
# - 16 byte block name null terminated string (padded to 16 if name shorter)
72+
# - 4 bytes of int32 block length and 12 bytes of 0's
73+
# - block data (padded to 16 byte alignment at end)
74+
# Repeat for all blocks
75+
for key in bin_blocks.keys():
76+
block_name = bytes(key, "utf8")[:15]
77+
block_name = block_name + b"\x00" * (16 - len(block_name))
78+
79+
# We need the acual unpadded block lengths for hw setup
80+
block_length = struct.pack("<iiii", len(bin_blocks[key]), 0, 0, 0)
81+
82+
# Pad block data to multiple of 16 bytes
83+
block_data = bin_blocks[key]
84+
block_data = block_data + b"\x00" * (15 - (len(block_data) - 1) % 16)
85+
86+
block = block_name + block_length + block_data
87+
blocks = blocks + block
88+
89+
return blocks

backends/arm/operators/__init__.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Copyright 2023 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from . import ( # noqa
7+
node_visitor,
8+
op_add,
9+
op_addmm,
10+
op_avg_pool2d,
11+
op_batch_norm,
12+
op_clone,
13+
op_conv2d,
14+
op_dequant,
15+
op_div,
16+
op_get_item,
17+
op_hardtanh,
18+
op_permute,
19+
op_quant,
20+
op_softmax,
21+
op_view,
22+
)
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright 2023 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Dict, List
7+
8+
import serializer.tosa_serializer as ts
9+
import torch
10+
from executorch.backends.arm.tosa_mapping import TosaArg
11+
from torch.export import ExportedProgram
12+
13+
14+
class NodeVisitor:
15+
"""
16+
Node Visitor pattern for lowering edge IR to TOSA
17+
"""
18+
19+
def __init__(self, exported_program: ExportedProgram):
20+
self._exported_program = exported_program or None
21+
22+
def define_node(
23+
self,
24+
node: torch.fx.Node,
25+
tosa_graph: ts.TosaSerializer,
26+
inputs: List[TosaArg],
27+
output: TosaArg,
28+
is_quant_node: bool,
29+
) -> None:
30+
raise NotImplementedError("NodeVisitor must be extended.")
31+
32+
33+
# container for all node visitors
34+
_node_visitor_dict = {}
35+
36+
37+
def register_node_visitor(visitor):
38+
_node_visitor_dict[visitor.target] = visitor
39+
40+
41+
def get_node_visitors(*args) -> Dict[str, NodeVisitor]:
42+
node_visitors = {}
43+
for target, visitor in _node_visitor_dict.items():
44+
node_visitors[target] = visitor(*args)
45+
46+
return node_visitors

backends/arm/operators/op_add.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Copyright 2023 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import List
7+
8+
import serializer.tosa_serializer as ts
9+
import torch
10+
from executorch.backends.arm.operators.node_visitor import (
11+
NodeVisitor,
12+
register_node_visitor,
13+
)
14+
from executorch.backends.arm.tosa_mapping import TosaArg
15+
from executorch.backends.arm.tosa_quant_utils import (
16+
buildRescaleFromInt32,
17+
buildRescaleToInt32,
18+
)
19+
from executorch.backends.arm.tosa_utils import broadcastShapes, getNodeArgs
20+
from serializer.tosa_serializer import TosaOp
21+
22+
23+
@register_node_visitor
24+
class AddVisitor(NodeVisitor):
25+
target = "aten.add.Tensor"
26+
27+
def __init__(self, *args):
28+
super().__init__(*args)
29+
30+
def define_node(
31+
self,
32+
node: torch.fx.Node,
33+
tosa_graph: ts.TosaSerializer,
34+
inputs: List[TosaArg],
35+
output: TosaArg,
36+
is_quant_node: bool,
37+
) -> None:
38+
if is_quant_node:
39+
# Single input or not
40+
if len(node.all_input_nodes) == 1:
41+
input_node_A = node.all_input_nodes[0]
42+
input_node_B = node.all_input_nodes[0]
43+
else:
44+
input_node_A, input_node_B = node.all_input_nodes
45+
46+
# Get input scale_factor and zero_points for A, B
47+
input_A, input_A_scale, input_A_zp, _, _, _ = getNodeArgs(input_node_A)
48+
input_B, input_B_scale, input_B_zp, _, _, _ = getNodeArgs(input_node_B)
49+
50+
max_scale_2x = 2.0 * max(input_A_scale.number, input_B_scale.number)
51+
inputA_rescale_scale = input_A_scale.number / max_scale_2x
52+
inputB_rescale_scale = input_B_scale.number / max_scale_2x
53+
54+
input_A_rescaled_to_int32 = buildRescaleToInt32(
55+
tosa_graph,
56+
input_A,
57+
input_A_zp.number,
58+
inputA_rescale_scale,
59+
)
60+
61+
input_B_rescaled_to_int32 = buildRescaleToInt32(
62+
tosa_graph,
63+
input_B,
64+
input_B_zp.number,
65+
inputB_rescale_scale,
66+
)
67+
68+
## Do the INT32 Add
69+
broadcasted_shape = broadcastShapes(input_A.shape, input_B.shape)
70+
add_res = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32)
71+
tosa_graph.addOperator(
72+
TosaOp.Op().ADD,
73+
[
74+
input_A_rescaled_to_int32.name,
75+
input_B_rescaled_to_int32.name,
76+
],
77+
[add_res.name],
78+
None,
79+
)
80+
81+
# Output
82+
output_node = list(node.users)[0]
83+
_, output_scale, output_zp, _, _, _ = getNodeArgs(output_node)
84+
output_rescale_scale = max_scale_2x / (output_scale.number)
85+
86+
# Rescale Back to INT8
87+
buildRescaleFromInt32(
88+
tosa_graph,
89+
add_res.name,
90+
output.name,
91+
output_zp.number,
92+
output_rescale_scale,
93+
)
94+
else:
95+
# FP32 Add lowering
96+
tosa_graph.addOperator(
97+
TosaOp.Op().ADD, [inputs[0].name, inputs[1].name], [output.name], None
98+
)

0 commit comments

Comments
 (0)