Skip to content

Refactor TOSA backend preprocess to use visitor pattern #1099

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
999 changes: 32 additions & 967 deletions backends/arm/arm_backend.py

Large diffs are not rendered by default.

75 changes: 75 additions & 0 deletions backends/arm/arm_partitioner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import logging
import operator
import os
from typing import final

import torch
from executorch.backends.arm.arm_backend import ArmBackend
from executorch.exir.backend.partitioner import (
DelegationSpec,
Partitioner,
PartitionResult,
)
from executorch.exir.dialects._ops import ops as exir_ops
from torch._export.exported_program import ExportedProgram
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner

from torch.fx.passes.operator_support import OperatorSupportBase

logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)
TOSA_DBG_VERBOSE = os.environ.get("TOSA_DBG_VERBOSE") == "1"
if TOSA_DBG_VERBOSE:
logging.basicConfig(level=logging.INFO)
logger.setLevel(logging.INFO)


class TOSASupportedOperators(OperatorSupportBase):
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
supported = node.op == "call_function" and node.target in [
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.addmm.default,
exir_ops.edge.aten.permute_copy.default,
exir_ops.edge.aten.hardtanh.default,
exir_ops.edge.aten.convolution.default,
exir_ops.edge.aten.div.Tensor,
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
exir_ops.edge.aten.avg_pool2d.default,
exir_ops.edge.aten._softmax.default,
exir_ops.edge.aten.view_copy.default,
exir_ops.edge.aten.clone.default,
operator.getitem,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
]
return supported


@final
class ArmPartitioner(Partitioner):
compile_spec = []

def __init__(self) -> None:
self.delegation_spec = DelegationSpec(ArmBackend.__name__, self.compile_spec)

def partition(self, exported_program: ExportedProgram) -> PartitionResult:
# Run the CapabilityBasedPartitioner to return the largest possible
# subgraphs containing the nodes with the tags
logger.info("ArmPartitioner::partition")
partition_tags = {}

capability_partitioner = CapabilityBasedPartitioner(
exported_program.graph_module,
TOSASupportedOperators(),
allows_single_node_partition=True,
)
partition_list = capability_partitioner.propose_partitions()
for partition in partition_list:
for node in partition.nodes:
tag = f"tag{partition.id}"
node.meta["delegation_tag"] = tag
partition_tags[tag] = self.delegation_spec

return PartitionResult(
tagged_exported_program=exported_program, partition_tags=partition_tags
)
89 changes: 89 additions & 0 deletions backends/arm/arm_vela.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import os
import struct
import subprocess
import tempfile

import numpy as np

# Pack either input or output tensor block, compose the related arrays into
# per-io structs to simplify runtime use.
def vela_bin_pack_io(prefix, data):
ios = struct.pack("<i", len(data[prefix + "_shape"]))
for i in range(len(data[prefix + "_shape"])):
io_shape = data[prefix + "_shape"][i]
io_elem_size = data[prefix + "_elem_size"][i]
io_offset = data[prefix + "_offset"][i]
io_region = data[prefix + "_region"][i]
assert len(io_shape) <= 4
inp_pad = io_shape.tolist() + [0] * (4 - len(io_shape))
io_struct = struct.pack(
"<iiiiiii", *inp_pad, io_elem_size, io_offset, io_region
)
ios += io_struct
return ios


# Output via Vela to binary stream for ArmBackendEthosU
# WARNING: Do not change this without changing VelaBinStream.cpp as that
# function consumes this format and the two need to align.
def vela_compile(tosa_graph):
with tempfile.TemporaryDirectory() as tmpdir:
tosaname = "out.tosa"
flatbuffer = tosa_graph.serialize()
with open(os.path.join(tmpdir, tosaname), "wb") as f:
f.write(flatbuffer)

# invoke vela
vela_command = (
f"cd {tmpdir}; vela --accelerator-config ethos-u55-128 {tosaname}"
)
subprocess.run([vela_command], shell=True, check=True)

np_path = os.path.join(tmpdir, "output", "out_sg0_vela.npz")
blocks = b""

with np.load(np_path, allow_pickle=False) as data:
# Construct our modified output_blocks with data in a form easily
# digested on the device side
bin_blocks = {"vela_bin_stream": b""}

# copy command data through unmodified
bin_blocks["cmd_data"] = data["cmd_data"].tobytes()

# copy weight data through unmodified
bin_blocks["weight_data"] = data["weight_data"].tobytes()

# Add a block for scratch, inputs and outputs; scratch shape is a 1 element
# array giving us size in bytes so extract this and add a block of 0's.
# Currently we preallocated this on the host to provide SRAM for computation.
if not isinstance(data["scratch_shape"][0], np.int64):
raise RuntimeError("Expected scratch to be int64")
block_length = int(data["scratch_shape"][0])
bin_blocks["scratch_data"] = b"\x00" * block_length

# Capture inputs and outputs
bin_blocks["inputs"] = vela_bin_pack_io("input", data)
bin_blocks["outputs"] = vela_bin_pack_io("output", data)

bin_blocks["vela_end_stream"] = b""

# Emit the NPZ regions as:
# - 16 byte block name null terminated string (padded to 16 if name shorter)
# - 4 bytes of int32 block length and 12 bytes of 0's
# - block data (padded to 16 byte alignment at end)
# Repeat for all blocks
for key in bin_blocks.keys():
block_name = bytes(key, "utf8")[:15]
block_name = block_name + b"\x00" * (16 - len(block_name))

# We need the acual unpadded block lengths for hw setup
block_length = struct.pack("<iiii", len(bin_blocks[key]), 0, 0, 0)

# Pad block data to multiple of 16 bytes
block_data = bin_blocks[key]
block_data = block_data + b"\x00" * (15 - (len(block_data) - 1) % 16)

block = block_name + block_length + block_data
blocks = blocks + block

return blocks
22 changes: 22 additions & 0 deletions backends/arm/operators/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright 2023 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from . import ( # noqa
node_visitor,
op_add,
op_addmm,
op_avg_pool2d,
op_batch_norm,
op_clone,
op_conv2d,
op_dequant,
op_div,
op_get_item,
op_hardtanh,
op_permute,
op_quant,
op_softmax,
op_view,
)
46 changes: 46 additions & 0 deletions backends/arm/operators/node_visitor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright 2023 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Dict, List

import serializer.tosa_serializer as ts
import torch
from executorch.backends.arm.tosa_mapping import TosaArg
from torch.export import ExportedProgram


class NodeVisitor:
"""
Node Visitor pattern for lowering edge IR to TOSA
"""

def __init__(self, exported_program: ExportedProgram):
self._exported_program = exported_program or None

def define_node(
self,
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
is_quant_node: bool,
) -> None:
raise NotImplementedError("NodeVisitor must be extended.")


# container for all node visitors
_node_visitor_dict = {}


def register_node_visitor(visitor):
_node_visitor_dict[visitor.target] = visitor


def get_node_visitors(*args) -> Dict[str, NodeVisitor]:
node_visitors = {}
for target, visitor in _node_visitor_dict.items():
node_visitors[target] = visitor(*args)

return node_visitors
98 changes: 98 additions & 0 deletions backends/arm/operators/op_add.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright 2023 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import List

import serializer.tosa_serializer as ts
import torch
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_quant_utils import (
buildRescaleFromInt32,
buildRescaleToInt32,
)
from executorch.backends.arm.tosa_utils import broadcastShapes, getNodeArgs
from serializer.tosa_serializer import TosaOp


@register_node_visitor
class AddVisitor(NodeVisitor):
target = "aten.add.Tensor"

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
is_quant_node: bool,
) -> None:
if is_quant_node:
# Single input or not
if len(node.all_input_nodes) == 1:
input_node_A = node.all_input_nodes[0]
input_node_B = node.all_input_nodes[0]
else:
input_node_A, input_node_B = node.all_input_nodes

# Get input scale_factor and zero_points for A, B
input_A, input_A_scale, input_A_zp, _, _, _ = getNodeArgs(input_node_A)
input_B, input_B_scale, input_B_zp, _, _, _ = getNodeArgs(input_node_B)

max_scale_2x = 2.0 * max(input_A_scale.number, input_B_scale.number)
inputA_rescale_scale = input_A_scale.number / max_scale_2x
inputB_rescale_scale = input_B_scale.number / max_scale_2x

input_A_rescaled_to_int32 = buildRescaleToInt32(
tosa_graph,
input_A,
input_A_zp.number,
inputA_rescale_scale,
)

input_B_rescaled_to_int32 = buildRescaleToInt32(
tosa_graph,
input_B,
input_B_zp.number,
inputB_rescale_scale,
)

## Do the INT32 Add
broadcasted_shape = broadcastShapes(input_A.shape, input_B.shape)
add_res = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32)
tosa_graph.addOperator(
TosaOp.Op().ADD,
[
input_A_rescaled_to_int32.name,
input_B_rescaled_to_int32.name,
],
[add_res.name],
None,
)

# Output
output_node = list(node.users)[0]
_, output_scale, output_zp, _, _, _ = getNodeArgs(output_node)
output_rescale_scale = max_scale_2x / (output_scale.number)

# Rescale Back to INT8
buildRescaleFromInt32(
tosa_graph,
add_res.name,
output.name,
output_zp.number,
output_rescale_scale,
)
else:
# FP32 Add lowering
tosa_graph.addOperator(
TosaOp.Op().ADD, [inputs[0].name, inputs[1].name], [output.name], None
)
Loading