From 74adfc15cf6291d62428c144892b103c5778cee9 Mon Sep 17 00:00:00 2001 From: haowhsu-quic Date: Mon, 8 Apr 2024 10:23:11 +0800 Subject: [PATCH 1/3] Qualcomm AI Engine Direct - support static llama2 with kv_cache summary - support static kv_cached llama2 model - add qnn_llama_runner - add e2e example script verified with story110M --- backends/qualcomm/builders/__init__.py | 2 + backends/qualcomm/builders/node_visitor.py | 8 +- backends/qualcomm/builders/op_embedding.py | 2 +- backends/qualcomm/builders/op_skip_ops.py | 4 +- backends/qualcomm/builders/op_split.py | 85 ++++ backends/qualcomm/builders/qnn_constants.py | 7 + backends/qualcomm/partition/common_defs.py | 3 +- .../qualcomm/partition/qnn_partitioner.py | 3 +- .../passes/fuse_consecutive_transpose.py | 69 ++++ backends/qualcomm/qnn_preprocess.py | 8 +- backends/qualcomm/quantizer/quantizer.py | 226 +--------- backends/qualcomm/quantizer/utils.py | 251 ++++++++++-- .../runtime/backends/QnnBackendCache.cpp | 2 + backends/qualcomm/scripts/build.sh | 1 + backends/qualcomm/tests/models.py | 2 +- backends/qualcomm/utils/utils.py | 41 ++ .../models/llama2/tokenizer/bpe_tokenizer.cpp | 2 +- examples/qualcomm/CMakeLists.txt | 43 +- .../executor_runner/qnn_llama_runner.cpp | 152 +++++++ examples/qualcomm/llama2/llama.py | 275 +++++++++++++ .../qualcomm/llama2/model/static_llama.py | 334 +++++++++++++++ examples/qualcomm/llama2/runner/runner.cpp | 385 ++++++++++++++++++ examples/qualcomm/llama2/runner/runner.h | 102 +++++ examples/qualcomm/scripts/utils.py | 129 ++++-- extension/module/module.cpp | 7 + extension/module/module.h | 10 + 26 files changed, 1860 insertions(+), 293 deletions(-) create mode 100644 backends/qualcomm/builders/op_split.py create mode 100644 backends/qualcomm/passes/fuse_consecutive_transpose.py create mode 100644 examples/qualcomm/executor_runner/qnn_llama_runner.cpp create mode 100644 examples/qualcomm/llama2/llama.py create mode 100644 examples/qualcomm/llama2/model/static_llama.py create mode 100644 examples/qualcomm/llama2/runner/runner.cpp create mode 100644 examples/qualcomm/llama2/runner/runner.h diff --git a/backends/qualcomm/builders/__init__.py b/backends/qualcomm/builders/__init__.py index b63a5583b10..349c4404991 100644 --- a/backends/qualcomm/builders/__init__.py +++ b/backends/qualcomm/builders/__init__.py @@ -41,6 +41,7 @@ op_skip_ops, op_slice_copy, op_softmax, + op_split, op_squeeze, op_sub, op_tanh, @@ -85,6 +86,7 @@ op_skip_ops, op_slice_copy, op_softmax, + op_split, op_squeeze, op_sub, op_tanh, diff --git a/backends/qualcomm/builders/node_visitor.py b/backends/qualcomm/builders/node_visitor.py index 3dae32f882e..56414119d6e 100644 --- a/backends/qualcomm/builders/node_visitor.py +++ b/backends/qualcomm/builders/node_visitor.py @@ -283,6 +283,7 @@ def define_tensor( nodes_to_wrappers: Dict[str, PyQnnWrapper.TensorWrapper], is_input_tensor: bool, node_name: str = None, + wrapper_idx: int = 0, is_tensor: bool = True, ) -> PyQnnWrapper.TensorWrapper: """ @@ -299,8 +300,9 @@ def define_tensor( if node_name is None: node_name = node.name - if node_name in nodes_to_wrappers: - return nodes_to_wrappers[node_name] + if cached := nodes_to_wrappers[node_name].get(wrapper_idx, None): + return cached + tensor_name = node.name if is_graph_output(node): tensor_name = "output_" + tensor_name @@ -341,7 +343,7 @@ def define_tensor( tensor.detach().numpy(), True, ) - nodes_to_wrappers[node_name] = tensor_wrapper + nodes_to_wrappers[node_name][wrapper_idx] = tensor_wrapper return tensor_wrapper def define_node( diff --git a/backends/qualcomm/builders/op_embedding.py b/backends/qualcomm/builders/op_embedding.py index 905578790c0..a5d6aae1702 100644 --- a/backends/qualcomm/builders/op_embedding.py +++ b/backends/qualcomm/builders/op_embedding.py @@ -34,7 +34,7 @@ def define_node( weight_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, nodes_to_wrappers, - is_input_tensor=False, + is_input_tensor=True, ) indices_node = node.args[1] diff --git a/backends/qualcomm/builders/op_skip_ops.py b/backends/qualcomm/builders/op_skip_ops.py index 9a1839f604e..837fb84d3ca 100644 --- a/backends/qualcomm/builders/op_skip_ops.py +++ b/backends/qualcomm/builders/op_skip_ops.py @@ -46,5 +46,7 @@ def define_node( raise AssertionError( f"Invalid number of index for {node.name }: {len(node.args[1])}" ) - nodes_to_wrappers[node.name] = nodes_to_wrappers.get(node.args[0].name) + nodes_to_wrappers[node.name] = { + 0: nodes_to_wrappers.get(node.args[0].name).get(node.args[1]) + } return diff --git a/backends/qualcomm/builders/op_split.py b/backends/qualcomm/builders/op_split.py new file mode 100644 index 00000000000..b964586c446 --- /dev/null +++ b/backends/qualcomm/builders/op_split.py @@ -0,0 +1,85 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import cast, Dict + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper + +import numpy as np +import torch + +from .node_visitor import NodeVisitor, register_node_visitor +from .qnn_constants import OpSplit, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class Softmax(NodeVisitor): + target = ["aten.split_with_sizes.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + input_node = node.args[0] + input_tensor = self.get_tensor(input_node, node) + input_tensor_wrapper = self.define_tensor( + input_node, + input_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + is_input_tensor=True, + ) + split_input_tensors = [input_tensor_wrapper] + + axis = 0 if len(node.args) < 3 else cast(int, node.args[2]) + if axis < 0: + axis = axis % len(input_tensor.shape) + if "axis_order" in node.meta: + axis = node.meta["axis_order"].index(axis) + + # this is not the general case, only a quick workaround here + index = np.arange(1, input_tensor.shape[axis], dtype=np.uint32) + index_shape = [len(index)] + + split_output_tensors = [] + for i in range(input_tensor.shape[axis]): + output_tensor = self.get_tensor(node, node, i) + output_tensor_wrapper = self.define_tensor( + node, + output_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + is_input_tensor=False, + wrapper_idx=i, + ) + split_output_tensors.append(output_tensor_wrapper) + + split_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpSplit.op_name, + ) + split_op.AddInputTensors(split_input_tensors) + split_op.AddOutputTensors(split_output_tensors) + + split_op.AddScalarParam( + OpSplit.param_axis, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + {"data": np.uint32(axis)}, + ) + split_op.AddTensorParam( + OpSplit.param_split_index, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + len(index_shape), + index_shape, + index, + True, + ) + + return split_op diff --git a/backends/qualcomm/builders/qnn_constants.py b/backends/qualcomm/builders/qnn_constants.py index 82c50046bee..da64594cfab 100644 --- a/backends/qualcomm/builders/qnn_constants.py +++ b/backends/qualcomm/builders/qnn_constants.py @@ -247,6 +247,13 @@ class OpSoftmax: param_beta: str = "beta" +@dataclass(init=False, frozen=True) +class OpSplit: + op_name: str = "Split" + param_axis: str = "axis" + param_split_index: str = "split_index" + + @dataclass(init=False, frozen=True) class OpSqueeze: op_name: str = "Squeeze" diff --git a/backends/qualcomm/partition/common_defs.py b/backends/qualcomm/partition/common_defs.py index b06a5766a63..280d2ac452e 100644 --- a/backends/qualcomm/partition/common_defs.py +++ b/backends/qualcomm/partition/common_defs.py @@ -11,8 +11,9 @@ not_supported_operator = [ exir_ops.edge.aten.arange.start_step, exir_ops.edge.aten.clone.default, - exir_ops.edge.aten.index.Tensor, exir_ops.edge.aten.full.default, + exir_ops.edge.aten.index.Tensor, + exir_ops.edge.aten.index_put.default, ] allow_list_operator = [ diff --git a/backends/qualcomm/partition/qnn_partitioner.py b/backends/qualcomm/partition/qnn_partitioner.py index a704d3a6336..afb43778349 100644 --- a/backends/qualcomm/partition/qnn_partitioner.py +++ b/backends/qualcomm/partition/qnn_partitioner.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import copy +from collections import defaultdict from typing import Any, Dict, List import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager @@ -49,7 +50,7 @@ def __init__( ) self.skip_node_id_set = skip_node_id_set - self.nodes_to_wrappers = {} + self.nodes_to_wrappers = defaultdict(dict) self.qnn_manager = PyQnnManager.QnnManager( generate_qnn_executorch_option(compiler_specs) ) diff --git a/backends/qualcomm/passes/fuse_consecutive_transpose.py b/backends/qualcomm/passes/fuse_consecutive_transpose.py new file mode 100644 index 00000000000..aa5dc151a29 --- /dev/null +++ b/backends/qualcomm/passes/fuse_consecutive_transpose.py @@ -0,0 +1,69 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import torch + +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult +from executorch.exir.passes import dead_code_elimination_pass + + +class FuseConsecutiveTranspose(ExportPass): + """ + This pass fuses consecutive transpose / permute into one to reduce runtime + overhead + """ + + def __init__(self): + super().__init__() + self.op_map = { + exir_ops.edge.aten.permute_copy.default, + } + self.visited = set() + self.nodes = [] + + def _traverse(self, node): + if node.op == "call_function" and node.target in self.op_map: + self.nodes.append(node) + self.visited.add(node) + if len(node.users) == 1: + self._traverse(list(node.users)[0]) + + def _fuse(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: + graph = graph_module.graph + for n in graph_module.graph.nodes: + if n in self.visited: + continue + if n.op == "call_function" and n.target in self.op_map: + self._traverse(n) + num_nodes = len(self.nodes) + if num_nodes > 1: + input_node, output_node = self.nodes[0].args[0], self.nodes[-1] + input_shape = input_node.meta["val"].shape + axis_order = torch.arange(len(input_shape)).tolist() + for node in self.nodes: + axis_order = [axis_order[i] for i in node.args[1]] + with graph.inserting_after(input_node): + permute_op = exir_ops.edge.aten.permute_copy.default + permute_node = graph.create_node( + "call_function", permute_op, (input_node, axis_order) + ) + users = output_node.users.copy() + for user in users: + user.replace_input_with(output_node, permute_node) + # copy metadata + + permute_node.meta = output_node.meta + # clear current stack + + self.nodes = [] + + def call(self, graph_module: torch.fx.GraphModule): + self._fuse(graph_module) + graph_module.recompile() + dead_code_elimination_pass(graph_module) + return PassResult(graph_module, True) diff --git a/backends/qualcomm/qnn_preprocess.py b/backends/qualcomm/qnn_preprocess.py index 1e6275892f0..1fac8056e80 100644 --- a/backends/qualcomm/qnn_preprocess.py +++ b/backends/qualcomm/qnn_preprocess.py @@ -5,12 +5,16 @@ # LICENSE file in the root directory of this source tree. import logging +from collections import defaultdict from typing import final, List import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager from executorch.backends.qualcomm.builders.node_visitor import get_node_visitors from executorch.backends.qualcomm.passes.convert_to_linear import ConvertToLinear +from executorch.backends.qualcomm.passes.fuse_consecutive_transpose import ( + FuseConsecutiveTranspose, +) from executorch.backends.qualcomm.passes.insert_io_qdq import InsertIOQDQ from executorch.backends.qualcomm.passes.insert_requantize import InsertRequantize from executorch.backends.qualcomm.passes.layout_transform import LayoutTransform @@ -47,6 +51,8 @@ def preprocess( InsertRequantize(edge_program), InsertIOQDQ(edge_program), LayoutTransform(edge_program, insert_permute=True), + # please enable this when apply convert_linear_to_conv2d + # FuseConsecutiveTranspose(), ] ) @@ -54,7 +60,7 @@ def preprocess( assert pass_result is not None enable_tensor_dump = qnn_manager.IsTensorDump() - nodes_to_wrappers = {} + nodes_to_wrappers = defaultdict(dict) node_visitors = get_node_visitors( edge_program, enable_tensor_dump=enable_tensor_dump ) diff --git a/backends/qualcomm/quantizer/quantizer.py b/backends/qualcomm/quantizer/quantizer.py index 1414af171a4..79e3e292057 100644 --- a/backends/qualcomm/quantizer/quantizer.py +++ b/backends/qualcomm/quantizer/quantizer.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. from enum import IntEnum, unique -from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple +from typing import Callable, Dict, Optional, Sequence, Set import torch from executorch.backends.qualcomm.passes.convert_hardsigmoid import ConvertHardsigmoid @@ -16,23 +16,18 @@ from executorch.backends.qualcomm.passes.remove_clone import RemoveClone from executorch.backends.qualcomm.passes.replace_inf_buffer import ReplaceInfBuffer -from torch import Tensor from torch._ops import OpOverload -from torch.ao.quantization.observer import ( - HistogramObserver, - MinMaxObserver, - MovingAverageMinMaxObserver, - PerChannelMinMaxObserver, +from torch.ao.quantization.quantizer import Quantizer +from torch.fx import GraphModule + +from .utils import ( + get_16a4w_qnn_ptq_config, + get_default_16bit_qnn_ptq_config, + get_default_8bit_qnn_ptq_config, + get_ptq_per_channel_weight_config, + OP_ANNOTATOR, + QuantizationConfig, ) -from torch.ao.quantization.quantizer import ( - DerivedQuantizationSpec, - QuantizationSpec, - Quantizer, -) - -from torch.fx import GraphModule, Node - -from .utils import OP_ANNOTATOR, QuantizationConfig __all__ = [ "QnnQuantizer", @@ -54,205 +49,6 @@ class QuantDtype(IntEnum): use_8a8w = 2 -def _derived_bias_quant_spec(node: Node) -> DerivedQuantizationSpec: - def _derive_bias_qparams_fn( - obs_or_fqs: List, - ) -> Tuple[Tensor, Tensor]: - assert ( - len(obs_or_fqs) == 2 - ), f"Expecting two obs/fqs, one for activation and one for weight, got: {len(obs_or_fqs)}" - act_obs_or_fq = obs_or_fqs[0] - weight_obs_or_fq = obs_or_fqs[1] - weight_scale, weight_zp = weight_obs_or_fq.calculate_qparams() - act_scale, act_zp = act_obs_or_fq.calculate_qparams() - (broadcast_act_scale, broadcast_weight_scale) = torch.broadcast_tensors( - act_scale, weight_scale - ) - derived_scale = (broadcast_act_scale * broadcast_weight_scale).to(torch.float32) - derived_zero = torch.zeros(derived_scale.size()).to(torch.int32) - return (derived_scale, derived_zero) - - input_act = node.args[0] - assert isinstance(input_act, Node) - weight = node.args[1] - assert isinstance(weight, Node) - - return DerivedQuantizationSpec( - derived_from=[(input_act, node), (weight, node)], - derive_qparams_fn=_derive_bias_qparams_fn, - dtype=torch.int32, - quant_min=torch.iinfo(torch.int32).min, - quant_max=torch.iinfo(torch.int32).max, - ch_axis=0, - qscheme=torch.per_channel_symmetric, - ) - - -def get_default_8bit_qnn_ptq_config() -> QuantizationConfig: - extra_args: Dict[str, Any] = {"eps": 2**-12} - - act_quantization_spec = QuantizationSpec( - dtype=torch.uint8, - quant_min=0, - quant_max=torch.iinfo(torch.uint8).max, - qscheme=torch.per_tensor_affine, - observer_or_fake_quant_ctr=MovingAverageMinMaxObserver.with_args(**extra_args), - ) - - weight_quantization_spec = QuantizationSpec( - dtype=torch.int8, - quant_min=torch.iinfo(torch.int8).min + 1, - quant_max=torch.iinfo(torch.int8).max, - qscheme=torch.per_tensor_symmetric, - ch_axis=0, - observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), - ) - - bias_quantization_spec = QuantizationSpec( - dtype=torch.int32, - quant_min=torch.iinfo(torch.int32).min, - quant_max=torch.iinfo(torch.int32).max, - qscheme=torch.per_tensor_symmetric, - observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), - ) - - quantization_config = QuantizationConfig( - input_activation=act_quantization_spec, - output_activation=act_quantization_spec, - weight=weight_quantization_spec, - bias=bias_quantization_spec, - ) - - return quantization_config - - -# 4 bits quantization only supports specific ops. -def get_16a4w_qnn_ptq_config() -> QuantizationConfig: - extra_args: Dict[str, Any] = {"eps": 2**-20} - act_quantization_spec = QuantizationSpec( - dtype=torch.int32, - quant_min=torch.iinfo(torch.uint16).min, - quant_max=torch.iinfo(torch.uint16).max, - qscheme=torch.per_tensor_affine, - observer_or_fake_quant_ctr=MovingAverageMinMaxObserver.with_args(**extra_args), - ) - - weight_quantization_spec = QuantizationSpec( - dtype=torch.int8, - quant_min=-7, - quant_max=7, - qscheme=torch.per_tensor_symmetric, - ch_axis=0, - observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), - ) - - bias_quantization_spec = QuantizationSpec( - dtype=torch.int32, - quant_min=torch.iinfo(torch.int32).min, - quant_max=torch.iinfo(torch.int32).max, - qscheme=torch.per_tensor_symmetric, - observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), - ) - - quantization_config = QuantizationConfig( - input_activation=act_quantization_spec, - output_activation=act_quantization_spec, - weight=weight_quantization_spec, - bias=bias_quantization_spec, - ) - - return quantization_config - - -def get_default_16bit_qnn_ptq_config() -> QuantizationConfig: - extra_args: Dict[str, Any] = {"eps": 2**-20} - act_quantization_spec = QuantizationSpec( - dtype=torch.int32, - quant_min=torch.iinfo(torch.uint16).min, - quant_max=torch.iinfo(torch.uint16).max, - qscheme=torch.per_tensor_affine, - observer_or_fake_quant_ctr=MovingAverageMinMaxObserver.with_args(**extra_args), - ) - - weight_quantization_spec = QuantizationSpec( - dtype=torch.int16, - quant_min=torch.iinfo(torch.int16).min + 1, - quant_max=torch.iinfo(torch.int16).max, - qscheme=torch.per_tensor_symmetric, - ch_axis=0, - observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), - ) - - # torch does not support uint16 quantization, use int32 to bypass - bias_quantization_spec = QuantizationSpec( - dtype=torch.int32, - quant_min=torch.iinfo(torch.int32).min, - quant_max=torch.iinfo(torch.int32).max, - qscheme=torch.per_tensor_symmetric, - observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), - ) - - quantization_config = QuantizationConfig( - input_activation=act_quantization_spec, - output_activation=act_quantization_spec, - weight=weight_quantization_spec, - bias=bias_quantization_spec, - ) - - return quantization_config - - -def get_ptq_per_channel_weight_config( - act_dtype=torch.uint8, weight_dtype=torch.int8 -) -> QuantizationConfig: - extra_args: Dict[str, Any] = {"eps": 2**-12} - - supported_act_types = { - torch.uint8, - torch.uint16, - torch.int8, - torch.int16, - } - # TODO accept "int4" temporally. Remove "int4" when torch support torch.int4 dtype - supported_weight_dtypes = {"int4", torch.int8, torch.int16} - assert ( - act_dtype in supported_act_types - ), f"act_dtype, {act_dtype} is not one of supported types, {supported_act_types}" - - assert ( - weight_dtype in supported_weight_dtypes - ), f"weight_dtype, {weight_dtype} is not one of supported types, {supported_weight_dtypes}" - - # torch do not support uint16 quantization, use int32 to bypass - act_quantization_spec = QuantizationSpec( - dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, - quant_min=torch.iinfo(act_dtype).min, - quant_max=torch.iinfo(act_dtype).max, - qscheme=torch.per_tensor_affine, - observer_or_fake_quant_ctr=HistogramObserver.with_args(**extra_args), - ) - - weight_quantization_spec = QuantizationSpec( - dtype=torch.int8 if weight_dtype == "int4" else weight_dtype, - quant_min=-7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).min + 1, - quant_max=7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).max, - qscheme=torch.per_channel_symmetric, - ch_axis=0, - observer_or_fake_quant_ctr=PerChannelMinMaxObserver.with_args(**extra_args), - ) - - bias_quantization_spec = _derived_bias_quant_spec - - quantization_config = QuantizationConfig( - input_activation=act_quantization_spec, - output_activation=act_quantization_spec, - weight=weight_quantization_spec, - bias=bias_quantization_spec, - ) - - return quantization_config - - class QnnQuantizer(Quantizer): SUPPORTED_OPS: Set = set(OP_ANNOTATOR.keys()) diff --git a/backends/qualcomm/quantizer/utils.py b/backends/qualcomm/quantizer/utils.py index ee6eb1608d1..bd28752617c 100644 --- a/backends/qualcomm/quantizer/utils.py +++ b/backends/qualcomm/quantizer/utils.py @@ -4,28 +4,33 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. from dataclasses import dataclass -from typing import Callable, Dict, List, Optional, Sequence +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple import torch +from torch import Tensor from torch._ops import OpOverload from torch._subclasses import FakeTensor +from torch.ao.quantization.observer import ( + HistogramObserver, + MinMaxObserver, + MovingAverageMinMaxObserver, + PerChannelMinMaxObserver, +) + from torch.ao.quantization.quantizer import ( + DerivedQuantizationSpec, QuantizationAnnotation, QuantizationSpec, SharedQuantizationSpec, ) - from torch.ao.quantization.quantizer.utils import ( _annotate_input_qspec_map, _annotate_output_qspec, ) from torch.fx import Node -QUANT_ANNOTATION_KEY = "quantization_annotation" -OP_ANNOTATOR: Dict[OpOverload, Callable] = {} - @dataclass(eq=True, frozen=True) class QuantizationConfig: @@ -35,6 +40,209 @@ class QuantizationConfig: bias: Optional[QuantizationSpec | Callable] +def _derived_bias_quant_spec(node: Node) -> DerivedQuantizationSpec: + def _derive_bias_qparams_fn( + obs_or_fqs: List, + ) -> Tuple[Tensor, Tensor]: + assert ( + len(obs_or_fqs) == 2 + ), f"Expecting two obs/fqs, one for activation and one for weight, got: {len(obs_or_fqs)}" + act_obs_or_fq = obs_or_fqs[0] + weight_obs_or_fq = obs_or_fqs[1] + weight_scale, weight_zp = weight_obs_or_fq.calculate_qparams() + act_scale, act_zp = act_obs_or_fq.calculate_qparams() + (broadcast_act_scale, broadcast_weight_scale) = torch.broadcast_tensors( + act_scale, weight_scale + ) + derived_scale = (broadcast_act_scale * broadcast_weight_scale).to(torch.float32) + derived_zero = torch.zeros(derived_scale.size()).to(torch.int32) + return (derived_scale, derived_zero) + + input_act = node.args[0] + assert isinstance(input_act, Node) + weight = node.args[1] + assert isinstance(weight, Node) + + return DerivedQuantizationSpec( + derived_from=[(input_act, node), (weight, node)], + derive_qparams_fn=_derive_bias_qparams_fn, + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max, + ch_axis=0, + qscheme=torch.per_channel_symmetric, + ) + + +def get_default_8bit_qnn_ptq_config() -> QuantizationConfig: + extra_args: Dict[str, Any] = {"eps": 2**-12} + + act_quantization_spec = QuantizationSpec( + dtype=torch.uint8, + quant_min=0, + quant_max=torch.iinfo(torch.uint8).max, + qscheme=torch.per_tensor_affine, + observer_or_fake_quant_ctr=MovingAverageMinMaxObserver.with_args(**extra_args), + ) + + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=torch.iinfo(torch.int8).min + 1, + quant_max=torch.iinfo(torch.int8).max, + qscheme=torch.per_tensor_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + bias_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max, + qscheme=torch.per_tensor_symmetric, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + quantization_config = QuantizationConfig( + input_activation=act_quantization_spec, + output_activation=act_quantization_spec, + weight=weight_quantization_spec, + bias=bias_quantization_spec, + ) + + return quantization_config + + +# 4 bits quantization only supports specific ops. +def get_16a4w_qnn_ptq_config() -> QuantizationConfig: + extra_args: Dict[str, Any] = {"eps": 2**-20} + act_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.uint16).min, + quant_max=torch.iinfo(torch.uint16).max, + qscheme=torch.per_tensor_affine, + observer_or_fake_quant_ctr=MovingAverageMinMaxObserver.with_args(**extra_args), + ) + + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=-7, + quant_max=7, + qscheme=torch.per_tensor_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + bias_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max, + qscheme=torch.per_tensor_symmetric, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + quantization_config = QuantizationConfig( + input_activation=act_quantization_spec, + output_activation=act_quantization_spec, + weight=weight_quantization_spec, + bias=bias_quantization_spec, + ) + + return quantization_config + + +def get_default_16bit_qnn_ptq_config() -> QuantizationConfig: + extra_args: Dict[str, Any] = {"eps": 2**-20} + act_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.uint16).min, + quant_max=torch.iinfo(torch.uint16).max, + qscheme=torch.per_tensor_affine, + observer_or_fake_quant_ctr=MovingAverageMinMaxObserver.with_args(**extra_args), + ) + + weight_quantization_spec = QuantizationSpec( + dtype=torch.int16, + quant_min=torch.iinfo(torch.int16).min + 1, + quant_max=torch.iinfo(torch.int16).max, + qscheme=torch.per_tensor_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + # torch does not support uint16 quantization, use int32 to bypass + bias_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max, + qscheme=torch.per_tensor_symmetric, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + quantization_config = QuantizationConfig( + input_activation=act_quantization_spec, + output_activation=act_quantization_spec, + weight=weight_quantization_spec, + bias=bias_quantization_spec, + ) + + return quantization_config + + +def get_ptq_per_channel_weight_config( + act_dtype=torch.uint8, weight_dtype=torch.int8 +) -> QuantizationConfig: + extra_args: Dict[str, Any] = {"eps": 2**-12} + + supported_act_types = { + torch.uint8, + torch.uint16, + torch.int8, + torch.int16, + } + # TODO accept "int4" temporally. Remove "int4" when torch support torch.int4 dtype + supported_weight_dtypes = {"int4", torch.int8, torch.int16} + assert ( + act_dtype in supported_act_types + ), f"act_dtype, {act_dtype} is not one of supported types, {supported_act_types}" + + assert ( + weight_dtype in supported_weight_dtypes + ), f"weight_dtype, {weight_dtype} is not one of supported types, {supported_weight_dtypes}" + + # torch do not support uint16 quantization, use int32 to bypass + act_quantization_spec = QuantizationSpec( + dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, + quant_min=torch.iinfo(act_dtype).min, + quant_max=torch.iinfo(act_dtype).max, + qscheme=torch.per_tensor_affine, + observer_or_fake_quant_ctr=HistogramObserver.with_args(**extra_args), + ) + + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8 if weight_dtype == "int4" else weight_dtype, + quant_min=-7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).min + 1, + quant_max=7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).max, + qscheme=torch.per_channel_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=PerChannelMinMaxObserver.with_args(**extra_args), + ) + + bias_quantization_spec = _derived_bias_quant_spec + + quantization_config = QuantizationConfig( + input_activation=act_quantization_spec, + output_activation=act_quantization_spec, + weight=weight_quantization_spec, + bias=bias_quantization_spec, + ) + + return quantization_config + + +QUANT_ANNOTATION_KEY = "quantization_annotation" +OP_ANNOTATOR: Dict[OpOverload, Callable] = {} + + def register_annotator(ops: List[OpOverload]): def decorator(annotator: Callable): for op in ops: @@ -117,15 +325,12 @@ def annotate_single_in_single_out( assert isinstance(input_act, Node) input_qspec_map[input_act] = quantization_config.input_activation - node_tensor = node.meta.get("val") - if torch.is_tensor(node_tensor) and node_tensor.dtype != torch.float32: - return - - node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( - input_qspec_map=input_qspec_map, - output_qspec=quantization_config.output_activation, - _annotated=True, - ) + if _is_input_float_tensor(node): + node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=quantization_config.output_activation, + _annotated=True, + ) def annotate_binary(node: Node, quantization_config: QuantizationConfig) -> None: @@ -133,7 +338,11 @@ def annotate_binary(node: Node, quantization_config: QuantizationConfig) -> None return input_act_qspec = quantization_config.input_activation - output_act_qspec = quantization_config.output_activation + output_act_qspec = ( + quantization_config.output_activation + if node.meta["val"].dtype == torch.float32 + else None + ) input_qspec_map = {} input_act0 = node.args[0] @@ -433,10 +642,8 @@ def annotate_matmul(node: Node, quantization_config: QuantizationConfig) -> None if isinstance(input_act1, Node): # In matmul, QNN_DATATYPE_SFIXED_POINT_16 Input1 must have QNN_DATATYPE_UFIXED_POINT_16 Input0 and must be symmetric quantized. if input_act_qspec.dtype == torch.int32: - input_qspec_map[input_act1] = quantization_config.weight - quantization_annotation = input_act1.meta.get(QUANT_ANNOTATION_KEY, None) - if quantization_annotation: - quantization_annotation.output_qspec = quantization_config.weight + # we should use int16 for mm / bmm instead of int4 + input_qspec_map[input_act1] = get_default_16bit_qnn_ptq_config().weight else: input_qspec_map[input_act1] = input_act_qspec @@ -464,10 +671,8 @@ def annotate_bmm(node: Node, quantization_config: QuantizationConfig) -> None: if isinstance(input_act1, Node): # In bmm, QNN_DATATYPE_SFIXED_POINT_16 Input1 must have QNN_DATATYPE_UFIXED_POINT_16 Input0 and must be symmetric quantized. if input_act_qspec.dtype == torch.int32: - input_qspec_map[input_act1] = quantization_config.weight - quantization_annotation = input_act1.meta.get(QUANT_ANNOTATION_KEY, None) - if quantization_annotation: - quantization_annotation.output_qspec = quantization_config.weight + # we should use int16 for mm / bmm instead of int4 + input_qspec_map[input_act1] = get_default_16bit_qnn_ptq_config().weight else: input_qspec_map[input_act1] = input_act_qspec diff --git a/backends/qualcomm/runtime/backends/QnnBackendCache.cpp b/backends/qualcomm/runtime/backends/QnnBackendCache.cpp index 0c569ae5ab6..8c753011ba6 100644 --- a/backends/qualcomm/runtime/backends/QnnBackendCache.cpp +++ b/backends/qualcomm/runtime/backends/QnnBackendCache.cpp @@ -88,6 +88,8 @@ QnnBackendCache::QnnBackendCache( QNN_EXECUTORCH_LOG_INFO("Caching: Caching is in SAVE MODE."); return; } else { + // TODO: need fix on this since qnn context binary could somehow + // pass the check of flatbuffer verifier // check if context binary came from flatbuffer flatbuffers::FlatBufferBuilder builder; flatbuffers::Verifier verifier( diff --git a/backends/qualcomm/scripts/build.sh b/backends/qualcomm/scripts/build.sh index c8379cf0b7a..b2c8e0d61ca 100755 --- a/backends/qualcomm/scripts/build.sh +++ b/backends/qualcomm/scripts/build.sh @@ -71,6 +71,7 @@ if [ "$BUILD_AARCH64" = true ]; then -DCMAKE_INSTALL_PREFIX=$BUILD_ROOT \ -DEXECUTORCH_BUILD_QNN=ON \ -DEXECUTORCH_BUILD_SDK=ON \ + -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ -DEXECUTORCH_ENABLE_EVENT_TRACER=ON \ -DQNN_SDK_ROOT=$QNN_SDK_ROOT \ -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \ diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index edc7a469f7b..a1eff816277 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -29,7 +29,7 @@ def __init__(self): super().__init__() def forward(self, x): - return 10.0 + x + return 10 + x class Arange(torch.nn.Module): diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py index 0a9b7d064d1..4a38fd3a24a 100644 --- a/backends/qualcomm/utils/utils.py +++ b/backends/qualcomm/utils/utils.py @@ -61,6 +61,47 @@ def qnn_edge_config() -> exir.EdgeCompileConfig: return exir.EdgeCompileConfig(_check_ir_validity=False) +def convert_linear_to_conv2d(module: torch.nn.Module): + class Conv2D(torch.nn.Module): + def __init__(self, weight, bias=None): + super().__init__() + use_bias = bias is not None + self.conv = torch.nn.Conv2d( + in_channels=weight.shape[0], + out_channels=weight.shape[1], + kernel_size=(1, 1), + padding=0, + bias=use_bias, + ) + self.conv.weight = torch.nn.Parameter( + weight.reshape(*weight.shape, 1, 1) + ) + if use_bias: + self.conv.bias = torch.nn.Parameter(bias) + + def forward(self, x): + rank = x.dim() + x = x.unsqueeze(-1) if rank == 3 else x.reshape(1, *x.shape, 1) + x = torch.transpose(x, 1, 2) + res = self.conv(x) + res = torch.transpose(res, 1, 2) + res = res.squeeze(-1) if rank == 3 else res.reshape(*res.shape[1:3]) + return res + + def replace_linear(module: torch.nn.Module): + for attr_str in dir(module): + target_attr = getattr(module, attr_str) + if type(target_attr) == torch.nn.Linear: + print('replaced: ', attr_str) + setattr(module, attr_str, Conv2D(target_attr.weight, target_attr.bias)) + + for _, sub_module in module.named_children(): + sub_module = replace_linear(sub_module) + return module + + return replace_linear(module) + + def canonicalize_program(prog: ExportedProgram): # check if user specifies to use multi_contexts # this is a generic approach in case there exists multiple backends diff --git a/examples/models/llama2/tokenizer/bpe_tokenizer.cpp b/examples/models/llama2/tokenizer/bpe_tokenizer.cpp index ed7d34aca4d..a517d3157fb 100644 --- a/examples/models/llama2/tokenizer/bpe_tokenizer.cpp +++ b/examples/models/llama2/tokenizer/bpe_tokenizer.cpp @@ -328,7 +328,7 @@ BPETokenizer::encode(const std::string& text, int8_t bos, int8_t eos) { } delete[] str_buffer; - return Result(tokens); + return Result>(tokens); } } // namespace executor diff --git a/examples/qualcomm/CMakeLists.txt b/examples/qualcomm/CMakeLists.txt index cff5db2a63d..5b8e66e40b2 100644 --- a/examples/qualcomm/CMakeLists.txt +++ b/examples/qualcomm/CMakeLists.txt @@ -56,6 +56,7 @@ get_filename_component(EXECUTORCH_SOURCE_DIR ABSOLUTE ) set(_qnn_executor_runner__srcs ${_executor_runner__srcs}) +set(_qnn_llama_runner__srcs ${_llama_runner__srcs}) # portable_ops_lib gen_selected_ops("" "" "ON") @@ -74,6 +75,7 @@ target_include_directories(full_portable_ops_lib ${_common_include_directories} ) +# prerpocess executor runner src files list( TRANSFORM _qnn_executor_runner__srcs @@ -92,8 +94,29 @@ list( ${CMAKE_CURRENT_LIST_DIR}/executor_runner/qnn_executor_runner.cpp ) -add_executable(qnn_executor_runner ${_qnn_executor_runner__srcs}) +# preprocess llama runner src files +list( + TRANSFORM + _qnn_llama_runner__srcs + PREPEND + "${EXECUTORCH_SOURCE_DIR}/" +) +list( + FILTER + _qnn_llama_runner__srcs + EXCLUDE REGEX + ".*runner.*$" +) +list( + PREPEND + _qnn_llama_runner__srcs + ${CMAKE_CURRENT_LIST_DIR}/executor_runner/qnn_llama_runner.cpp + ${CMAKE_CURRENT_LIST_DIR}/llama2/runner/runner.cpp + ${CMAKE_CURRENT_LIST_DIR}/llama2/runner/runner.h +) +# build executor runner +add_executable(qnn_executor_runner ${_qnn_executor_runner__srcs}) target_include_directories(qnn_executor_runner PUBLIC ${_common_include_directories} @@ -109,3 +132,21 @@ target_compile_options(qnn_executor_runner PUBLIC ${_common_compile_options} ) + +# build llama runner +add_executable(qnn_llama_runner ${_qnn_llama_runner__srcs}) +target_include_directories(qnn_llama_runner + PUBLIC + ${_common_include_directories} +) +target_link_libraries(qnn_llama_runner + qnn_executorch_backend + full_portable_ops_lib + extension_data_loader + extension_module + gflags +) +target_compile_options(qnn_llama_runner + PUBLIC + ${_common_compile_options} +) diff --git a/examples/qualcomm/executor_runner/qnn_llama_runner.cpp b/examples/qualcomm/executor_runner/qnn_llama_runner.cpp new file mode 100644 index 00000000000..0e959a1caab --- /dev/null +++ b/examples/qualcomm/executor_runner/qnn_llama_runner.cpp @@ -0,0 +1,152 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/** + * @file + * + * This tool can run ExecuTorch model files with Qualcomm AI Engine Direct + * and the portable kernels. + * + * User could specify arguments like desired input data, iterfations, etc. + * Currently we assume that the outputs are all fp32 tensors. + */ + +#include +#include +#include +#include + +#include + +#include + +DEFINE_string( + model_path, + "qnn_llama2.pte", + "Model serialized in flatbuffer format."); + +DEFINE_string( + output_folder_path, + "outputs", + "Executorch inference data output path."); + +DEFINE_string(input_list_path, "input_list.txt", "Model input list path."); + +DEFINE_string(tokenizer_path, "tokenizer.bin", "Tokenizer stuff."); + +DEFINE_string(prompt, "The answer to the ultimate question is", "Prompt."); + +DEFINE_double( + temperature, + 0.8f, + "Temperature; Default is 0.8f. 0 = greedy argmax sampling (deterministic). Lower temperature = more deterministic"); + +DEFINE_int32( + seq_len, + 128, + "Total number of tokens to generate (prompt + output). Defaults to max_seq_len. If the number of input tokens + seq_len > max_seq_len, the output will be truncated to max_seq_len tokens."); + +int main(int argc, char** argv) { + using namespace torch::executor; + + gflags::ParseCommandLineFlags(&argc, &argv, true); + + const char* model_path = FLAGS_model_path.c_str(); + const char* tokenizer_path = FLAGS_tokenizer_path.c_str(); + const char* prompt = FLAGS_prompt.c_str(); + double temperature = FLAGS_temperature; + int32_t seq_len = FLAGS_seq_len; + + // create llama runner + Runner runner(model_path, tokenizer_path, temperature); + ET_CHECK_MSG(runner.load() == Error::Ok, "Runner failed to load method"); + + // MethodMeta describes the memory requirements of the method. + Result method_meta = runner.method_meta(); + ET_CHECK_MSG( + method_meta.ok(), + "Failed to get method_meta 0x%x", + (unsigned int)method_meta.error()); + + // Fill in data for input + std::ifstream input_list(FLAGS_input_list_path); + ET_CHECK_MSG(input_list.is_open(), "Failed to open input_list.txt"); + + auto split = [](std::string s, std::string delimiter) { + size_t pos_start = 0, pos_end, delim_len = delimiter.length(); + std::string token; + std::vector res; + + while ((pos_end = s.find(delimiter, pos_start)) != std::string::npos) { + token = s.substr(pos_start, pos_end - pos_start); + pos_start = pos_end + delim_len; + res.push_back(token); + } + res.push_back(s.substr(pos_start)); + return res; + }; + + std::string file_path; + size_t inference_index = 0, num_inputs = method_meta->num_inputs(); + std::vector> inputs(num_inputs); + while (std::getline(input_list, file_path)) { + auto input_files = split(file_path, " "); + if (input_files.size() == 0) { + break; + } + // inputs: [tokens, pos_ids, atten_mask, kv_mask, k_cache, v_cache] + // tokens are determined by command line arguments + // pos_ids are infered inside runner + std::vector managed_inputs; + for (int input_index = 2; input_index < num_inputs; ++input_index) { + Result tensor_meta = + method_meta->input_tensor_meta(input_index); + + std::ifstream fin(input_files[input_index], std::ios::binary); + fin.seekg(0, fin.end); + size_t file_size = fin.tellg(); + + ET_CHECK_MSG( + file_size == tensor_meta->nbytes(), + "Input(%d) size mismatch. file bytes: %zu, tensor bytes: %zu", + input_index, + file_size, + tensor_meta->nbytes()); + + inputs[input_index].resize(tensor_meta->nbytes()); + if (input_index <= 2) { + fin.seekg(0, fin.beg); + fin.read(inputs[input_index].data(), file_size); + } + fin.close(); + + auto tensor_shape = tensor_meta->sizes(); + std::vector sizes( + tensor_shape.data(), tensor_shape.data() + tensor_shape.size()); + + managed_inputs.emplace_back(ManagedTensor( + inputs[input_index].data(), 128, sizes, tensor_meta->scalar_type())); + } + + // generate tokens + std::string inference_output; + runner.generate( + prompt, seq_len, managed_inputs, [&](const std::string& piece) { + inference_output += piece; + }); + + auto output_file_name = FLAGS_output_folder_path + "/output_" + + std::to_string(inference_index++) + "_0.raw"; + std::ofstream fout(output_file_name.c_str()); + fout << inference_output; + fout.close(); + } + + return 0; +} diff --git a/examples/qualcomm/llama2/llama.py b/examples/qualcomm/llama2/llama.py new file mode 100644 index 00000000000..3a43793e68e --- /dev/null +++ b/examples/qualcomm/llama2/llama.py @@ -0,0 +1,275 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import codecs +import json +import os +import sys + +from functools import partial + +import torch + +from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype +from executorch.backends.qualcomm.utils.utils import convert_linear_to_conv2d +from executorch.examples.qualcomm.llama2.model.static_llama import LlamaModel, ModelArgs +from executorch.examples.qualcomm.scripts.utils import ( + build_executorch_binary, + make_output_dir, + setup_common_args_and_variables, + SimpleADB, +) + +from sentencepiece import SentencePieceProcessor + + +def create_device_inputs(example_inputs): + # TODO: support batch inputs if necessary + input_list = "" + inputs, flat_inputs = [], [] + for input in example_inputs: + if isinstance(input, list): + for inp in input: + flat_inputs.append(inp) + else: + flat_inputs.append(input) + + for i, data in enumerate(flat_inputs): + input_list += f"input_0_{i}.raw " + inputs.append(data) + + input_list += "\n" + return tuple(inputs), input_list + + +def calibrate(example_inputs, module: torch.fx.GraphModule): + sp_model = SentencePieceProcessor(model_file="tokenizer.model") + _, _, kv_mask, k_caches, v_caches = example_inputs + + # TODO: change criteria & support batch inputs if necessary + pos = torch.tensor(0, dtype=torch.int32) + token_list = [sp_model.bos_id()] + user_prompts = ["Once"] + for prompt in user_prompts: + token_list += sp_model.encode(prompt) + + def sample_top_p(probs: torch.Tensor, top_p: float) -> torch.Tensor: + probs_sort, probs_indices = torch.sort(probs, dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + mask = probs_sum - probs_sort > top_p + probs_sort[mask] = 0 + probs_sort /= probs_sort.sum(dim=-1, keepdim=True) + next_token = torch.multinomial(probs_sort, num_samples=1) + return probs_indices.gather(dim=-1, index=next_token) + + with torch.no_grad(): + while token_list[-1] != sp_model.eos_id() and pos < 128: + logits, kv_mask, k_caches, v_caches = module( + torch.full((1, 1), token_list[pos]), + torch.full((1, 1), pos), + kv_mask, + *k_caches, + *v_caches, + ) + pos += 1 + if pos >= len(token_list): + probs = torch.softmax(logits[:, -1] / 0.8, dim=-1) + token_list.append(sample_top_p(probs, 0.9).item()) + + print(f"calibration data:\n{sp_model.decode(token_list)}") + + +if __name__ == "__main__": + parser = setup_common_args_and_variables() + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts by this example. Default ./llama2_qnn", + default="./llama2_qnn", + type=str, + ) + + parser.add_argument( + "-F", + "--use_fp16", + help="If specified, will run in fp16 precision and discard ptq setting", + action="store_true", + default=False, + ) + + parser.add_argument( + "-P", + "--ptq", + help="If specified, will do PTQ quantization. default is 16bits activation and 4bits weight. Support 8a8w and 16a4w.", + default="16a4w", + ) + + parser.add_argument( + "--checkpoint", + help="Pass llama2 checkpoint.", + required=True, + type=str, + ) + + parser.add_argument( + "--params", + help="Pass llama2 params json file.", + required=True, + type=str, + ) + + parser.add_argument( + "--tokenizer_bin", + help="Pass llama2 tokenizer binary.", + required=True, + type=str, + ) + + parser.add_argument( + "--tokenizer_model", + help="Pass llama2 tokenizer model.", + type=str, + default=None, + ) + + parser.add_argument( + "--prompt", + help="User prompts for llama2.", + required=True, + type=str, + ) + + parser.add_argument( + "--seq_len", + help="Ouput sequence length for llama2.", + default=128, + type=int, + ) + + parser.add_argument( + "--temperature", + help="Sampling temperature for llama2.", + default=0.8, + type=float, + ) + + parser.add_argument( + "--pre_gen_pte", + help="Pre-generated llama2.", + type=str, + ) + + args = parser.parse_args() + + # ensure the working directory exist. + os.makedirs(args.artifact, exist_ok=True) + + with open(args.params) as f: + config = ModelArgs(**json.load(f)) + # TODO: support batch inputs if necessary + config.max_batch_size = 1 + + state_dict = torch.load(args.checkpoint) + instance = LlamaModel(config) + instance.load_state_dict(state_dict["model"]) + inputs, input_list = create_device_inputs(instance.get_export_inputs()) + pte_filename = "llama2_qnn" + + if args.ptq == "8a8w": + quant_dtype = QuantDtype.use_8a8w + elif args.ptq == "16a4w": + quant_dtype = QuantDtype.use_16a4w + else: + raise AssertionError( + f"No support for quant type {args.ptq}. Support 8a8w and 16a4w." + ) + + if args.use_fp16: + quant_dtype = None + else: + assert args.tokenizer_model is not None, "Need tokenizer model for calibration" + + if args.pre_gen_pte is None: + build_executorch_binary( + # try this if you want: convert_linear_to_conv2d(instance.eval()), + instance.eval(), + inputs, + args.model, + f"{args.artifact}/{pte_filename}", + partial(calibrate, instance.get_example_inputs()), + custom_annotations=(), + quant_dtype=quant_dtype, + per_channel_linear=True, + shared_buffer=args.shared_buffer, + metadata=instance.get_metadata(), + direct_io=True, + ) + + if args.compile_only: + sys.exit(0) + + # build custom commands for qnn_llama_runner + pte_path = ( + f"{args.artifact}/{pte_filename}.pte" + if args.pre_gen_pte is None + else args.pre_gen_pte + ) + workspace = f"/data/local/tmp/executorch/{pte_filename}" + runner_args = " ".join( + [ + f"--model_path {pte_filename}.pte", + "--output_folder_path outputs", + "--input_list_path input_list.txt", + f"--tokenizer_path {os.path.basename(args.tokenizer_bin)}", + f"--prompt {args.prompt}", + f"--seq_len {args.seq_len}", + f"--temperature {args.temperature}", + ] + ) + runner_cmd = " ".join( + [ + f"cd {workspace} &&", + "export ADSP_LIBRARY_PATH=. &&", + "export LD_LIBRARY_PATH=. &&", + f"./qnn_llama_runner {runner_args}", + ] + ) + + adb = SimpleADB( + qnn_sdk=os.getenv("QNN_SDK_ROOT"), + artifact_path=f"{args.build_folder}", + pte_path=pte_path, + workspace=workspace, + device_id=args.device, + host_id=args.host, + soc_model=args.model, + shared_buffer=args.shared_buffer, + runner="examples/qualcomm/qnn_llama_runner", + ) + adb.push(inputs=[inputs], input_list=input_list, files=[args.tokenizer_bin]) + adb.execute(custom_runner_cmd=runner_cmd) + + # collect output data + output_data_folder = f"{args.artifact}/outputs" + make_output_dir(output_data_folder) + outputs = [] + + def post_process(): + for f in sorted( + os.listdir(output_data_folder), key=lambda f: int(f.split("_")[1]) + ): + with codecs.open( + os.path.join(output_data_folder, f), + "r", + encoding="utf-8", + errors="replace", + ) as fdata: + outputs.append(fdata.read()) + + adb.pull(output_path=args.artifact, callback=post_process) + + for idx, output in enumerate(outputs): + print(f"Results[{idx}]:\n{output}") diff --git a/examples/qualcomm/llama2/model/static_llama.py b/examples/qualcomm/llama2/model/static_llama.py new file mode 100644 index 00000000000..c4cc25607a1 --- /dev/null +++ b/examples/qualcomm/llama2/model/static_llama.py @@ -0,0 +1,334 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Tuple + +import torch +import torch.nn as nn + +from executorch.examples.models.llama2.llama_transformer import ( + apply_rotary_emb, + FeedForward, + ModelArgs, + precompute_freqs_cis, + RMSNorm, +) + + +class LlamaAttention(nn.Module): + def __init__(self, config: ModelArgs, split_kv_cache=False): + super().__init__() + self.dim = config.dim + self.n_heads = config.n_heads + self.head_dim = config.dim // config.n_heads + self.n_kv_heads = config.n_kv_heads + self.num_key_value_groups = config.n_heads // self.n_kv_heads + self.max_seq_len = config.max_seq_len + self.split_kv_cache = split_kv_cache + + self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False) + self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False) + + self.attn_softmax = torch.nn.Softmax(dim=-1) + + scale = float(self.head_dim) ** -0.5 + scale_tensor = torch.tensor( + [scale], dtype=torch.float32, requires_grad=False + ).view(1, 1, 1) + self.register_buffer("scale_tensor", scale_tensor, False) + + def forward( + self, + hidden_states: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + atten_mask: torch.Tensor, + kv_mask: torch.Tensor, + k_caches: List[torch.Tensor], + v_caches: List[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + bsz, seqlen, _ = hidden_states.shape + + q, k, v = self.wq(hidden_states), self.wk(hidden_states), self.wv(hidden_states) + q = q.view(bsz, seqlen, self.n_heads, self.head_dim) + k = k.view(bsz, seqlen, self.n_kv_heads, self.head_dim) + v = v.view(bsz, seqlen, self.n_kv_heads, self.head_dim) + + q, k = apply_rotary_emb(q, k, freqs_cos, freqs_sin) + + if self.split_kv_cache: + output_kh, output_vh, output_y = [], [], [] + for i, _ in enumerate(k_caches): + kh = k_caches[i] + k[:, :, i, :] * kv_mask + vh = v_caches[i] + v[:, :, i, :] * kv_mask + + attn = q[:, :, i, :] @ kh.permute(0, 2, 1) + attn = attn * self.scale_tensor + atten_mask + attn = self.attn_softmax(attn) + y = attn @ vh + + output_kh.append(kh) + output_vh.append(vh) + output_y.append(y) + + y = torch.concat(output_y, dim=-1) + y = self.wo(y) + return y, output_kh, output_vh + else: + k = k_caches + k * kv_mask + v = v_caches + v * kv_mask + + attn = q.transpose(1, 2) @ k.permute(0, 2, 3, 1) + attn = attn * self.scale_tensor + atten_mask + attn = self.attn_softmax(attn) + y = attn @ v.transpose(1, 2) + y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) + y = self.wo(y) + + return y, k, v + + +class LlamaDecoderLayer(nn.Module): + def __init__(self, config: ModelArgs, split_kv_cache=False): + super().__init__() + self.dim = config.dim + self.attention = LlamaAttention(config=config, split_kv_cache=split_kv_cache) + self.feed_forward = FeedForward(config) + self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps) + self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps) + + def forward( + self, + x: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + atten_mask: torch.Tensor, + kv_mask: torch.Tensor, + k_caches: List[torch.Tensor], + v_caches: List[torch.Tensor], + ) -> Tuple[torch.Tensor]: + h, k_cache, v_cache = self.attention( + hidden_states=self.attention_norm(x), + freqs_cos=freqs_cos, + freqs_sin=freqs_sin, + atten_mask=atten_mask, + kv_mask=kv_mask, + k_caches=k_caches, + v_caches=v_caches, + ) + h = x + h + output = h + self.feed_forward(self.ffn_norm(h)) + return output, k_cache, v_cache + + +class LlamaModel(nn.Module): + def __init__(self, config: ModelArgs, split_kv_cache=False): + super().__init__() + self.dim = config.dim + self.head_dim = config.dim // config.n_heads + self.max_batch_size = config.max_batch_size + self.max_seq_len = config.max_seq_len + self.n_heads = config.n_heads + self.n_kv_heads = config.n_kv_heads + self.n_layers = config.n_layers + self.vocab_size = config.vocab_size + self.split_kv_cache = split_kv_cache + + self.layers = nn.ModuleList( + [LlamaDecoderLayer(config, split_kv_cache) for _ in range(config.n_layers)] + ) + self.norm = RMSNorm(config.dim, eps=config.norm_eps) + self.output = nn.Linear(config.dim, config.vocab_size, bias=False) + self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim) + freqs_cos, freqs_sin = precompute_freqs_cis( + config.dim // config.n_heads, + config.max_seq_len, + config.rope_freq_base, + ) + atten_mask = torch.triu( + torch.full( + (self.max_seq_len, self.max_seq_len), + -255.0, + ), + diagonal=1, + ) + self.register_buffer("atten_mask", atten_mask, persistent=False) + self.register_buffer("freqs_cos", freqs_cos, persistent=False) + self.register_buffer("freqs_sin", freqs_sin, persistent=False) + if split_kv_cache: + self.register_buffer("mask", torch.ones(self.head_dim), persistent=False) + self.register_buffer("unmask", torch.zeros(self.head_dim), persistent=False) + else: + self.register_buffer("mask", torch.ones(self.dim), persistent=False) + self.register_buffer("unmask", torch.zeros(self.dim), persistent=False) + + def forward( + self, + tokens: torch.Tensor, + input_pos: torch.Tensor, + kv_mask: torch.Tensor, + *args, + ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]: + output_k_cache = [] + output_v_cache = [] + # following tensors should be invariant across batches + freqs_cos = self.freqs_cos[input_pos][0] + freqs_sin = self.freqs_sin[input_pos][0] + atten_mask = self.atten_mask[input_pos][0] + + hidden_states = self.tok_embeddings(tokens) + for ind, decoder_layer in enumerate(self.layers): + if self.split_kv_cache: + offset_k = ind * self.n_heads + offset_v = self.n_layers * self.n_heads + offset_k + k_caches = args[offset_k : offset_k + self.n_heads] + v_caches = args[offset_v : offset_v + self.n_heads] + hidden_states, k, v = decoder_layer( + hidden_states, + freqs_cos=freqs_cos, + freqs_sin=freqs_sin, + atten_mask=atten_mask, + kv_mask=kv_mask, + k_caches=k_caches, + v_caches=v_caches, + ) + output_k_cache.extend(k) + output_v_cache.extend(v) + else: + k_caches = args[ind] + v_caches = args[self.n_layers + ind] + hidden_states, k, v = decoder_layer( + hidden_states, + freqs_cos=freqs_cos, + freqs_sin=freqs_sin, + atten_mask=atten_mask, + kv_mask=kv_mask.view( + self.max_seq_len, self.n_kv_heads, self.head_dim + ), + k_caches=k_caches, + v_caches=v_caches, + ) + output_k_cache.append(k) + output_v_cache.append(v) + + hidden_states = self.norm(hidden_states) + logits = self.output(hidden_states) + + # TODO: add op builder for kv mask update once HTP supports more ops + # this part is now expected to be fallback on cpu + # for simplicity, input_pos is assumed to never go over max_seq_len-1 + kv_mask[input_pos] = self.unmask + kv_mask[input_pos + 1] = self.mask + + return logits, kv_mask, output_k_cache, output_v_cache + + def get_example_inputs(self): + tokens = torch.randint( + self.vocab_size, (self.max_batch_size, 1), dtype=torch.int32 + ) + pos_ids = torch.zeros((self.max_batch_size, 1), dtype=torch.int32) + k_cache, v_cache = [], [] + if self.split_kv_cache: + kv_mask = torch.zeros(self.max_seq_len, self.head_dim) + kv_mask[0] = torch.ones(self.head_dim) + for _ in range(self.n_layers): + for _ in range(self.n_heads): + k_cache += torch.zeros( + self.max_batch_size, + self.max_seq_len, + self.head_dim, + ) + v_cache += torch.zeros( + self.max_batch_size, + self.max_seq_len, + self.head_dim, + ) + else: + kv_mask = torch.zeros(self.max_seq_len, self.dim) + kv_mask[0] = torch.ones(self.dim) + for _ in range(self.n_layers): + k_cache += torch.zeros( + self.max_batch_size, + self.max_seq_len, + self.n_heads, + self.head_dim, + ) + v_cache += torch.zeros( + self.max_batch_size, + self.max_seq_len, + self.n_heads, + self.head_dim, + ) + return ( + tokens, + pos_ids, + kv_mask, + k_cache, + v_cache, + ) + + def get_export_inputs(self): + tokens = torch.randint( + self.vocab_size, (self.max_batch_size, 1), dtype=torch.int32 + ) + pos_ids = torch.zeros((self.max_batch_size, 1), dtype=torch.int32) + # this is important for torch.export not to take it as dummy input + k_cache, v_cache = [], [] + if self.split_kv_cache: + kv_mask = torch.zeros(self.max_seq_len, self.head_dim) + kv_mask[0] = torch.ones(self.head_dim) + for _ in range(self.n_layers): + for _ in range(self.n_heads): + k_cache += torch.randn( + self.max_batch_size, + self.max_seq_len, + self.head_dim, + ) + v_cache += torch.randn( + self.max_batch_size, + self.max_seq_len, + self.head_dim, + ) + else: + kv_mask = torch.zeros(self.max_seq_len, self.dim) + kv_mask[0] = torch.ones(self.dim) + for _ in range(self.n_layers): + k_cache += torch.randn( + self.max_batch_size, + self.max_seq_len, + self.n_heads, + self.head_dim, + ) + v_cache += torch.randn( + self.max_batch_size, + self.max_seq_len, + self.n_heads, + self.head_dim, + ) + return ( + tokens, + pos_ids, + kv_mask, + k_cache, + v_cache, + ) + + def get_metadata(self): + # TODO: modify this when enabling LLAMA 7B + return { + "get_bos_id": 1, + "get_eos_id": 2, + "get_head_dim": self.dim // self.n_heads, + "get_max_batch_size": self.max_batch_size, + "get_max_seq_len": self.max_seq_len, + "get_n_bos": 1, + "get_n_eos": 1, + "get_n_kv_heads": self.n_heads, + "get_n_layers": self.n_layers, + "get_vocab_size": self.vocab_size, + } diff --git a/examples/qualcomm/llama2/runner/runner.cpp b/examples/qualcomm/llama2/runner/runner.cpp new file mode 100644 index 00000000000..881fe96ae4e --- /dev/null +++ b/examples/qualcomm/llama2/runner/runner.cpp @@ -0,0 +1,385 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// A simple llama2 runner that includes preprocessing and post processing logic. +// The module takes in a string as input and emits a string as output. + +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include + +namespace torch { +namespace executor { + +namespace { +static constexpr auto kTopp = 0.9f; +void printReport(const Runner::Stats& stats); +std::string statsToJsonString(const Runner::Stats& stats); +} // namespace + +Runner::Runner( + const std::string& model_path, + const std::string& tokenizer_path, + const float temperature) + : module_(std::make_unique( + model_path, + Module::MlockConfig::UseMlockIgnoreErrors)), + tokenizer_path_(tokenizer_path), + temperature_(temperature) { + ET_LOG( + Info, + "Creating LLaMa runner: model_path=%s, tokenizer_path=%s", + model_path.c_str(), + tokenizer_path.c_str()); +} + +bool Runner::is_loaded() const { + return module_->is_loaded() && tokenizer_ && sampler_; +} + +Error Runner::load() { + if (is_loaded()) { + return Error::Ok; + } + stats_.model_load_start_ms = util::time_in_ms(); + ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method("forward")); + + // Read out metadata from the model + ET_LOG(Info, "Reading metadata from model"); + const auto method_names = module_->method_names(); + ET_CHECK_MSG(method_names.ok(), "Failed to read method names from model"); + model_methods_ = method_names.get(); + vocab_size_ = getMetadataHelper("get_vocab_size", 32000); + bos_id_ = getMetadataHelper("get_bos_id", 1); + eos_id_ = getMetadataHelper("get_eos_id", 2); + n_bos_ = getMetadataHelper("get_n_bos", 1); + n_eos_ = getMetadataHelper("get_n_eos", 1); + max_seq_len_ = getMetadataHelper("get_max_seq_len", 128); + + // Load tokenizer + tokenizer_ = std::make_unique(vocab_size_, bos_id_, eos_id_); + tokenizer_->load(tokenizer_path_); + if (tokenizer_->bos_tok() != bos_id_) { + ET_LOG( + Error, + "Tokenizer's BOS id %lu does not match model's BOS id %d, will override tokenizer's BOS.", + tokenizer_->bos_tok(), + bos_id_); + } + if (tokenizer_->eos_tok() != eos_id_) { + ET_LOG( + Error, + "Tokenizer's EOS id %lu does not match model's EOS id %d, will override tokenizer's EOS.", + tokenizer_->eos_tok(), + eos_id_); + } + // Create sampler + sampler_ = std::make_unique( + vocab_size_, + temperature_, + kTopp, + static_cast(std::time(nullptr))); + stats_.model_load_end_ms = util::time_in_ms(); + + return Error::Ok; +} + +template +T Runner::getMetadataHelper(std::string method_name, T default_val) { + T res = default_val; + if (model_methods_.count(method_name)) { + Result> outputs = module_->execute(method_name); + if (outputs.ok()) { + std::vector outs = outputs.get(); + if (outs.size() > 0) { + res = outs[0].to(); + } + } + } else { + ET_LOG( + Info, + "The model does not contain %s method, using default value %lld", + method_name.c_str(), + (long long)default_val); + } + ET_LOG(Info, "%s: %lld", method_name.c_str(), (long long)res); + return res; +} + +template +int32_t Runner::logitsToToken(const exec_aten::Tensor& logits_tensor) { + T* logits = logits_tensor.mutable_data_ptr(); + + // Since the logits are for all tokens, get the last token probabilities + T* logits_last = logits; + return sampler_->sample(logits_last); +} + +// Given an input token. Set up the inputs for the model and execute a single +// step. Returning the logits tensor. +Result Runner::run_model_step( + int64_t input_token, + Tensor& token, + Tensor& start_pos, + std::vector& input_tensors) { + token.mutable_data_ptr()[0] = input_token; + // inputs:[tokens, start_pos, atten_mask, kv_mask, k_cache, v_cache] + std::vector inputs = {token, start_pos}; + inputs.insert(inputs.end(), input_tensors.begin(), input_tensors.end()); + + Result> outputs_res = module_->forward(inputs); + ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error()); + + // Bump start_pos by 1 + start_pos.mutable_data_ptr()[0]++; + return outputs_res.get()[1].toTensor(); +} + +// TODO: add overloaded method for on-device tokenize +Error Runner::generate( + const std::string& prompt, + int32_t seq_len, + std::vector& managed_inputs, + std::function token_callback, + std::function stats_callback) { + ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null"); + ET_CHECK_MSG(is_loaded(), "Please invoke load method first"); + + // First token time only measures the time it takes to encode the prompt and + // return a response token. + stats_.inference_start_ms = util::time_in_ms(); + shouldStop_ = false; + + // Set the sequence length to the max seq length if not provided + seq_len = (seq_len > 0 && seq_len <= max_seq_len_) ? seq_len : max_seq_len_; + + Result> encode_res = + tokenizer_->encode(prompt, n_bos_, 0); + + ET_CHECK_OK_OR_RETURN_ERROR( + encode_res.error(), "Failed to encode prompt %s", prompt.c_str()); + + // encode the (string) prompt into tokens sequence + std::vector prompt_tokens = encode_res.get(); + int num_prompt_tokens = prompt_tokens.size(); + + ET_CHECK_MSG( + num_prompt_tokens < max_seq_len_, + "Max seq length exceeded - please increase max seq len value in static_llama.py"); + + ET_CHECK_MSG( + num_prompt_tokens < seq_len, + "Sequence length exceeded - please increase the seq_len value passed to generate()"); + + int32_t pos = 0, prev_token, cur_token = prompt_tokens[0]; + std::vector token_data = {1}; + std::vector token_shape = {1, 1}; + + std::vector start_pos_data = {0}; + std::vector start_pos_shape = {1, 1}; + + std::vector logits_data(vocab_size_); + std::vector logits_data_shape = {1, vocab_size_}; + + // initialize tensor wrappers + ManagedTensor managed_token( + token_data.data(), 128, token_shape, ScalarType::Int); + ManagedTensor managed_pos_id( + start_pos_data.data(), 128, start_pos_shape, ScalarType::Int); + ManagedTensor managed_logits( + logits_data.data(), 128, logits_data_shape, ScalarType::Float); + + Tensor logits = managed_logits.get_aliasing_tensor(); + Tensor token = managed_token.get_aliasing_tensor(); + Tensor start_pos = managed_pos_id.get_aliasing_tensor(); + + // TODO: investigate why kv_mask was duplicated in the output + // current output: [kv_mask, logits, k_cache, v_cache, kv_mask] + // change following indexes back when issue got resolved + std::vector inputs; + for (int i = 0; i < managed_inputs.size(); ++i) { + inputs.push_back(managed_inputs[i].get_aliasing_tensor()); + ET_CHECK_MSG( + module_->set_output_data_ptr(inputs.back(), i + 2) == Error::Ok, + "Failed to set output tensor"); + } + ET_CHECK_MSG( + module_->set_output_data_ptr(logits, 1) == Error::Ok, + "Failed to set output tensor - logits"); + + // Start consuming user's prompts and generating new tokens + std::string final_output; + while (pos < seq_len - 1) { + // Run the model + Result logits_res = + run_model_step(cur_token, token, start_pos, inputs); + + if (pos == num_prompt_tokens) { + stats_.first_token_ms = util::time_in_ms(); + } else if (pos == num_prompt_tokens - 1) { + stats_.prompt_eval_end_ms = util::time_in_ms(); + } + + ET_CHECK_OK_OR_RETURN_ERROR(logits_res.error()); + exec_aten::Tensor& logits_tensor = logits_res.get(); + prev_token = cur_token; + + long sample_start_time_ms = util::time_in_ms(); + cur_token = logitsToToken(logits_tensor); + stats_.aggregate_sampling_time_ms += + util::time_in_ms() - sample_start_time_ms; + + // advance the state machine + if (pos < num_prompt_tokens - 1) { + // prefill, force the next token to be the next prompt token + cur_token = prompt_tokens[pos + 1]; + } + pos++; + + // print the token as string, decode it with the Tokenizer object + auto piece_res = tokenizer_->decode(prev_token, cur_token); + ET_CHECK(piece_res.ok()); + + if (token_callback) { + token_callback(piece_res.get()); + } + + if (shouldStop_) { + break; + } + + // data-dependent terminating condition: we have n_eos_ number of EOS + if (pos >= num_prompt_tokens && cur_token == eos_id_) { + ET_LOG(Info, "Reached to the end of generation"); + break; + } + } + stats_.inference_end_ms = util::time_in_ms(); + + if (pos == seq_len) { + ET_LOG(Info, "Sequence length (%i tokens) reached!", seq_len); + } + + stats_.num_prompt_tokens = num_prompt_tokens; + stats_.num_generated_tokens = pos - num_prompt_tokens; + printReport(stats_); + if (stats_callback) { + stats_callback(stats_); + } + + return Error::Ok; +} + +namespace { +void printReport(const Runner::Stats& stats) { + printf("PyTorchObserver %s\n", statsToJsonString(stats).c_str()); + + ET_LOG( + Info, + "\tPrompt Tokens: %" PRIu64 " Generated Tokens: %" PRIu64, + stats.num_prompt_tokens, + stats.num_generated_tokens); + + ET_LOG( + Info, + "\tModel Load Time:\t\t%f (seconds)", + ((double)(stats.model_load_end_ms - stats.model_load_start_ms) / + stats.SCALING_FACTOR_UNITS_PER_SECOND)); + double inference_time_ms = + (double)(stats.inference_end_ms - stats.inference_start_ms); + ET_LOG( + Info, + "\tTotal inference time:\t\t%f (seconds)\t\t Rate: \t%f (tokens/second)", + inference_time_ms / stats.SCALING_FACTOR_UNITS_PER_SECOND, + + (stats.num_generated_tokens) / + (double)(stats.inference_end_ms - stats.inference_start_ms) * + stats.SCALING_FACTOR_UNITS_PER_SECOND); + double prompt_eval_time = + (double)(stats.prompt_eval_end_ms - stats.inference_start_ms); + ET_LOG( + Info, + "\t\tPrompt evaluation:\t%f (seconds)\t\t Rate: \t%f (tokens/second)", + prompt_eval_time / stats.SCALING_FACTOR_UNITS_PER_SECOND, + (stats.num_prompt_tokens) / prompt_eval_time * + stats.SCALING_FACTOR_UNITS_PER_SECOND); + + double eval_time = + (double)(stats.inference_end_ms - stats.prompt_eval_end_ms); + ET_LOG( + Info, + "\t\tGenerated %" PRIu64 + " tokens:\t%f (seconds)\t\t Rate: \t%f (tokens/second)", + stats.num_generated_tokens, + eval_time / stats.SCALING_FACTOR_UNITS_PER_SECOND, + stats.num_generated_tokens / eval_time * + stats.SCALING_FACTOR_UNITS_PER_SECOND); + + // Time to first token is measured from the start of inference, excluding + // model load time. + ET_LOG( + Info, + "\tTime to first generated token:\t%f (seconds)", + ((double)(stats.first_token_ms - stats.inference_start_ms) / + stats.SCALING_FACTOR_UNITS_PER_SECOND)); + + ET_LOG( + Info, + "\tSampling time over %" PRIu64 " tokens:\t%f (seconds)", + stats.num_prompt_tokens + stats.num_generated_tokens, + (double)stats.aggregate_sampling_time_ms / + stats.SCALING_FACTOR_UNITS_PER_SECOND); +} + +std::string statsToJsonString(const Runner::Stats& stats) { + std::stringstream ss; + ss << "{\"prompt_tokens\":" << stats.num_prompt_tokens << "," + << "\"generated_tokens\":" << stats.num_generated_tokens << "," + << "\"model_load_start_ms\":" << stats.model_load_start_ms << "," + << "\"model_load_end_ms\":" << stats.model_load_end_ms << "," + << "\"inference_start_ms\":" << stats.inference_start_ms << "," + << "\"inference_end_ms\":" << stats.inference_end_ms << "," + << "\"prompt_eval_end_ms\":" << stats.prompt_eval_end_ms << "," + << "\"first_token_ms\":" << stats.first_token_ms << "," + << "\"aggregate_sampling_time_ms\":" << stats.aggregate_sampling_time_ms + << "," + << "\"SCALING_FACTOR_UNITS_PER_SECOND\":" + << stats.SCALING_FACTOR_UNITS_PER_SECOND << "}"; + return ss.str(); +} +} // namespace + +void Runner::stop() { + shouldStop_ = true; +} + +Result Runner::method_meta() { + return module_->method_meta("forward"); +} + +// explicit instantiation of template methods +template int64_t Runner::getMetadataHelper( + std::string method_name, + int64_t default_val); +template bool Runner::getMetadataHelper( + std::string method_name, + bool default_val); + +} // namespace executor +} // namespace torch diff --git a/examples/qualcomm/llama2/runner/runner.h b/examples/qualcomm/llama2/runner/runner.h new file mode 100644 index 00000000000..120e8af5be7 --- /dev/null +++ b/examples/qualcomm/llama2/runner/runner.h @@ -0,0 +1,102 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// A simple llama2 runner that includes preprocessing and post processing logic. +// The module takes in a string as input and emits a string as output. + +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace torch { +namespace executor { + +class Runner { + public: + explicit Runner( + const std::string& model_path, + const std::string& tokenizer_path, + const float temperature = 0.8f); + + struct Stats { + // Scaling factor for timestamps - in this case, we use ms. + const long SCALING_FACTOR_UNITS_PER_SECOND = 1000; + // Time stamps for the different stages of the execution + // model_load_start_ms: Start of model loading. + long model_load_start_ms; + // model_load_end_ms: End of model loading. + long model_load_end_ms; + // inference_start_ms: Immediately after the model is loaded (or we check + // for model load), measure the inference time. + long inference_start_ms; + // prompt_eval_end_ms: Prompt array allocation and tokenization. Ends right + // before the inference loop starts + long prompt_eval_end_ms; + // first_token: Timestamp when the first generated token is emitted + long first_token_ms; + // inference_end_ms: End of inference/generation. + long inference_end_ms; + // Keep a running total of the time spent in sampling. + long aggregate_sampling_time_ms; + // Token count from prompt + int64_t num_prompt_tokens; + // Token count from generated (total - prompt) + int64_t num_generated_tokens; + }; + + bool is_loaded() const; + Error load(); + Error generate( + const std::string& prompt, + int32_t seq_len, + std::vector& managed_inputs, + std::function token_callback = {}, + std::function stats_callback = {}); + void stop(); + Result method_meta(); + + private: + // metadata + template + T getMetadataHelper(std::string method_name, T default_val); + template + int32_t logitsToToken(const exec_aten::Tensor& logits_tensor); + Result run_model_step( + int64_t input_token, + Tensor& token, + Tensor& start_pos, + std::vector& input_tensors); + // metadata + int32_t vocab_size_; + int64_t bos_id_; + int64_t eos_id_; + int32_t n_bos_; + int32_t n_eos_; + int32_t max_seq_len_; + std::unordered_set model_methods_; + std::unique_ptr module_; + std::string tokenizer_path_; + float temperature_; + std::unique_ptr tokenizer_; + std::unique_ptr sampler_; + bool shouldStop_{false}; + Stats stats_; +}; + +} // namespace executor +} // namespace torch diff --git a/examples/qualcomm/scripts/utils.py b/examples/qualcomm/scripts/utils.py index f8c28371619..5ea7548f66c 100755 --- a/examples/qualcomm/scripts/utils.py +++ b/examples/qualcomm/scripts/utils.py @@ -10,7 +10,7 @@ import sys from pathlib import Path -from typing import Optional +from typing import Callable, List, Optional import numpy as np @@ -30,6 +30,7 @@ generate_htp_compiler_spec, generate_qnn_executorch_compiler_spec, ) +from executorch.exir import EdgeCompileConfig, EdgeProgramManager from executorch.exir.backend.backend_api import to_backend from executorch.exir.capture._config import ExecutorchBackendConfig from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass @@ -48,6 +49,7 @@ def __init__( host_id=None, error_only=False, shared_buffer=False, + runner="examples/qualcomm/qnn_executor_runner", ): self.qnn_sdk = qnn_sdk self.artifact_path = artifact_path @@ -68,6 +70,7 @@ def __init__( self.soc_model = arch_table[soc_model] self.error_only = error_only self.shared_buffer = shared_buffer + self.runner = runner def _adb(self, cmd): if not self.host_id: @@ -80,7 +83,7 @@ def _adb(self, cmd): cmds, stdout=subprocess.DEVNULL if self.error_only else sys.stdout ) - def push(self, inputs, input_list): + def push(self, inputs, input_list, files=None): self._adb(["shell", f"rm -rf {self.workspace}"]) self._adb(["shell", f"mkdir -p {self.workspace}"]) @@ -104,7 +107,7 @@ def push(self, inputs, input_list): ), f"{self.qnn_sdk}/lib/aarch64-android/libQnnHtpPrepare.so", f"{self.qnn_sdk}/lib/aarch64-android/libQnnSystem.so", - f"{self.artifact_path}/examples/qualcomm/qnn_executor_runner", + f"{self.artifact_path}/{self.runner}", f"{self.artifact_path}/backends/qualcomm/libqnn_executorch_backend.so", input_list_file, ]: @@ -112,31 +115,47 @@ def push(self, inputs, input_list): # input data for idx, data in enumerate(inputs): - for i, d in enumerate(data): + flat_inputs = [] + for input in data: + if isinstance(input, list): + for inp in input: + flat_inputs.append(inp) + else: + flat_inputs.append(input) + for i, d in enumerate(flat_inputs): file_name = f"{self.working_dir}/input_{idx}_{i}.raw" d.detach().numpy().tofile(file_name) self._adb(["push", file_name, self.workspace]) - def execute(self): + # extra files + if files is not None: + for f in files: + self._adb(["push", f, self.workspace]) + + def execute(self, custom_runner_cmd=None): self._adb(["shell", f"mkdir -p {self.output_folder}"]) # run the delegation - qnn_executor_runner_args = " ".join( - [ - f"--model_path {os.path.basename(self.pte_path)}", - f"--output_folder_path {self.output_folder}", - f"--input_list_path {self.input_list_filename}", - f"--etdump_path {self.etdump_path}", - "--shared_buffer" if self.shared_buffer else "", - ] - ) - qnn_executor_runner_cmds = " ".join( - [ - f"cd {self.workspace} &&", - "export ADSP_LIBRARY_PATH=. &&", - "export LD_LIBRARY_PATH=. &&", - f"./qnn_executor_runner {qnn_executor_runner_args}", - ] - ) + if custom_runner_cmd is None: + qnn_executor_runner_args = " ".join( + [ + f"--model_path {os.path.basename(self.pte_path)}", + f"--output_folder_path {self.output_folder}", + f"--input_list_path {self.input_list_filename}", + f"--etdump_path {self.etdump_path}", + "--shared_buffer" if self.shared_buffer else "", + ] + ) + qnn_executor_runner_cmds = " ".join( + [ + f"cd {self.workspace} &&", + "export ADSP_LIBRARY_PATH=. &&", + "export LD_LIBRARY_PATH=. &&", + f"./qnn_executor_runner {qnn_executor_runner_args}", + ] + ) + else: + qnn_executor_runner_cmds = custom_runner_cmd + self._adb(["shell", f"{qnn_executor_runner_cmds}"]) def pull(self, output_path, callback=None): @@ -156,16 +175,20 @@ def build_executorch_binary( inputs, # noqa: B006 soc_model, file_name, - dataset, + dataset: List[torch.Tensor] | Callable[[torch.fx.GraphModule], None], custom_annotations=(), skip_node_id_set=None, skip_node_op_set=None, quant_dtype: Optional[QuantDtype] = None, + per_channel_linear=False, # TODO: remove this once QNN fully supports linear + direct_io=False, # TODO: temporal workaround for llama shared_buffer=False, + metadata=None, ): - if quant_dtype: + if quant_dtype is not None: quantizer = QnnQuantizer() quantizer.add_custom_quant_annotations(custom_annotations) + quantizer.set_per_channel_linear_quant(per_channel_linear) if quant_dtype == QuantDtype.use_8a8w: pass # default setting @@ -183,8 +206,11 @@ def build_executorch_binary( annotated_model = prepare_pt2e(captured_model, quantizer) print("Quantizing the model...") # calibration - for data in dataset: - annotated_model(*data) + if callable(dataset): + dataset(annotated_model) + else: + for data in dataset: + annotated_model(*data) quantized_model = convert_pt2e(annotated_model) edge_prog = capture_program(quantized_model, inputs) @@ -208,29 +234,44 @@ def build_executorch_binary( debug=False, saver=False, shared_buffer=shared_buffer, + profile=False, ), skip_node_id_set, skip_node_op_set, ) - edge_prog.exported_program = to_backend(edge_prog.exported_program, qnn_partitioner) - edge_prog.exported_program.graph_module.graph.print_tabular() - exec_prog = edge_prog.to_executorch( - config=ExecutorchBackendConfig( - extract_constant_segment=False, - # For shared buffer, user must pass the memory address - # which is allocated by RPC memory to executor runner. - # Therefore, won't want to pre-allocate - # by memory manager in runtime. - memory_planning_pass=MemoryPlanningPass( - memory_planning_algo="greedy", - alloc_graph_input=not shared_buffer, - alloc_graph_output=not shared_buffer, - ), - extract_delegate_segments=True, - ) + + executorch_config = ExecutorchBackendConfig( + extract_constant_segment=False, + # For shared buffer, user must pass the memory address + # which is allocated by RPC memory to executor runner. + # Therefore, won't want to pre-allocate + # by memory manager in runtime. + memory_planning_pass=MemoryPlanningPass( + memory_planning_algo="greedy", + alloc_graph_input=not shared_buffer and not direct_io, + alloc_graph_output=not shared_buffer and not direct_io, + ), + extract_delegate_segments=True, ) - with open(f"{file_name}.pte", "wb") as file: - file.write(exec_prog.buffer) + + if metadata is None: + edge_prog.exported_program = to_backend( + edge_prog.exported_program, qnn_partitioner + ) + edge_prog.exported_program.graph_module.graph.print_tabular() + exec_prog = edge_prog.to_executorch(config=executorch_config) + with open(f"{file_name}.pte", "wb") as file: + file.write(exec_prog.buffer) + else: + edge_prog_mgr = EdgeProgramManager( + edge_programs={"forward": edge_prog.exported_program}, + constant_methods=metadata, + compile_config=EdgeCompileConfig(_check_ir_validity=False), + ) + edge_prog_mgr = edge_prog_mgr.to_backend(qnn_partitioner) + exec_prog_mgr = edge_prog_mgr.to_executorch(config=executorch_config) + with open(f"{file_name}.pte", "wb") as file: + file.write(exec_prog_mgr.buffer) def make_output_dir(path: str): diff --git a/extension/module/module.cpp b/extension/module/module.cpp index 83ded144469..015c6cfc68e 100644 --- a/extension/module/module.cpp +++ b/extension/module/module.cpp @@ -156,4 +156,11 @@ Result> Module::execute( return outputs; } +Error Module::set_output_data_ptr(Tensor& output_tensor, size_t output_index) { + ET_CHECK_OK_OR_RETURN_ERROR(load_method("forward")); + auto& method = methods_.at("forward").method; + return method->set_output_data_ptr( + output_tensor.mutable_data_ptr(), output_tensor.nbytes(), output_index); +} + } // namespace torch::executor diff --git a/extension/module/module.h b/extension/module/module.h index fb70cb08417..83fff368db8 100644 --- a/extension/module/module.h +++ b/extension/module/module.h @@ -191,6 +191,16 @@ class Module final { return event_tracer_.get(); } + /** + * Set output data pointer for forward method. + * + * @param[in] output_tensor A Tensor for the output of 'forward' method. + * @param[in] output_index Index of the output in 'forward' method. + * + * @returns An Error to indicate success or failure of the loading process. + */ + Error set_output_data_ptr(Tensor& output_tensor, size_t output_index); + private: struct MethodHolder { std::vector> planned_buffers; From 3b6af6422263722a61aa770b474968fd16c9aea2 Mon Sep 17 00:00:00 2001 From: shewu-quic Date: Mon, 22 Apr 2024 11:05:53 +0800 Subject: [PATCH 2/3] resolve uint16 type and reorder input in runtime --- backends/qualcomm/builders/node_visitor.py | 24 +++++++------------ backends/qualcomm/builders/op_split.py | 2 +- backends/qualcomm/builders/qnn_constants.py | 1 - backends/qualcomm/quantizer/utils.py | 8 +++---- .../qualcomm/runtime/QnnExecuTorchBackend.cpp | 11 ++++++++- .../runtime/backends/QnnBackendCache.cpp | 5 ++-- .../executor_runner/qnn_llama_runner.cpp | 6 +---- examples/qualcomm/llama2/llama.py | 2 ++ examples/qualcomm/llama2/runner/runner.cpp | 3 ++- examples/qualcomm/scripts/utils.py | 6 +++-- 10 files changed, 36 insertions(+), 32 deletions(-) diff --git a/backends/qualcomm/builders/node_visitor.py b/backends/qualcomm/builders/node_visitor.py index 56414119d6e..26520afab68 100644 --- a/backends/qualcomm/builders/node_visitor.py +++ b/backends/qualcomm/builders/node_visitor.py @@ -14,8 +14,6 @@ from executorch.exir.dialects._ops import ops as exir_ops -from .qnn_constants import QNN_uint16 - from .utils import get_parameter, is_graph_input, is_graph_output, is_parameter @@ -26,7 +24,7 @@ # Note that there is no int64 tensor data type in Qnn. torch.int64: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UNDEFINED, torch.uint8: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_8, - QNN_uint16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_16, + torch.uint16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_16, } QNN_TENSOR_TYPE_MAP = { torch.float32: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, @@ -35,7 +33,7 @@ torch.int32: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_32, torch.int64: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_64, torch.uint8: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_8, - QNN_uint16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_16, + torch.uint16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_16, float: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, } @@ -169,7 +167,7 @@ def get_quant_encoding_conf( return self.make_qnn_per_tensor_config(quant_attrs) def get_quant_tensor_value( - self, tensor: torch.Tensor, quant_attrs: Dict, dtype, bitwidth + self, tensor: torch.Tensor, quant_attrs: Dict, quant_configs: Dict ) -> torch.Tensor: if quant_attrs["encoding"] in PER_TENSOR_ENCODING: scale = quant_attrs["scale"] @@ -178,16 +176,11 @@ def get_quant_tensor_value( scale = quant_attrs["scales"] zero_point = quant_attrs["zero_points"] - # To bypass torch.uint16 quantization is not supported - dtype = ( - torch.int32 - if dtype == PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_16 - else quant_attrs["dtype"] - ) + dtype = quant_configs["dtype"] tensor = tensor.div(scale).add(zero_point).round().to(dtype) # Make the backends access data correctly - if bitwidth == 4: + if quant_configs.get("bitwidth") == 4: mask = torch.full(tensor.size(), 0x0F, dtype=torch.int8) tensor = torch.bitwise_and(mask, tensor) return tensor @@ -236,7 +229,7 @@ def get_data_type( <= torch.iinfo(torch.int16).max - torch.iinfo(torch.int16).min ): if unsigned: - quant_config["dtype"] = QNN_uint16 + quant_config["dtype"] = torch.uint16 else: quant_config["dtype"] = torch.int16 return QNN_QUANT_TYPE_MAP[quant_config["dtype"]] @@ -304,6 +297,8 @@ def define_tensor( return cached tensor_name = node.name + if is_graph_input(node, self.edge_program): + tensor_name = "QnnInput_"+str(self.external_ids[node])+"_"+ tensor_name if is_graph_output(node): tensor_name = "output_" + tensor_name dims = [1] if len(tensor.size()) == 0 else tensor.size() @@ -329,8 +324,7 @@ def define_tensor( tensor = self.get_quant_tensor_value( tensor, node.meta["quant_attrs"], - dtype, - quant_configs.get("bitwidth"), + quant_configs, ) tensor_wrapper = PyQnnWrapper.TensorWrapper( tensor_name, diff --git a/backends/qualcomm/builders/op_split.py b/backends/qualcomm/builders/op_split.py index b964586c446..00bfb3e556d 100644 --- a/backends/qualcomm/builders/op_split.py +++ b/backends/qualcomm/builders/op_split.py @@ -15,7 +15,7 @@ @register_node_visitor -class Softmax(NodeVisitor): +class Split(NodeVisitor): target = ["aten.split_with_sizes.default"] def __init__(self, *args) -> None: diff --git a/backends/qualcomm/builders/qnn_constants.py b/backends/qualcomm/builders/qnn_constants.py index da64594cfab..092e3b208f0 100644 --- a/backends/qualcomm/builders/qnn_constants.py +++ b/backends/qualcomm/builders/qnn_constants.py @@ -8,7 +8,6 @@ from enum import IntEnum, unique QNN_OP_PACKAGE_NAME_QTI_AISW = "qti.aisw" -QNN_uint16 = "uint16" # Below constants should be same as those in QNN headers. # Maybe someday we should expose these constants by pybind diff --git a/backends/qualcomm/quantizer/utils.py b/backends/qualcomm/quantizer/utils.py index bd28752617c..f9e0748b967 100644 --- a/backends/qualcomm/quantizer/utils.py +++ b/backends/qualcomm/quantizer/utils.py @@ -113,14 +113,14 @@ def get_default_8bit_qnn_ptq_config() -> QuantizationConfig: # 4 bits quantization only supports specific ops. -def get_16a4w_qnn_ptq_config() -> QuantizationConfig: +def get_16a4w_qnn_ptq_config(act_observer=MovingAverageMinMaxObserver) -> QuantizationConfig: extra_args: Dict[str, Any] = {"eps": 2**-20} act_quantization_spec = QuantizationSpec( dtype=torch.int32, quant_min=torch.iinfo(torch.uint16).min, quant_max=torch.iinfo(torch.uint16).max, qscheme=torch.per_tensor_affine, - observer_or_fake_quant_ctr=MovingAverageMinMaxObserver.with_args(**extra_args), + observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), ) weight_quantization_spec = QuantizationSpec( @@ -150,14 +150,14 @@ def get_16a4w_qnn_ptq_config() -> QuantizationConfig: return quantization_config -def get_default_16bit_qnn_ptq_config() -> QuantizationConfig: +def get_default_16bit_qnn_ptq_config(act_observer=MovingAverageMinMaxObserver) -> QuantizationConfig: extra_args: Dict[str, Any] = {"eps": 2**-20} act_quantization_spec = QuantizationSpec( dtype=torch.int32, quant_min=torch.iinfo(torch.uint16).min, quant_max=torch.iinfo(torch.uint16).max, qscheme=torch.per_tensor_affine, - observer_or_fake_quant_ctr=MovingAverageMinMaxObserver.with_args(**extra_args), + observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), ) weight_quantization_spec = QuantizationSpec( diff --git a/backends/qualcomm/runtime/QnnExecuTorchBackend.cpp b/backends/qualcomm/runtime/QnnExecuTorchBackend.cpp index 77449703c5f..8c787cf7981 100644 --- a/backends/qualcomm/runtime/QnnExecuTorchBackend.cpp +++ b/backends/qualcomm/runtime/QnnExecuTorchBackend.cpp @@ -11,7 +11,7 @@ #include #include #include - +#include #include namespace torch { namespace executor { @@ -20,6 +20,12 @@ using namespace qnn; using namespace qnn_delegate; constexpr const char* QNN_COMPILE_SPEC = "qnn_compile_spec"; +bool CompareQnnInput(const std::shared_ptr& a, const std::shared_ptr& b) { + int numA = std::stoi(a->GetName().substr(a->GetName().find('_') + 1)); + int numB = std::stoi(b->GetName().substr(b->GetName().find('_') + 1)); + return numA < numB; +} + Result QnnExecuTorchBackend::init( BackendInitContext& context, FreeableBuffer* processed, @@ -187,6 +193,9 @@ Error QnnExecuTorchBackend::execute( qnn_manager->GetGraphOutputs(); std::vector input_tensor_structs; std::vector output_tensor_structs; + // Using the order of the nodes as external_id in AOT + // to extract the right arg from *args at runtime + std::sort(input_tensors.begin(), input_tensors.end(), CompareQnnInput); input_tensor_structs.reserve(input_tensors.size()); for (int i = 0; i < input_tensors.size(); ++i) { diff --git a/backends/qualcomm/runtime/backends/QnnBackendCache.cpp b/backends/qualcomm/runtime/backends/QnnBackendCache.cpp index 8c753011ba6..b8ac289133b 100644 --- a/backends/qualcomm/runtime/backends/QnnBackendCache.cpp +++ b/backends/qualcomm/runtime/backends/QnnBackendCache.cpp @@ -87,7 +87,8 @@ QnnBackendCache::QnnBackendCache( state_ = SERIALIZE; QNN_EXECUTORCH_LOG_INFO("Caching: Caching is in SAVE MODE."); return; - } else { + } + /*else { // TODO: need fix on this since qnn context binary could somehow // pass the check of flatbuffer verifier // check if context binary came from flatbuffer @@ -100,7 +101,7 @@ QnnBackendCache::QnnBackendCache( state_ = ONLINE_PREPARE; return; } - } + }*/ if (qnn_sys_impl_.Load() != Error::Ok) { QNN_EXECUTORCH_LOG_ERROR( diff --git a/examples/qualcomm/executor_runner/qnn_llama_runner.cpp b/examples/qualcomm/executor_runner/qnn_llama_runner.cpp index 0e959a1caab..9733baf5c12 100644 --- a/examples/qualcomm/executor_runner/qnn_llama_runner.cpp +++ b/examples/qualcomm/executor_runner/qnn_llama_runner.cpp @@ -100,7 +100,7 @@ int main(int argc, char** argv) { if (input_files.size() == 0) { break; } - // inputs: [tokens, pos_ids, atten_mask, kv_mask, k_cache, v_cache] + // inputs: [tokens, pos_ids, kv_mask, *k_cache, *v_cache] // tokens are determined by command line arguments // pos_ids are infered inside runner std::vector managed_inputs; @@ -120,10 +120,6 @@ int main(int argc, char** argv) { tensor_meta->nbytes()); inputs[input_index].resize(tensor_meta->nbytes()); - if (input_index <= 2) { - fin.seekg(0, fin.beg); - fin.read(inputs[input_index].data(), file_size); - } fin.close(); auto tensor_shape = tensor_meta->sizes(); diff --git a/examples/qualcomm/llama2/llama.py b/examples/qualcomm/llama2/llama.py index 3a43793e68e..49170184742 100644 --- a/examples/qualcomm/llama2/llama.py +++ b/examples/qualcomm/llama2/llama.py @@ -12,6 +12,7 @@ from functools import partial import torch +from torch.ao.quantization.observer import MinMaxObserver from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.backends.qualcomm.utils.utils import convert_linear_to_conv2d @@ -206,6 +207,7 @@ def sample_top_p(probs: torch.Tensor, top_p: float) -> torch.Tensor: shared_buffer=args.shared_buffer, metadata=instance.get_metadata(), direct_io=True, + act_observer=MinMaxObserver ) if args.compile_only: diff --git a/examples/qualcomm/llama2/runner/runner.cpp b/examples/qualcomm/llama2/runner/runner.cpp index 881fe96ae4e..0b4fa9e71af 100644 --- a/examples/qualcomm/llama2/runner/runner.cpp +++ b/examples/qualcomm/llama2/runner/runner.cpp @@ -139,7 +139,8 @@ Result Runner::run_model_step( Tensor& start_pos, std::vector& input_tensors) { token.mutable_data_ptr()[0] = input_token; - // inputs:[tokens, start_pos, atten_mask, kv_mask, k_cache, v_cache] + // inputs:[tokens, start_pos, kv_mask, k_cache, v_cache] + // input_tensors:[kv_mask, k_cache, v_cache] std::vector inputs = {token, start_pos}; inputs.insert(inputs.end(), input_tensors.begin(), input_tensors.end()); diff --git a/examples/qualcomm/scripts/utils.py b/examples/qualcomm/scripts/utils.py index 5ea7548f66c..1064d9ff3a2 100755 --- a/examples/qualcomm/scripts/utils.py +++ b/examples/qualcomm/scripts/utils.py @@ -15,6 +15,7 @@ import numpy as np import torch +from torch.ao.quantization.observer import MovingAverageMinMaxObserver from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner from executorch.backends.qualcomm.quantizer.quantizer import ( get_16a4w_qnn_ptq_config, @@ -184,6 +185,7 @@ def build_executorch_binary( direct_io=False, # TODO: temporal workaround for llama shared_buffer=False, metadata=None, + act_observer=MovingAverageMinMaxObserver ): if quant_dtype is not None: quantizer = QnnQuantizer() @@ -194,10 +196,10 @@ def build_executorch_binary( pass # default setting elif quant_dtype == QuantDtype.use_16a16w: quantizer.add_16bit_quant_ops(quantizer.SUPPORTED_OPS) - quantizer.set_bit16_op_quant_config(get_default_16bit_qnn_ptq_config()) + quantizer.set_bit16_op_quant_config(get_default_16bit_qnn_ptq_config(act_observer=act_observer)) elif quant_dtype == QuantDtype.use_16a4w: quantizer.add_16bit_quant_ops(quantizer.SUPPORTED_OPS) - quantizer.set_bit16_op_quant_config(get_16a4w_qnn_ptq_config()) + quantizer.set_bit16_op_quant_config(get_16a4w_qnn_ptq_config(act_observer=act_observer)) quantizer.set_per_channel_weight_dtype(weight_dtype_for_16bit_act="int4") else: raise AssertionError(f"No support for QuantDtype {quant_dtype}.") From 9646142b4fdfde96befc1e025c8fd13d31e83864 Mon Sep 17 00:00:00 2001 From: chiwwang Date: Thu, 16 May 2024 23:09:07 -0700 Subject: [PATCH 3/3] Split llama and modify architecture for performance Note that this branch is for an example. llama2 cannot work by this branch. What we did to optimize performance on HTP is listed: 1. One multihead attentions is transformed to multiple single head. 2. KV-cache is changed to graph I/O. The update is performed in qnn_llama_runner.cpp on CPU. 3. llama2 is partitioned to 6 pte files in examples/qualcomm/llama2/composite_llama.py 4. Embedding is quantized. This might need further investigation, e.g., can we move it out of the model on CPU..etc 5. Support u16 and u8 mixed-precision quantization. 6. KV-cache is left as quantized format in graph I/O. 7. RMSNorm is tweaked a bit to reduce the quantization sensitivity. 8. HTP Spill-Fill buffer feature is used among pte files. 9. Convert all Linear layers to Conv2d. 10 Properly set quant_min and quant_max in Observers to offset=128 in symmetrical quantization. --- backends/qualcomm/builders/__init__.py | 8 +- backends/qualcomm/builders/node_visitor.py | 33 +- backends/qualcomm/builders/op_linear.py | 2 +- backends/qualcomm/builders/op_log_softmax.py | 1 - backends/qualcomm/builders/op_slice_copy.py | 4 +- .../builders/{op_cast.py => op_sqrt.py} | 24 +- backends/qualcomm/builders/op_sum_int_list.py | 80 ++ backends/qualcomm/builders/op_to.py | 104 +++ backends/qualcomm/builders/qnn_constants.py | 17 + backends/qualcomm/builders/utils.py | 19 + backends/qualcomm/partition/common_defs.py | 1 + .../qualcomm/partition/qnn_partitioner.py | 4 + backends/qualcomm/passes/build_quant_io.py | 52 ++ .../passes/fuse_consecutive_transpose.py | 43 +- backends/qualcomm/passes/insert_io_qdq.py | 37 +- backends/qualcomm/passes/insert_requantize.py | 42 +- backends/qualcomm/passes/layout_transform.py | 8 +- backends/qualcomm/passes/utils.py | 2 + backends/qualcomm/qnn_preprocess.py | 4 +- backends/qualcomm/quantizer/quantizer.py | 2 + backends/qualcomm/quantizer/utils.py | 80 +- backends/qualcomm/setup.md | 1 - backends/qualcomm/tests/models.py | 105 ++- backends/qualcomm/tests/test_qnn_delegate.py | 80 +- backends/qualcomm/utils/utils.py | 14 +- .../executor_runner/qnn_llama_runner.cpp | 125 ++- examples/qualcomm/llama2/composite_llama.py | 873 ++++++++++++++++++ examples/qualcomm/llama2/llama.py | 33 +- .../qualcomm/llama2/model/static_llama.py | 307 +++--- examples/qualcomm/llama2/runner/runner.cpp | 279 ++++-- examples/qualcomm/llama2/runner/runner.h | 22 +- examples/qualcomm/scripts/utils.py | 28 +- 32 files changed, 2015 insertions(+), 419 deletions(-) rename backends/qualcomm/builders/{op_cast.py => op_sqrt.py} (70%) create mode 100644 backends/qualcomm/builders/op_sum_int_list.py create mode 100644 backends/qualcomm/builders/op_to.py create mode 100644 backends/qualcomm/passes/build_quant_io.py create mode 100644 examples/qualcomm/llama2/composite_llama.py diff --git a/backends/qualcomm/builders/__init__.py b/backends/qualcomm/builders/__init__.py index 349c4404991..5bd02ce53af 100644 --- a/backends/qualcomm/builders/__init__.py +++ b/backends/qualcomm/builders/__init__.py @@ -10,7 +10,6 @@ op_avg_pool2d, op_batch_norm, op_bmm, - op_cast, op_cat, op_ceil, op_clamp, @@ -42,9 +41,12 @@ op_slice_copy, op_softmax, op_split, + op_sqrt, op_squeeze, op_sub, + op_sum_int_list, op_tanh, + op_to, op_transpose, op_unsqueeze, op_upsample_bilinear2d, @@ -56,7 +58,6 @@ op_avg_pool2d, op_batch_norm, op_bmm, - op_cast, op_cat, op_ceil, op_clamp, @@ -88,8 +89,11 @@ op_softmax, op_split, op_squeeze, + op_sqrt, op_sub, + op_sum_int_list, op_tanh, + op_to, op_transpose, op_unsqueeze, op_upsample_bilinear2d, diff --git a/backends/qualcomm/builders/node_visitor.py b/backends/qualcomm/builders/node_visitor.py index 26520afab68..060dd77fa66 100644 --- a/backends/qualcomm/builders/node_visitor.py +++ b/backends/qualcomm/builders/node_visitor.py @@ -14,7 +14,13 @@ from executorch.exir.dialects._ops import ops as exir_ops -from .utils import get_parameter, is_graph_input, is_graph_output, is_parameter +from .utils import ( + deduce_dtype, + get_parameter, + is_graph_input, + is_graph_output, + is_parameter, +) QNN_QUANT_TYPE_MAP = { @@ -27,6 +33,7 @@ torch.uint16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_16, } QNN_TENSOR_TYPE_MAP = { + torch.bool: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, torch.float32: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, torch.int8: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_8, torch.int16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_16, @@ -214,24 +221,9 @@ def get_data_type( self, tensor: torch.Tensor, quant_config: Dict, - is_tensor: bool, ) -> PyQnnWrapper.Qnn_TensorType_t: - if quant_config and is_tensor: - quant_range = quant_config["quant_max"] - quant_config["quant_min"] - unsigned = quant_config["quant_min"] >= 0 - if quant_range <= torch.iinfo(torch.int8).max - torch.iinfo(torch.int8).min: - if unsigned: - quant_config["dtype"] = torch.uint8 - else: - quant_config["dtype"] = torch.int8 - elif ( - quant_range - <= torch.iinfo(torch.int16).max - torch.iinfo(torch.int16).min - ): - if unsigned: - quant_config["dtype"] = torch.uint16 - else: - quant_config["dtype"] = torch.int16 + if quant_config: + quant_config["dtype"] = deduce_dtype(tensor, quant_config) return QNN_QUANT_TYPE_MAP[quant_config["dtype"]] else: return QNN_TENSOR_TYPE_MAP[tensor.dtype] @@ -277,7 +269,6 @@ def define_tensor( is_input_tensor: bool, node_name: str = None, wrapper_idx: int = 0, - is_tensor: bool = True, ) -> PyQnnWrapper.TensorWrapper: """ Covert torch.Tensor to TensorWrapper @@ -298,7 +289,7 @@ def define_tensor( tensor_name = node.name if is_graph_input(node, self.edge_program): - tensor_name = "QnnInput_"+str(self.external_ids[node])+"_"+ tensor_name + tensor_name = "QnnInput_" + str(self.external_ids[node]) + "_" + tensor_name if is_graph_output(node): tensor_name = "output_" + tensor_name dims = [1] if len(tensor.size()) == 0 else tensor.size() @@ -306,7 +297,7 @@ def define_tensor( quant_encoding, quant_configs = self.get_quant_encoding_conf( node, is_input_tensor ) - dtype = self.get_data_type(tensor, quant_configs, is_tensor) + dtype = self.get_data_type(tensor, quant_configs) if isinstance(tensor, torch._subclasses.fake_tensor.FakeTensor): tensor_wrapper = PyQnnWrapper.TensorWrapper( tensor_name, diff --git a/backends/qualcomm/builders/op_linear.py b/backends/qualcomm/builders/op_linear.py index 78d1e6244e9..9a593528219 100644 --- a/backends/qualcomm/builders/op_linear.py +++ b/backends/qualcomm/builders/op_linear.py @@ -62,7 +62,7 @@ def define_node( bias_node = node.args[2] # TODO remove this when qnn sdk support - if "scales" in bias_node.meta.get("quant_attrs"): + if "scales" in bias_node.meta.get("quant_attrs", {}): print( f"[WARNING] Fallback linear bias, {bias_node}. per channel bias quantization is not support yet." ) diff --git a/backends/qualcomm/builders/op_log_softmax.py b/backends/qualcomm/builders/op_log_softmax.py index c159b9bf00e..002dd5bc9b2 100644 --- a/backends/qualcomm/builders/op_log_softmax.py +++ b/backends/qualcomm/builders/op_log_softmax.py @@ -72,5 +72,4 @@ def define_node( PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, {"data": np.uint32(dim)}, ) - # pdb.set_trace() return log_softmax_op diff --git a/backends/qualcomm/builders/op_slice_copy.py b/backends/qualcomm/builders/op_slice_copy.py index 7972fb3dd92..3a294e35486 100644 --- a/backends/qualcomm/builders/op_slice_copy.py +++ b/backends/qualcomm/builders/op_slice_copy.py @@ -61,7 +61,9 @@ def define_node( ranges = [] for i in range(input_tensor_rank): if i == dim: - ranges.extend([start, end, 1]) + # find step + step = node.args[4] if len(node.args) > 4 else 1 + ranges.extend([start, end, step]) else: ranges.extend([0, input_tensor.shape[i], 1]) diff --git a/backends/qualcomm/builders/op_cast.py b/backends/qualcomm/builders/op_sqrt.py similarity index 70% rename from backends/qualcomm/builders/op_cast.py rename to backends/qualcomm/builders/op_sqrt.py index d3096ee27cf..7847d00e8b8 100644 --- a/backends/qualcomm/builders/op_cast.py +++ b/backends/qualcomm/builders/op_sqrt.py @@ -10,12 +10,12 @@ import torch from .node_visitor import NodeVisitor, register_node_visitor -from .qnn_constants import OpCast, QNN_OP_PACKAGE_NAME_QTI_AISW +from .qnn_constants import OpSqrt, QNN_OP_PACKAGE_NAME_QTI_AISW @register_node_visitor -class Cast(NodeVisitor): - target = ["aten._to_copy.default"] +class SQRT(NodeVisitor): + target = ["aten.sqrt.default"] def __init__(self, *args) -> None: super().__init__(*args) @@ -25,6 +25,7 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: + # tensor input input_node = node.args[0] input_tensor = self.get_tensor(input_node, node) @@ -35,23 +36,24 @@ def define_node( nodes_to_wrappers, is_input_tensor=True, ) + sqrt_input_tensors = [input_tensor_wrapper] - output_tensor = self.get_tensor(node, node) - + out_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( node, - output_tensor, + out_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, is_input_tensor=False, ) + sqrt_output_tensors = [output_tensor_wrapper] - cast_op = PyQnnWrapper.PyQnnOpWrapper( + sqrt_op = PyQnnWrapper.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, - OpCast.op_name, + OpSqrt.op_name, ) - cast_op.AddInputTensors([input_tensor_wrapper]) - cast_op.AddOutputTensors([output_tensor_wrapper]) + sqrt_op.AddInputTensors(sqrt_input_tensors) + sqrt_op.AddOutputTensors(sqrt_output_tensors) - return cast_op + return sqrt_op diff --git a/backends/qualcomm/builders/op_sum_int_list.py b/backends/qualcomm/builders/op_sum_int_list.py new file mode 100644 index 00000000000..26cc262462e --- /dev/null +++ b/backends/qualcomm/builders/op_sum_int_list.py @@ -0,0 +1,80 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import cast, Dict, List + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper + +import numpy as np +import torch + +from .node_visitor import NodeVisitor, register_node_visitor +from .qnn_constants import OpReduceSum, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class Sum(NodeVisitor): + target = ["aten.sum.dim_IntList"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + + input_node = node.args[0] + input_tensor = self.get_tensor(input_node, node) + input_tensor_wrapper = self.define_tensor( + input_node, + input_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + is_input_tensor=True, + ) + sum_input_tensors = [input_tensor_wrapper] + + # sum dims + sum_dims = cast(List[int], node.args[1]) + sum_dims = [sum_dim % len(input_node.meta["val"].shape) for sum_dim in sum_dims] + if "axis_order" in node.meta: + sum_dims = [node.meta["axis_order"].index(sum_dim) for sum_dim in sum_dims] + sum_dims_shape = [len(sum_dims)] + + output_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + output_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + is_input_tensor=False, + ) + sum_output_tensors = [output_tensor_wrapper] + sum_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpReduceSum.op_name, + ) + sum_op.AddInputTensors(sum_input_tensors) + sum_op.AddOutputTensors(sum_output_tensors) + sum_op.AddTensorParam( + OpReduceSum.param_axes, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + len(sum_dims_shape), + sum_dims_shape, + np.array(sum_dims, dtype=np.uint32), + True, + ) + + if len(node.args) > 2: + keep_dims = cast(bool, node.args[2]) + sum_op.AddScalarParam( + OpReduceSum.param_keep_dims, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, + {"data": keep_dims}, + ) + return sum_op diff --git a/backends/qualcomm/builders/op_to.py b/backends/qualcomm/builders/op_to.py new file mode 100644 index 00000000000..d2762eb8f6b --- /dev/null +++ b/backends/qualcomm/builders/op_to.py @@ -0,0 +1,104 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import Dict + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper + +import torch + +from .node_visitor import NodeVisitor, register_node_visitor +from .qnn_constants import OpCast, OpConvert, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class To(NodeVisitor): + target = ["aten._to_copy.default"] + sufixed_8_offset_diff = 128 + sufixed_16_offset_diff = 32768 + epsilon = 1e-6 + sufixed_8 = { + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_SFIXED_POINT_8, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_8, + } + sufixed_16 = { + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_SFIXED_POINT_16, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_16, + } + + def __init__(self, *args) -> None: + super().__init__(*args) + + def is_cast_node(self, node): + input_node = node.args[0] + + # Not a case which has two quant node, no need to consider the convert op + if not all([input_node.meta.get("quant_attrs"), node.meta.get("quant_attrs")]): + return True + + input_tensor = self.get_tensor(input_node, node) + _, inp_qconfs = self.get_quant_encoding_conf(input_node, False) + inp_dtype = self.get_data_type(input_tensor, inp_qconfs) + + output_tensor = self.get_tensor(node, node) + _, out_qconfs = self.get_quant_encoding_conf(node, False) + out_dtype = self.get_data_type(output_tensor, out_qconfs) + is_qparam_castalbe = ( + lambda o1, o2, s1, s2, diff: abs(s1 - s2) < self.epsilon + and abs(o1 - o2) == diff + ) + + if {inp_dtype, out_dtype} == self.sufixed_8: + return is_qparam_castalbe( + inp_qconfs["offset"], + out_qconfs["offset"], + inp_qconfs["scale"], + out_qconfs["scale"], + self.sufixed_8_offset_diff, + ) + elif {inp_dtype, out_dtype} == self.sufixed_16: + return is_qparam_castalbe( + inp_qconfs["offset"], + out_qconfs["offset"], + inp_qconfs["scale"], + out_qconfs["scale"], + self.sufixed_16_offset_diff, + ) + return False + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + input_node = node.args[0] + input_tensor = self.get_tensor(input_node, node) + + input_tensor_wrapper = self.define_tensor( + input_node, + input_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + is_input_tensor=True, + ) + + output_tensor = self.get_tensor(node, node) + + output_tensor_wrapper = self.define_tensor( + node, + output_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + is_input_tensor=False, + ) + + qnn_op = OpCast if self.is_cast_node(node) else OpConvert + op = PyQnnWrapper.PyQnnOpWrapper( + node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, qnn_op.op_name + ) + op.AddInputTensors([input_tensor_wrapper]) + op.AddOutputTensors([output_tensor_wrapper]) + + return op diff --git a/backends/qualcomm/builders/qnn_constants.py b/backends/qualcomm/builders/qnn_constants.py index 092e3b208f0..c776fe5a346 100644 --- a/backends/qualcomm/builders/qnn_constants.py +++ b/backends/qualcomm/builders/qnn_constants.py @@ -39,6 +39,11 @@ class OpConv2d: param_dilation: str = "dilation" +@dataclass(init=False, frozen=True) +class OpConvert: + op_name: str = "Convert" + + @dataclass(init=False, frozen=True) class OpDepthToSpace: op_name: str = "DepthToSpace" @@ -105,6 +110,13 @@ class OpExpandDims: param_axis: str = "axis" +@dataclass(init=False, frozen=True) +class OpReduceSum: + op_name: str = "ReduceSum" + param_axes: str = "axes" + param_keep_dims: str = "keep_dims" + + @dataclass(init=False, frozen=True) class OpFullyConnected: op_name: str = "FullyConnected" @@ -122,6 +134,11 @@ class OpGelu: op_name: str = "Gelu" +@dataclass(init=False, frozen=True) +class OpSqrt: + op_name: str = "ElementWiseSquareRoot" + + @dataclass(init=False, frozen=True) class OpHardSwish: op_name: str = "HardSwish" diff --git a/backends/qualcomm/builders/utils.py b/backends/qualcomm/builders/utils.py index 92c129f342f..217c840553c 100755 --- a/backends/qualcomm/builders/utils.py +++ b/backends/qualcomm/builders/utils.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Dict, Optional + import torch from torch._export.utils import get_buffer, get_param, is_buffer, is_param @@ -97,3 +99,20 @@ def is_constant( return tensor.meta["val"].constant is not None return False + + +def deduce_dtype( + tensor: torch.Tensor, quant_infos: Optional[Dict] = None +) -> torch.dtype: + if quant_infos: + quant_range = quant_infos["quant_max"] - quant_infos["quant_min"] + unsigned = quant_infos["quant_min"] >= 0 + if quant_range <= torch.iinfo(torch.int8).max - torch.iinfo(torch.int8).min + 1: + return torch.uint8 if unsigned else torch.int8 + + elif quant_range <= torch.iinfo(torch.int16).max - torch.iinfo(torch.int16).min: + return torch.uint16 if unsigned else torch.int16 + + return quant_infos["dtype"] + + return tensor.dtype diff --git a/backends/qualcomm/partition/common_defs.py b/backends/qualcomm/partition/common_defs.py index 280d2ac452e..61935cf3536 100644 --- a/backends/qualcomm/partition/common_defs.py +++ b/backends/qualcomm/partition/common_defs.py @@ -12,6 +12,7 @@ exir_ops.edge.aten.arange.start_step, exir_ops.edge.aten.clone.default, exir_ops.edge.aten.full.default, + exir_ops.edge.aten.slice_scatter.default, exir_ops.edge.aten.index.Tensor, exir_ops.edge.aten.index_put.default, ] diff --git a/backends/qualcomm/partition/qnn_partitioner.py b/backends/qualcomm/partition/qnn_partitioner.py index afb43778349..0c5b25284eb 100644 --- a/backends/qualcomm/partition/qnn_partitioner.py +++ b/backends/qualcomm/partition/qnn_partitioner.py @@ -96,6 +96,9 @@ def is_node_supported(self, _, node: torch.fx.Node) -> bool: print(f"[QNN Partitioner Op Support]: {node.target.__name__} | {supported}") return supported + def __del__(self): + self.qnn_manager.Destroy() + class QnnPartitioner(Partitioner): def __init__( @@ -145,6 +148,7 @@ def partition(self, edge_program: torch.export.ExportedProgram) -> PartitionResu # pop certain keys in meta for not affecting the passes in compilation # TODO: need to put property name in common definitions node.meta.pop("axis_order", "") + del self.op_support_checker return PartitionResult( tagged_exported_program=edge_program, partition_tags=self.partition_tags ) diff --git a/backends/qualcomm/passes/build_quant_io.py b/backends/qualcomm/passes/build_quant_io.py new file mode 100644 index 00000000000..44b66592f3c --- /dev/null +++ b/backends/qualcomm/passes/build_quant_io.py @@ -0,0 +1,52 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import torch + +from executorch.exir.pass_base import ExportPass, PassResult +from executorch.exir.tensor import TensorSpec + +from .utils import q_io_key + + +class BuildQuantIo(ExportPass): + + def __init__(self): + super(BuildQuantIo, self).__init__() + + def _make_spec(self, x): + if isinstance(x, torch.Tensor): + return TensorSpec.from_tensor(x) + elif isinstance(x, (int, bool, float)): + return x + else: + return None + + def _build(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: + # forcely update delegate node's meta['spec'] to get correct output + # tensor size in runtime + call_delegate = [ + node + for node in graph_module.graph.nodes + if node.op == "call_function" and node.name == "executorch_call_delegate" + ] + assert len(call_delegate) == 1 + spec = [] + for n in graph_module.graph.nodes: + if q_io_key in n.meta: + n.meta["val"] = n.meta["val"].to(dtype=n.meta[q_io_key]) + if n.op == "call_function" and "getitem" in n.name: + fake_tensor = n.meta["val"] + if q_io_key in n.meta: + fake_tensor = fake_tensor.to(dtype=n.meta[q_io_key]) + spec.append(self._make_spec(fake_tensor)) + + call_delegate[0].meta["spec"] = tuple(spec) + + def call(self, graph_module: torch.fx.GraphModule): + self._build(graph_module) + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/passes/fuse_consecutive_transpose.py b/backends/qualcomm/passes/fuse_consecutive_transpose.py index aa5dc151a29..740b91dfaac 100644 --- a/backends/qualcomm/passes/fuse_consecutive_transpose.py +++ b/backends/qualcomm/passes/fuse_consecutive_transpose.py @@ -6,6 +6,7 @@ import torch +from executorch.backends.qualcomm.passes.layout_transform import LayoutTransform from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -27,25 +28,33 @@ def __init__(self): self.nodes = [] def _traverse(self, node): - if node.op == "call_function" and node.target in self.op_map: - self.nodes.append(node) - self.visited.add(node) - if len(node.users) == 1: - self._traverse(list(node.users)[0]) + if node in self.visited or not node.target in self.op_map: + return + + self.nodes.append(node) + self.visited.add(node) + next_users = [n for n in list(node.users) if n.target in self.op_map] + if not next_users: + return + + if len(next_users) == 1: + self._traverse(list(node.users)[0]) + else: + raise NotImplementedError( + f"Check the node {node}, wich encounter mutilple permute output case" + ) def _fuse(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: graph = graph_module.graph for n in graph_module.graph.nodes: - if n in self.visited: - continue - if n.op == "call_function" and n.target in self.op_map: - self._traverse(n) - num_nodes = len(self.nodes) - if num_nodes > 1: + self._traverse(n) + if len(self.nodes) > 1: + permute_order = [] input_node, output_node = self.nodes[0].args[0], self.nodes[-1] input_shape = input_node.meta["val"].shape axis_order = torch.arange(len(input_shape)).tolist() for node in self.nodes: + permute_order.append(node.args[1]) axis_order = [axis_order[i] for i in node.args[1]] with graph.inserting_after(input_node): permute_op = exir_ops.edge.aten.permute_copy.default @@ -55,11 +64,19 @@ def _fuse(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: users = output_node.users.copy() for user in users: user.replace_input_with(output_node, permute_node) - # copy metadata + # copy metadata permute_node.meta = output_node.meta - # clear current stack + # Without inserted_permute_tag, we might obtain wrong input shape + if any( + [ + pn.meta.get(LayoutTransform.inserted_permute_tag) + for pn in self.nodes + ] + ): + permute_node.meta[LayoutTransform.inserted_permute_tag] = True + # clear current stack self.nodes = [] def call(self, graph_module: torch.fx.GraphModule): diff --git a/backends/qualcomm/passes/insert_io_qdq.py b/backends/qualcomm/passes/insert_io_qdq.py index 5e6a03799cf..d88ca24fbba 100644 --- a/backends/qualcomm/passes/insert_io_qdq.py +++ b/backends/qualcomm/passes/insert_io_qdq.py @@ -11,7 +11,7 @@ from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult -from .utils import dq_ops, q_ops +from .utils import q_io_key, q_ops class InsertIOQDQ(ExportPass): @@ -107,41 +107,18 @@ def _insert_dequant_node( if user.op == "output": user.replace_input_with(node, inserted_node) - # When having requantization dq/q nodes at the input, - # such as the case: input1 -> dq_node1 -> q_node1 -> node1, - # we should fold the dq_node1 and connect input -> q_node1 -> node1. - def _fold_mix_quantization_dq_node(self, graph_module, input_node): - input_users = list(input_node.users.keys()) - for input_user in input_users: - if input_user.target not in dq_ops: - continue - dq_users = list(input_user.users.keys()) - for dq_user in dq_users: - dq_user.replace_input_with(input_user, input_node) - - # When having requantization dq/q nodes at the output, - # such as the case: node(int32) -> dq(int32) -> q(uint8) -> output(int32), - # we should fold the q node and connect node(int32) -> dq(int32) -> output(int32). - def _fold_mix_quantization_q_node(self, graph_module, node, users): - for user in users: - if user.op == "output": - output_node = user - break - - dq_node = node.args[0] - for out_node in output_node.meta["val"]: - if dq_node.meta["val"].dtype == out_node.dtype: - user.replace_input_with(node, dq_node) - def _insert(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: for n in graph_module.graph.nodes: + # do nothing when a node is expected to output a quant tensor + if n.meta.get(q_io_key): + continue + # insert q after input or fold mix_quantization dq if applicable if ( n.op == "placeholder" and n.meta.get("quant_attrs") and not is_parameter(n, self.edge_program) ): - self._fold_mix_quantization_dq_node(graph_module, n) self._insert_quant_node( graph_module, n, n.meta["quant_attrs"]["encoding"] ) @@ -149,10 +126,6 @@ def _insert(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: # insert dq before output or fold mix_quantization q if applicable users = list(n.users.keys()) if n.meta.get("quant_attrs") and any(user.op == "output" for user in users): - if n.target in q_ops: - self._fold_mix_quantization_q_node(graph_module, n, users) - # If q_node is fold, it will have no users, - # so it won't insert dequant node in following function. self._insert_dequant_node( graph_module, n, diff --git a/backends/qualcomm/passes/insert_requantize.py b/backends/qualcomm/passes/insert_requantize.py index d0169ebe357..c41bc16a5f5 100644 --- a/backends/qualcomm/passes/insert_requantize.py +++ b/backends/qualcomm/passes/insert_requantize.py @@ -6,11 +6,13 @@ import torch -from executorch.backends.qualcomm.passes.insert_io_qdq import InsertIOQDQ from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult +from .utils import q_io_key -class InsertRequantize(InsertIOQDQ): + +class InsertRequantize(ExportPass): """ This pass inserts dq/q nodes for non-arithmetic operators which have different quantization specs in input and activation @@ -26,10 +28,9 @@ class InsertRequantize(InsertIOQDQ): def __init__( self, edge_program: torch.export.ExportedProgram, - insert_requantize: bool = False, ): - super().__init__(edge_program) - self.insert_requantize = insert_requantize + super(InsertRequantize, self).__init__() + self.edge_program = edge_program # TODO: Implement this function when we have an op with # multiple outputs that requires quant attributes. @@ -39,16 +40,21 @@ def _multi_output_annotation(self) -> None: def _single_output_annotation( self, gm: torch.fx.GraphModule, n: torch.fx.node ) -> None: - dq_attr = n.meta["quant_attrs"] - q_attr = n.meta["requantize"] - # insert dq with given quantization attribute in input node - dq = self._insert_quant_node( - gm, n, InsertIOQDQ.q_dq_map[q_attr["encoding"]], dq_attr - ) - dq.meta["quant_attrs"] = dq_attr - # insert q with given quantization attribute in current node - q = self._insert_quant_node(gm, dq, q_attr["encoding"], q_attr) - q.meta["quant_attrs"] = q_attr + with gm.graph.inserting_after(n): + users = list(n.users.keys()) + inserted_n = gm.graph.create_node( + "call_function", + exir_ops.edge.aten._to_copy.default, + (n,), + ) + + inserted_n.meta["val"] = n.meta["val"] + inserted_n.meta["quant_attrs"] = n.meta.pop("requantize") + if n.meta.get(q_io_key): + inserted_n.meta[q_io_key] = n.meta[q_io_key] + + for user in users: + user.replace_input_with(n, inserted_n) def _insert(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: for n in graph_module.graph.nodes: @@ -59,3 +65,9 @@ def _insert(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: or n.target in self.multi_output_op_ignore_set else self._multi_output_annotation() ) + + def call(self, graph_module: torch.fx.GraphModule): + self._insert(graph_module) + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/passes/layout_transform.py b/backends/qualcomm/passes/layout_transform.py index 8c86f1919ad..fbf1431f1a5 100644 --- a/backends/qualcomm/passes/layout_transform.py +++ b/backends/qualcomm/passes/layout_transform.py @@ -52,6 +52,9 @@ class LayoutTransform(ExportPass): exir_ops.edge.aten.bmm.default, exir_ops.edge.aten.full.default, exir_ops.edge.aten.gelu.default, + exir_ops.edge.aten.sqrt.default, + exir_ops.edge.aten.sum.dim_IntList, + exir_ops.edge.aten.pow.Tensor_Scalar, *q_ops, *dq_ops, _operator.getitem, @@ -109,7 +112,10 @@ def is_layout_sensitive(self, node: torch.fx.Node) -> bool: return node.target in self.layout_sensitive_ops def is_layout_agnostic(self, node: torch.fx.Node) -> bool: - if node.target == exir_ops.edge.aten.mean.dim: + if node.target in [ + exir_ops.edge.aten.mean.dim, + exir_ops.edge.aten.sum.dim_IntList, + ]: # if dimemsion is not kept, we'll have no clue how to do layout transform if len(node.args) < 3 or not node.args[2]: return False diff --git a/backends/qualcomm/passes/utils.py b/backends/qualcomm/passes/utils.py index 49da1929e84..93e8a92a0d4 100755 --- a/backends/qualcomm/passes/utils.py +++ b/backends/qualcomm/passes/utils.py @@ -9,6 +9,8 @@ from executorch.exir.dialects._ops import ops as exir_ops +q_io_key = "q_tensor_io" + q_ops = { exir_ops.edge.quantized_decomposed.quantize_per_channel.default, exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, diff --git a/backends/qualcomm/qnn_preprocess.py b/backends/qualcomm/qnn_preprocess.py index 1fac8056e80..e95aaa0aaea 100644 --- a/backends/qualcomm/qnn_preprocess.py +++ b/backends/qualcomm/qnn_preprocess.py @@ -52,7 +52,7 @@ def preprocess( InsertIOQDQ(edge_program), LayoutTransform(edge_program, insert_permute=True), # please enable this when apply convert_linear_to_conv2d - # FuseConsecutiveTranspose(), + FuseConsecutiveTranspose(), ] ) @@ -94,6 +94,8 @@ def preprocess( ) assert len(qnn_context_binary) != 0, "Failed to generate Qnn context binary." qnn_manager.Destroy() + del py_op_wrapper_list + del qnn_manager # For now, debug_handle_map is not used by QNN ExecuTorch return PreprocessResult( processed_bytes=bytes(qnn_context_binary), debug_handle_map={} diff --git a/backends/qualcomm/quantizer/quantizer.py b/backends/qualcomm/quantizer/quantizer.py index 79e3e292057..69f9ad1d589 100644 --- a/backends/qualcomm/quantizer/quantizer.py +++ b/backends/qualcomm/quantizer/quantizer.py @@ -22,6 +22,7 @@ from .utils import ( get_16a4w_qnn_ptq_config, + get_16a8w_qnn_ptq_config, get_default_16bit_qnn_ptq_config, get_default_8bit_qnn_ptq_config, get_ptq_per_channel_weight_config, @@ -33,6 +34,7 @@ "QnnQuantizer", "QuantDtype", "get_16a4w_qnn_ptq_config", + "get_16a8w_qnn_ptq_config", "get_default_16bit_qnn_ptq_config", "get_default_8bit_qnn_ptq_config", ] diff --git a/backends/qualcomm/quantizer/utils.py b/backends/qualcomm/quantizer/utils.py index f9e0748b967..c9e21af767f 100644 --- a/backends/qualcomm/quantizer/utils.py +++ b/backends/qualcomm/quantizer/utils.py @@ -74,15 +74,18 @@ def _derive_bias_qparams_fn( ) -def get_default_8bit_qnn_ptq_config() -> QuantizationConfig: +def get_default_8bit_qnn_ptq_config(act_symmetric: bool = False) -> QuantizationConfig: extra_args: Dict[str, Any] = {"eps": 2**-12} act_quantization_spec = QuantizationSpec( dtype=torch.uint8, - quant_min=0, - quant_max=torch.iinfo(torch.uint8).max, - qscheme=torch.per_tensor_affine, - observer_or_fake_quant_ctr=MovingAverageMinMaxObserver.with_args(**extra_args), + quant_min=torch.iinfo(torch.uint8).min, + quant_max=torch.iinfo(torch.uint8).max + 1, + qscheme=( + torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine + ), + ch_axis=0, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), ) weight_quantization_spec = QuantizationSpec( @@ -113,7 +116,9 @@ def get_default_8bit_qnn_ptq_config() -> QuantizationConfig: # 4 bits quantization only supports specific ops. -def get_16a4w_qnn_ptq_config(act_observer=MovingAverageMinMaxObserver) -> QuantizationConfig: +def get_16a4w_qnn_ptq_config( + act_observer=MovingAverageMinMaxObserver, +) -> QuantizationConfig: extra_args: Dict[str, Any] = {"eps": 2**-20} act_quantization_spec = QuantizationSpec( dtype=torch.int32, @@ -150,7 +155,48 @@ def get_16a4w_qnn_ptq_config(act_observer=MovingAverageMinMaxObserver) -> Quanti return quantization_config -def get_default_16bit_qnn_ptq_config(act_observer=MovingAverageMinMaxObserver) -> QuantizationConfig: +def get_16a8w_qnn_ptq_config( + act_observer=MovingAverageMinMaxObserver, +) -> QuantizationConfig: + extra_args: Dict[str, Any] = {"eps": 2**-20} + act_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.uint16).min, + quant_max=torch.iinfo(torch.uint16).max, + qscheme=torch.per_tensor_affine, + observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), + ) + + weight_quantization_spec = QuantizationSpec( + dtype=torch.uint8, + quant_min=torch.iinfo(torch.uint8).min, + quant_max=torch.iinfo(torch.uint8).max + 1, + qscheme=torch.per_tensor_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + bias_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max, + qscheme=torch.per_tensor_symmetric, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + quantization_config = QuantizationConfig( + input_activation=act_quantization_spec, + output_activation=act_quantization_spec, + weight=weight_quantization_spec, + bias=bias_quantization_spec, + ) + + return quantization_config + + +def get_default_16bit_qnn_ptq_config( + act_observer=MovingAverageMinMaxObserver, +) -> QuantizationConfig: extra_args: Dict[str, Any] = {"eps": 2**-20} act_quantization_spec = QuantizationSpec( dtype=torch.int32, @@ -215,7 +261,7 @@ def get_ptq_per_channel_weight_config( quant_min=torch.iinfo(act_dtype).min, quant_max=torch.iinfo(act_dtype).max, qscheme=torch.per_tensor_affine, - observer_or_fake_quant_ctr=HistogramObserver.with_args(**extra_args), + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), ) weight_quantization_spec = QuantizationSpec( @@ -385,6 +431,11 @@ def annotate_rsub(node: Node, quantization_config: QuantizationConfig) -> None: annotate_binary(node, quantization_config) +@register_annotator([torch.ops.aten.sum.dim_IntList]) +def annotate_sum(node: Node, quantization_config: QuantizationConfig) -> None: + annotate_binary(node, quantization_config) + + @register_annotator([torch.ops.aten.ceil.default]) def annotate_ceil(node: Node, quantization_config: QuantizationConfig) -> None: annotate_single_in_single_out(node, quantization_config) @@ -455,12 +506,16 @@ def annotate_avgpool2d(node: Node, quantization_config: QuantizationConfig) -> N @register_annotator([torch.ops.aten.permute.default]) def annotate_permute(node: Node, quantization_config: QuantizationConfig) -> None: - annotate_single_in_single_out(node, quantization_config) + annotate_in_out_obs_sharing_op(node, quantization_config) + if not _is_annotated([node]): + annotate_single_in_single_out(node, quantization_config) @register_annotator([torch.ops.aten.view.default]) def annotate_view(node: Node, quantization_config: QuantizationConfig) -> None: - annotate_single_in_single_out(node, quantization_config) + annotate_in_out_obs_sharing_op(node, quantization_config) + if not _is_annotated([node]): + annotate_single_in_single_out(node, quantization_config) @register_annotator([torch.ops.aten.pixel_shuffle.default]) @@ -512,6 +567,11 @@ def annotate_slice(node: Node, quantization_config: QuantizationConfig) -> None: annotate_single_in_single_out(node, quantization_config) +@register_annotator([torch.ops.aten.sqrt.default]) +def annotate_sqrt(node: Node, quantization_config: QuantizationConfig) -> None: + annotate_single_in_single_out(node, quantization_config) + + @register_annotator([torch.ops.aten.gelu.default]) def annotate_gelu(node: Node, quantization_config: QuantizationConfig) -> None: annotate_single_in_single_out(node, quantization_config) diff --git a/backends/qualcomm/setup.md b/backends/qualcomm/setup.md index 18ebf412fc0..b78b481e86e 100644 --- a/backends/qualcomm/setup.md +++ b/backends/qualcomm/setup.md @@ -93,7 +93,6 @@ mkdir build_android cd build_android # build executorch & qnn_executorch_backend cmake .. \ - -DBUCK2=buck2 \ -DCMAKE_INSTALL_PREFIX=$PWD \ -DEXECUTORCH_BUILD_QNN=ON \ -DQNN_SDK_ROOT=$QNN_SDK_ROOT \ diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index a1eff816277..ba97240455f 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -122,8 +122,8 @@ def __init__( ) -> None: super().__init__() self.modules = [ - Conv2DSequential(), - Conv2DSequential(), + Conv2dSequential(), + Conv2dSequential(), Add(), Relu(), ] @@ -172,7 +172,7 @@ def forward(self, x, y): return CompositeReferenceModule(self.modules) -class Conv1DSequential(torch.nn.Module): +class Conv1dSequential(torch.nn.Module): def __init__(self): super().__init__() self.first = torch.nn.Conv1d( @@ -210,43 +210,6 @@ def forward(self, x): return x -class Conv2DSequential(torch.nn.Module): - def __init__(self): - super().__init__() - self.first = torch.nn.Conv2d( - in_channels=1, - out_channels=3, - kernel_size=(3, 3), - padding=1, - bias=True, - ) - self.second = torch.nn.Conv2d( - in_channels=3, - out_channels=2, - kernel_size=(3, 3), - padding=1, - bias=True, - ) - - def forward(self, x): - return self.second(self.first(x)) - - -class Conv2DSingle(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv2d( - in_channels=1, - out_channels=3, - kernel_size=(3, 3), - padding=1, - bias=True, - ) - - def forward(self, x): - return self.conv(x) - - class Conv2dAvgPool2d(torch.nn.Module): def __init__(self): super().__init__() @@ -321,6 +284,58 @@ def forward(self, x): return self.pool(self.conv(x)) +class Conv2dSequential(torch.nn.Module): + def __init__(self): + super().__init__() + self.first = torch.nn.Conv2d( + in_channels=1, + out_channels=3, + kernel_size=(3, 3), + padding=1, + bias=True, + ) + self.second = torch.nn.Conv2d( + in_channels=3, + out_channels=2, + kernel_size=(3, 3), + padding=1, + bias=True, + ) + + def forward(self, x): + return self.second(self.first(x)) + + +class Conv2dSingle(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d( + in_channels=1, + out_channels=3, + kernel_size=(3, 3), + padding=1, + bias=True, + ) + + def forward(self, x): + return self.conv(x) + + +class Conv2dSumReduceDim(torch.nn.Module): + def __init__(self): + super().__init__() + self.first = torch.nn.Conv2d( + in_channels=1, + out_channels=3, + kernel_size=(3, 3), + padding=1, + bias=True, + ) + + def forward(self, x): + return torch.sum(self.first(x), dim=(2, 3), keepdim=False) + + class Div(torch.nn.Module): def __init__(self): super().__init__() @@ -691,7 +706,7 @@ def __init__(self): super().__init__() def forward(self, x): - return x / torch.sqrt(torch.tensor([64])) + return x / torch.sqrt(torch.tensor([64.0])) class Squeeze(torch.nn.Module): @@ -748,6 +763,14 @@ def forward(self, x): return 10 - x +class SumIntList(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.sum(x, dim=(2, 3), keepdim=True) + + class Tanh(torch.nn.Module): def __init__(self): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index d539827fdb9..3874da9e981 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -95,12 +95,12 @@ def test_qnn_backend_clamp(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_conv1d(self): - module = Conv1DSequential() # noqa: F405 + module = Conv1dSequential() # noqa: F405 sample_input = (torch.randn([1, 1, 3]),) self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_conv2d(self): - module = Conv2DSequential() # noqa: F405 + module = Conv2dSequential() # noqa: F405 sample_input = (torch.randn([1, 1, 3, 3]),) self.lower_module_and_test_output(module, sample_input) @@ -183,11 +183,10 @@ def test_qnn_backend_element_wise_mul(self): self.lower_module_and_test_output(module, sample_input) index += 1 - @unittest.skip("not yet implemented") def test_qnn_backend_element_wise_sqrt(self): modules = [Sqrt(), SqrtConstant()] # noqa: F405 - sample_input = (torch.randn([3, 1]),) for i, module in enumerate(modules): + sample_input = (torch.rand([3, 1]),) with self.subTest(i=i): self.lower_module_and_test_output(module, sample_input) @@ -357,6 +356,11 @@ def test_qnn_backend_squeeze(self): sample_input = (torch.randn([1, 3, 3]),) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_sum_int_list(self): + module = SumIntList() # noqa: F405 + sample_input = (torch.randn([1, 4, 8, 8]),) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_tanh(self): module = Tanh() # noqa: F405 sample_input = (torch.randn(2, 5, 1, 3),) @@ -421,6 +425,11 @@ def test_qnn_backend_conv2d_max_pool2d(self): sample_input = (torch.rand(1, 2, 14, 14),) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_conv2d_sum_reduce_dim(self): + module = Conv2dSumReduceDim() # noqa: F405 + sample_input = (torch.randn([1, 1, 3, 3]),) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_residual_block(self): module = ResidualBlockModule() # noqa: F405 sample_input = (torch.randn(1, 32, 28, 28),) @@ -494,7 +503,7 @@ def setUp(self): ) def test_qnn_backend_16a4w_conv2d(self): - module = Conv2DSingle() # noqa: F405 + module = Conv2dSingle() # noqa: F405 sample_input = (torch.randn([1, 1, 3, 3]),) module = self.get_qdq_module( module, sample_input, quant_dtype=QuantDtype.use_16a4w @@ -575,13 +584,13 @@ def test_qnn_backend_clamp(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_conv1d(self): - module = Conv1DSequential() # noqa: F405 + module = Conv1dSequential() # noqa: F405 sample_input = (torch.randn([1, 1, 3]),) module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_conv2d(self): - module = Conv2DSequential() # noqa: F405 + module = Conv2dSequential() # noqa: F405 sample_input = (torch.randn([1, 1, 3, 3]),) module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) @@ -669,11 +678,10 @@ def test_qnn_backend_element_wise_mul(self): self.lower_module_and_test_output(module, sample_input) index += 1 - @unittest.skip("not yet implemented") def test_qnn_backend_element_wise_sqrt(self): modules = [Sqrt(), SqrtConstant()] # noqa: F405 - sample_input = (torch.randn([3, 1]),) for i, module in enumerate(modules): + sample_input = (torch.rand([3, 1]),) with self.subTest(i=i): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) @@ -873,6 +881,12 @@ def test_qnn_backend_stack(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_sum_int_list(self): + module = SumIntList() # noqa: F405 + sample_input = (torch.randn([1, 4, 8, 8]),) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_tanh(self): module = Tanh() # noqa: F405 sample_input = (torch.randn(2, 5, 1, 3),) @@ -946,6 +960,12 @@ def test_qnn_backend_conv2d_max_pool2d(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_conv2d_sum_reduce_dim(self): + module = Conv2dSumReduceDim() # noqa: F405 + sample_input = (torch.randn([1, 1, 3, 3]),) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_example_models(self): instances = [ {"module": DeepLabV3ResNet101Model(), "annotation": ()}, @@ -1095,6 +1115,7 @@ def test_qnn_backend_multi_contexts_composite(self): exec_prog = edge_prog.to_executorch() self.verify_output(module.get_reference_module(), sample_input, exec_prog) + @unittest.expectedFailure def test_qnn_backend_profile_op(self): TestQNN.enable_profile = True backend_options = generate_htp_compiler_spec(use_fp16=True) @@ -1227,6 +1248,7 @@ def test_qnn_backend_multi_contexts_composite(self): exec_prog = edge_prog.to_executorch() self.verify_output(module.get_reference_module(), sample_input, exec_prog) + @unittest.expectedFailure def test_qnn_backend_profile_op(self): TestQNN.enable_profile = True backend_options = generate_htp_compiler_spec(use_fp16=False) @@ -1323,6 +1345,40 @@ def test_fbnet(self): self.assertGreaterEqual(msg["top_1"], 60) self.assertGreaterEqual(msg["top_5"], 90) + def test_ssd300_vgg16(self): + if not self.required_envs([self.pretrained_weight, self.oss_repo]): + self.skipTest("missing required envs") + + cmds = [ + "python", + f"{self.executorch_root}/examples/qualcomm/oss_scripts/ssd300_vgg16.py", + "--artifact", + self.artifact_dir, + "--build_folder", + self.build_folder, + "--device", + self.device, + "--model", + self.model, + "--oss_repo", + self.oss_repo, + "--pretrained_weight", + self.pretrained_weight, + "--ip", + self.ip, + "--port", + str(self.port), + ] + if self.host: + cmds.extend(["--host", self.host]) + + p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) + with Listener((self.ip, self.port)) as listener: + conn = listener.accept() + p.communicate() + msg = json.loads(conn.recv()) + self.assertGreaterEqual(msg["mAP"], 0.70) + class TestExampleScript(TestQNN): def required_envs(self, conditions=None) -> bool: @@ -1771,6 +1827,11 @@ def setup_environment(): help="Emit log only when error happened", action="store_true", ) + parser.add_argument( + "--oss_repo", + help="Path to open source software model repository", + type=str, + ) args, ns_args = parser.parse_known_args(namespace=unittest) TestQNN.host = args.host @@ -1785,6 +1846,7 @@ def setup_environment(): TestQNN.online_prepare = args.online_prepare TestQNN.enable_profile = args.enable_profile TestQNN.error_only = args.error_only + TestQNN.oss_repo = args.oss_repo TestQNN.shared_buffer = args.shared_buffer return sys.argv[:1] + ns_args diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py index 4a38fd3a24a..bfeb1bb649d 100644 --- a/backends/qualcomm/utils/utils.py +++ b/backends/qualcomm/utils/utils.py @@ -69,13 +69,11 @@ def __init__(self, weight, bias=None): self.conv = torch.nn.Conv2d( in_channels=weight.shape[0], out_channels=weight.shape[1], - kernel_size=(1, 1), + kernel_size=1, padding=0, bias=use_bias, ) - self.conv.weight = torch.nn.Parameter( - weight.reshape(*weight.shape, 1, 1) - ) + self.conv.weight = torch.nn.Parameter(weight.reshape(*weight.shape, 1, 1)) if use_bias: self.conv.bias = torch.nn.Parameter(bias) @@ -89,10 +87,13 @@ def forward(self, x): return res def replace_linear(module: torch.nn.Module): - for attr_str in dir(module): + attr_strs = dir(module) + if type(module) == torch.nn.ModuleList: + attr_strs += [str(i) for i in range(len(module))] + + for attr_str in attr_strs: target_attr = getattr(module, attr_str) if type(target_attr) == torch.nn.Linear: - print('replaced: ', attr_str) setattr(module, attr_str, Conv2D(target_attr.weight, target_attr.bias)) for _, sub_module in module.named_children(): @@ -223,6 +224,7 @@ def generate_htp_compiler_spec( # TODO: enable voting mechanism in runtime and make this as an option htp_options.performance_mode = QnnExecuTorchHtpPerformanceMode.kHtpBurst htp_options.use_multi_contexts = use_multi_contexts + htp_options.max_sf_buf_size = 73859072 htp_options.use_dlbc = use_dlbc return QnnExecuTorchBackendOptions( backend_type=QnnExecuTorchBackendType.kHtpBackend, diff --git a/examples/qualcomm/executor_runner/qnn_llama_runner.cpp b/examples/qualcomm/executor_runner/qnn_llama_runner.cpp index 9733baf5c12..ab17a551d49 100644 --- a/examples/qualcomm/executor_runner/qnn_llama_runner.cpp +++ b/examples/qualcomm/executor_runner/qnn_llama_runner.cpp @@ -25,9 +25,11 @@ #include #include +#include +#include DEFINE_string( - model_path, + model_paths, "qnn_llama2.pte", "Model serialized in flatbuffer format."); @@ -57,22 +59,30 @@ int main(int argc, char** argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); - const char* model_path = FLAGS_model_path.c_str(); + std::vector model_path_list; + std::istringstream f(FLAGS_model_paths); + std::string s; + while (getline(f, s, ',')) { + model_path_list.push_back(s); + } + const char* tokenizer_path = FLAGS_tokenizer_path.c_str(); const char* prompt = FLAGS_prompt.c_str(); double temperature = FLAGS_temperature; int32_t seq_len = FLAGS_seq_len; // create llama runner - Runner runner(model_path, tokenizer_path, temperature); + Runner runner(model_path_list, tokenizer_path, temperature); ET_CHECK_MSG(runner.load() == Error::Ok, "Runner failed to load method"); // MethodMeta describes the memory requirements of the method. - Result method_meta = runner.method_meta(); - ET_CHECK_MSG( + std::vector> method_metas = runner.get_methods_meta(); + for(auto& method_meta: method_metas){ + ET_CHECK_MSG( method_meta.ok(), "Failed to get method_meta 0x%x", (unsigned int)method_meta.error()); + } // Fill in data for input std::ifstream input_list(FLAGS_input_list_path); @@ -93,47 +103,106 @@ int main(int argc, char** argv) { }; std::string file_path; - size_t inference_index = 0, num_inputs = method_meta->num_inputs(); - std::vector> inputs(num_inputs); + size_t inference_index = 0; + std::vector> freqs_inputs(2); + std::vector>> inputs(method_metas.size()-2); + + for (int i = 1; i < method_metas.size()-1; i++){ + size_t num_inputs = method_metas[i]->num_inputs(); + inputs[i-1].resize(num_inputs); + } + while (std::getline(input_list, file_path)) { auto input_files = split(file_path, " "); if (input_files.size() == 0) { break; } - // inputs: [tokens, pos_ids, kv_mask, *k_cache, *v_cache] + // inputs: [tokens, pos_ids, freqs_cos, freqs_sin, atten_mask, *k_cache, *v_cache] // tokens are determined by command line arguments - // pos_ids are infered inside runner - std::vector managed_inputs; - for (int input_index = 2; input_index < num_inputs; ++input_index) { - Result tensor_meta = - method_meta->input_tensor_meta(input_index); + // pos_ids, atten_mask are infered inside runner + for (int input_index = 2; input_index < 4; ++input_index) { std::ifstream fin(input_files[input_index], std::ios::binary); fin.seekg(0, fin.end); size_t file_size = fin.tellg(); - ET_CHECK_MSG( - file_size == tensor_meta->nbytes(), - "Input(%d) size mismatch. file bytes: %zu, tensor bytes: %zu", - input_index, - file_size, - tensor_meta->nbytes()); - - inputs[input_index].resize(tensor_meta->nbytes()); + freqs_inputs[input_index-2].resize(file_size / sizeof(float)); + fin.seekg(0, fin.beg); + fin.read(reinterpret_cast(freqs_inputs[input_index-2].data()), file_size); fin.close(); + } - auto tensor_shape = tensor_meta->sizes(); - std::vector sizes( - tensor_shape.data(), tensor_shape.data() + tensor_shape.size()); - - managed_inputs.emplace_back(ManagedTensor( - inputs[input_index].data(), 128, sizes, tensor_meta->scalar_type())); + std::vector> managed_kv_inputs(method_metas.size()-2); + for (int i = 1; i < method_metas.size()-1; ++i){ + size_t num_inputs = method_metas[i]->num_inputs(); + const int k_caches_end = (num_inputs - 4) / 2; + + // TODO: need to handle batch size != 1 + // k caches init + for (int input_index = 4; input_index < k_caches_end; ++input_index) { + Result tensor_meta = + method_metas[i]->input_tensor_meta(input_index); + int file_index = (i-1) * (num_inputs - 4) + input_index + 1; + std::ifstream fin(input_files[file_index], std::ios::binary); + fin.seekg(0, fin.end); + size_t file_size = fin.tellg(); + + ET_CHECK_MSG( + file_size == tensor_meta->nbytes(), + "Input(%d) size mismatch. file bytes: %zu, tensor bytes: %zu", + file_index, + file_size, + tensor_meta->nbytes()); + + // to simplify kv_cache update logic, we use (bsz, head_dim+2, seq) + // for fast pointer shifting + // head_dim+1 is the buffer of last word + // head_dim+2 is for output + inputs[i-1][input_index].resize(tensor_meta->nbytes() + 2*(tensor_meta->nbytes()/tensor_meta->sizes()[1])); + fin.close(); + + auto tensor_shape = tensor_meta->sizes(); + std::vector sizes( + tensor_shape.data(), tensor_shape.data() + tensor_shape.size()); + + managed_kv_inputs[i-1].emplace_back(ManagedTensor( + inputs[i-1][input_index].data(), 128, sizes, tensor_meta->scalar_type())); + } + + // v caches init + for (int input_index = k_caches_end; input_index < num_inputs; ++input_index) { + Result tensor_meta = + method_metas[i]->input_tensor_meta(input_index); + int file_index = (i-1) * (num_inputs - 4) + input_index + 1; + std::ifstream fin(input_files[file_index], std::ios::binary); + fin.seekg(0, fin.end); + size_t file_size = fin.tellg(); + + ET_CHECK_MSG( + file_size == tensor_meta->nbytes(), + "Input(%d) size mismatch. file bytes: %zu, tensor bytes: %zu", + file_index, + file_size, + tensor_meta->nbytes()); + + // to simplify v_cache update logic, we use (bsz, 2*max_seq_len, head_dim) + // for fast pointer shifting + inputs[i-1][input_index].resize(2*tensor_meta->nbytes()); + fin.close(); + + auto tensor_shape = tensor_meta->sizes(); + std::vector sizes( + tensor_shape.data(), tensor_shape.data() + tensor_shape.size()); + + managed_kv_inputs[i-1].emplace_back(ManagedTensor( + inputs[i-1][input_index].data(), 128, sizes, tensor_meta->scalar_type())); + } } // generate tokens std::string inference_output; runner.generate( - prompt, seq_len, managed_inputs, [&](const std::string& piece) { + prompt, seq_len, managed_kv_inputs, freqs_inputs, [&](const std::string& piece) { inference_output += piece; }); diff --git a/examples/qualcomm/llama2/composite_llama.py b/examples/qualcomm/llama2/composite_llama.py new file mode 100644 index 00000000000..da4d19f7480 --- /dev/null +++ b/examples/qualcomm/llama2/composite_llama.py @@ -0,0 +1,873 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import codecs +import gc +import getpass +import json +import os +import shutil +import stat +import sys +from pathlib import Path + +sys.setrecursionlimit(4096) + +import time +from typing import List, Tuple + +import torch + +from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner +from executorch.backends.qualcomm.passes.build_quant_io import BuildQuantIo +from executorch.backends.qualcomm.passes.utils import q_io_key + +from executorch.backends.qualcomm.quantizer.quantizer import QnnQuantizer, QuantDtype +from executorch.backends.qualcomm.quantizer.utils import get_16a4w_qnn_ptq_config +from executorch.backends.qualcomm.serialization.qnn_compile_spec_schema import ( + QcomChipset, +) +from executorch.backends.qualcomm.utils.utils import ( + capture_program, + convert_linear_to_conv2d, + generate_htp_compiler_spec, + generate_qnn_executorch_compiler_spec, +) +from executorch.examples.models.llama2.builder import DType +from executorch.examples.models.llama2.llama_transformer import precompute_freqs_cis +from executorch.examples.qualcomm.llama2.model.static_llama import LlamaModel, ModelArgs +from executorch.examples.qualcomm.scripts.utils import ( + make_output_dir, + setup_common_args_and_variables, + SimpleADB, +) +from executorch.exir import EdgeCompileConfig, EdgeProgramManager +from executorch.exir.capture._config import ExecutorchBackendConfig +from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass + +from sentencepiece import SentencePieceProcessor +from torch.ao.quantization.observer import MinMaxObserver +from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e + + +def annotate_matmul_16a8w(gm: torch.fx.GraphModule) -> None: + """ + This function is specific for matmul op 16a8w. + """ + from typing import Sequence + + from executorch.backends.qualcomm.quantizer.quantizer import ( + get_16a8w_qnn_ptq_config, + get_default_8bit_qnn_ptq_config, + QuantizationConfig, + ) + from executorch.backends.qualcomm.quantizer.utils import QUANT_ANNOTATION_KEY + from torch.ao.quantization.quantizer import ( + QuantizationAnnotation, + SharedQuantizationSpec, + ) + from torch.fx import Node + + def annotate_matmul(node: Node, quantization_config: QuantizationConfig): + input_qspec_map = {} + input_act = node.args[0] + assert isinstance(input_act, Node) + input_spec = quantization_config.input_activation + input_qspec_map[input_act] = input_spec + + input_act1 = node.args[1] + input_spec1 = quantization_config.weight + input_qspec_map[input_act1] = input_spec1 + + node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=quantization_config.output_activation, + _annotated=True, + ) + + def annotate_cat(node: Node, quantization_config: QuantizationConfig): + input_nodes = node.args[0] + + assert isinstance(input_nodes, Sequence) + + first_input_node = input_nodes[0] + input_qspec_map = {} + assert isinstance(first_input_node, Node) + assert isinstance(node, Node) + input_qspec_map[first_input_node] = quantization_config.input_activation + share_qparams_with_input_act0_qspec = SharedQuantizationSpec( + (first_input_node, node) + ) + + for input_node in input_nodes[1:]: + if input_node not in input_qspec_map: + assert isinstance(input_node, Node) + input_qspec_map[input_node] = share_qparams_with_input_act0_qspec + + node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=share_qparams_with_input_act0_qspec, + _annotated=True, + ) + + def annotate_single_in_single_out( + node: Node, quantization_config: QuantizationConfig + ) -> None: + + input_qspec_map = {} + input_act = node.args[0] + assert isinstance(input_act, Node) + input_qspec_map[input_act] = quantization_config.input_activation + + node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=quantization_config.output_activation, + _annotated=True, + ) + + def annotate_matmul_input1(node: Node): + quantization_config_8a8w = get_default_8bit_qnn_ptq_config(act_symmetric=True) + while isinstance(node, Node) and node.op == "call_function": + if node.target in [ + torch.ops.aten.permute.default, + torch.ops.aten.transpose.int, + ]: + annotate_single_in_single_out(node, quantization_config_8a8w) + node = node.args[0] + elif node.target == torch.ops.aten.cat.default: + annotate_cat(node, quantization_config_8a8w) + node = node.args[0][0] + else: + node = node.args[0] + + quantization_config_16a8w = get_16a8w_qnn_ptq_config() + + for node in gm.graph.nodes: + if node.op == "call_function" and node.target == torch.ops.aten.matmul.default: + annotate_matmul(node, quantization_config_16a8w) + annotate_matmul_input1(node.args[1]) + + +def annotate_linear_16a8w_in_affine_layer(gm: torch.fx.GraphModule) -> None: + from executorch.backends.qualcomm.quantizer.quantizer import ( + get_ptq_per_channel_weight_config, + QuantizationConfig, + ) + from executorch.backends.qualcomm.quantizer.utils import QUANT_ANNOTATION_KEY + from torch.ao.quantization.quantizer import QuantizationAnnotation + from torch.fx import Node + + def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None: + input_qspec_map = {} + input_act = node.args[0] + assert isinstance(input_act, Node) + input_spec = quantization_config.input_activation + input_qspec_map[input_act] = input_spec + + weight = node.args[1] + assert isinstance(weight, Node) + input_qspec_map[weight] = quantization_config.weight + + if len(node.args) > 2: + bias = node.args[2] + if isinstance(bias, Node): + if callable(quantization_config.bias): + input_qspec_map[bias] = quantization_config.bias(node) + else: + input_qspec_map[bias] = quantization_config.bias + + node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=quantization_config.output_activation, + _annotated=True, + ) + + quantization_config_16a8w_per_channel = get_ptq_per_channel_weight_config( + torch.uint16, weight_dtype=torch.int8 + ) + for node in gm.graph.nodes: + if node.op == "call_function" and node.target == torch.ops.aten.conv2d.default: + if "nn_module_stack" in node.meta: + module_values_list = list(node.meta["nn_module_stack"].values()) + full_qualified_name = module_values_list[0][0] + if full_qualified_name == "L['self'].llama.output": + annotate_conv2d( + node, quantization_config=quantization_config_16a8w_per_channel + ) + + +def calibrate( + example_inputs, n_heads, layers_per_ctx, modules: List[torch.fx.GraphModule] +): + sp_model = SentencePieceProcessor(model_file="tokenizer.model") + _, _, freqs_cos, freqs_sin, atten_mask, k_caches, v_caches = example_inputs + + # TODO: change criteria & support batch inputs if necessary + pos = torch.tensor(0, dtype=torch.int32) + token_list = [sp_model.bos_id()] + user_prompts = ["Once"] + for prompt in user_prompts: + token_list += sp_model.encode(prompt) + + def sample_top_p(probs: torch.Tensor, top_p: float) -> torch.Tensor: + probs_sort, probs_indices = torch.sort(probs, dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + mask = probs_sum - probs_sort > top_p + probs_sort[mask] = 0 + probs_sort /= probs_sort.sum(dim=-1, keepdim=True) + next_token = torch.multinomial(probs_sort, num_samples=1) + return probs_indices.gather(dim=-1, index=next_token) + + with torch.no_grad(): + while token_list[-1] != sp_model.eos_id() and pos < 128: + hidden_states = modules[0](torch.full((1, 1), token_list[pos])) + input_pos = torch.full((1, 1), pos) + k_caches_o_list = [] + v_caches_o_list = [] + for i, decode_module in enumerate(modules[1:-1]): + offset = i * layers_per_ctx * n_heads + k_caches_i = k_caches[offset : offset + layers_per_ctx * n_heads] + v_caches_i = v_caches[offset : offset + layers_per_ctx * n_heads] + hidden_states, k_caches_o, v_caches_o = decode_module( + hidden_states, + freqs_cos[input_pos][0], + freqs_sin[input_pos][0], + atten_mask, + k_caches_i, + v_caches_i, + ) + k_caches_o_list.extend(k_caches_o) + v_caches_o_list.extend(v_caches_o) + + logits = modules[-1](hidden_states) + # k_caches have been transposed ahead, the shpae is [batch, head_dim, seq-1] + k_caches = [ + torch.cat([k_cache[:, :, 1:], k_caches_o_list[i]], dim=-1) + for i, k_cache in enumerate(k_caches) + ] + v_caches = [ + torch.cat([v_cache[:, 1:, :], v_caches_o_list[i]], dim=1) + for i, v_cache in enumerate(v_caches) + ] + + pos += 1 + atten_mask[0][-pos - 1] = 0 + if pos >= len(token_list): + token_list.append(torch.argmax(logits[:, -1], dim=-1).item()) + + print(f"calibration data:\n{sp_model.decode(token_list)}") + + +class CompositeLlama: + def __init__(self, division, llama_model) -> None: + super().__init__() + self.division = division + self.layers_per_ctx = llama_model.n_layers // division + self.llama_model = llama_model + self.quant_dtype = None + self.split_modules, self.split_inputs = [], [] + self.llama_meta = self.llama_model.get_metadata() + self.has_quant_io = False + + def split_llama(self): + def get_block_module(llama, indexes): + class LlamaBlock(torch.nn.Module): + def __init__(self, llama, indexes) -> None: + super().__init__() + self.llama = llama + self.indexes = indexes + + def forward( + self, + hidden_states: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + atten_mask: torch.Tensor, + k_caches: List[torch.Tensor], + v_caches: List[torch.Tensor], + ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]: + output_k_cache, output_v_cache = [], [] + for i, ind in enumerate(self.indexes): + offset = i * self.llama.n_heads + k_in = k_caches[offset : offset + self.llama.n_heads] + v_in = v_caches[offset : offset + self.llama.n_heads] + hidden_states, k, v = self.llama.layers[ind]( + x=hidden_states, + freqs_cos=freqs_cos, + freqs_sin=freqs_sin, + atten_mask=atten_mask, + k_caches=k_in, + v_caches=v_in, + ) + output_k_cache.extend(k) + output_v_cache.extend(v) + + return hidden_states, output_k_cache, output_v_cache + + return LlamaBlock(llama, indexes) + + def get_affine_module(llama): + class LlamaAffine(torch.nn.Module): + def __init__(self, llama) -> None: + super().__init__() + self.llama = llama + + def forward(self, hidden_states): + hidden_states = self.llama.norm(hidden_states) + logits = self.llama.output(hidden_states) + return logits + + return LlamaAffine(llama) + + tokens, pos_ids, freqs_cos, freqs_sin, atten_mask, k_caches, v_caches = ( + self.get_example_inputs() + ) + + with torch.no_grad(): + # embedding + self.split_modules.append(self.llama_model.tok_embeddings) + self.split_inputs.append((tokens,)) + + # attentions + for i in range(self.division): + llama_block = get_block_module( + self.llama_model, + [*range(self.layers_per_ctx * i, self.layers_per_ctx * (i + 1))], + ) + offset = i * self.layers_per_ctx * self.llama_model.n_heads + k_caches_in = k_caches[ + offset : offset + self.layers_per_ctx * self.llama_model.n_heads + ] + v_caches_in = v_caches[ + offset : offset + self.layers_per_ctx * self.llama_model.n_heads + ] + self.split_modules.append(llama_block) + self.split_inputs.append( + ( + self.llama_model.tok_embeddings(tokens), + freqs_cos[pos_ids][0], + freqs_sin[pos_ids][0], + atten_mask, + k_caches_in, + v_caches_in, + ) + ) + + # affine layer + affine_block = get_affine_module(self.llama_model) + self.split_modules.append(affine_block) + self.split_inputs.append((self.llama_model.tok_embeddings(tokens),)) + + def _tag_kv_ios(self, gm: torch.fx.GraphModule, kv_type=torch.float32): + if not self.has_quant_io: + return + + # shape of k caches and v caches + input_cache_shape = { + (self.llama_meta["get_head_dim"], self.llama_meta["get_max_seq_len"]), + (self.llama_meta["get_max_seq_len"], self.llama_meta["get_head_dim"]), + } + for n in gm.graph.nodes: + if ( + n.op == "placeholder" + and len(users := list(n.users)) == 1 + and users[0].meta["val"].size()[-2:] in input_cache_shape + ): + n.meta[q_io_key] = kv_type + elif n.op == "output": + for a in n.args[0]: + if ( + a.meta["val"].flatten().size()[0] + == self.llama_meta["get_head_dim"] + ): + a.meta[q_io_key] = kv_type + + def quantize(self, quant_dtype, custom_annotations=()): + self.quant_dtype = quant_dtype + quantizer = QnnQuantizer() + quantizer.set_per_channel_linear_quant(True) + quantizer.set_per_channel_conv_quant(True) + + if quant_dtype == QuantDtype.use_8a8w: + pass # default setting + elif quant_dtype == QuantDtype.use_16a4w: + quantizer.add_16bit_quant_ops(quantizer.SUPPORTED_OPS) + quantizer.set_bit16_op_quant_config( + get_16a4w_qnn_ptq_config(act_observer=MinMaxObserver) + ) + quantizer.set_per_channel_weight_dtype(weight_dtype_for_16bit_act="int4") + else: + raise AssertionError(f"No support for QuantDtype {quant_dtype}.") + quantizer.add_custom_quant_annotations(custom_annotations) + + self.has_quant_io = True + split_fx_graph_modules = [] + + with torch.no_grad(): + for nn_module, capture_inputs in zip(self.split_modules, self.split_inputs): + fx_graph_module = torch._export.capture_pre_autograd_graph( + nn_module, capture_inputs + ) + fx_graph_module = prepare_pt2e(fx_graph_module, quantizer) + split_fx_graph_modules.append(fx_graph_module) + print("Quantizing the model...") + calibrate( + self.get_example_inputs(), + self.llama_model.n_heads, + self.layers_per_ctx, + split_fx_graph_modules, + ) + + self.split_modules = [ + convert_pt2e(fx_graph_module) for fx_graph_module in split_fx_graph_modules + ] + del self.llama_model + + def lowering_modules(self, work_space, kv_type=torch.float32): + + executorch_config = ExecutorchBackendConfig( + passes=[ + BuildQuantIo(), + ], + extract_constant_segment=False, + # For shared buffer, user must pass the memory address + # which is allocated by RPC memory to executor runner. + # Therefore, won't want to pre-allocate + # by memory manager in runtime. + memory_planning_pass=MemoryPlanningPass( + memory_planning_algo="greedy", + alloc_graph_input=False, + alloc_graph_output=False, + ), + extract_delegate_segments=True, + ) + pte_filename_list = [] + index = len(self.split_modules) + with torch.no_grad(): + while index > 0: + # backend option + backend_options = generate_htp_compiler_spec( + use_fp16=True if self.quant_dtype is None else False, + use_multi_contexts=True, + ) + compiler_specs = generate_qnn_executorch_compiler_spec( + soc_model=QcomChipset.SM8650, + backend_options=backend_options, + # saver=True if index==5 else False + ) + partitioner = QnnPartitioner(compiler_specs) + pte_filename = f"llama2_qnn_{index-1}" + edge_prog = capture_program( + self.split_modules[index - 1], self.split_inputs[index - 1] + ) + self._tag_kv_ios( + edge_prog.exported_program.graph_module, kv_type=kv_type + ) + edge_prog_mgr = EdgeProgramManager( + edge_programs={"forward": edge_prog.exported_program}, + constant_methods=self.llama_meta, + compile_config=EdgeCompileConfig(_check_ir_validity=False), + ) + edge_prog_mgr = edge_prog_mgr.to_backend(partitioner) + exec_prog_mgr = edge_prog_mgr.to_executorch(config=executorch_config) + with open(f"{work_space}/{pte_filename}.pte", "wb") as file: + exec_prog_mgr.write_to_file(file) + + del edge_prog + del edge_prog_mgr + del exec_prog_mgr + self.split_modules.pop() + self.split_inputs.pop() + gc.collect(generation=2) + pte_filename_list.insert(0, f"{work_space}/{pte_filename}.pte") + index -= 1 + return pte_filename_list + + def get_example_inputs(self): + tokens, pos_ids, atten_mask, k_caches, v_caches = ( + self.llama_model.get_example_inputs() + ) + freqs_cos, freqs_sin = precompute_freqs_cis( + self.llama_model.dim // self.llama_model.n_heads, + self.llama_model.max_seq_len, + self.llama_model.rope_freq_base, + ) + return (tokens, pos_ids, freqs_cos, freqs_sin, atten_mask, k_caches, v_caches) + + def get_export_inputs(self): + tokens, pos_ids, atten_mask, k_caches, v_caches = ( + self.llama_model.get_export_inputs() + ) + freqs_cos, freqs_sin = precompute_freqs_cis( + self.llama_model.dim // self.llama_model.n_heads, + self.llama_model.max_seq_len, + self.llama_model.rope_freq_base, + ) + export_inputs = [tokens, pos_ids, freqs_cos, freqs_sin, atten_mask] + for i in range(self.division): + offset = i * self.layers_per_ctx * self.llama_model.n_heads + k_caches_in = k_caches[ + offset : offset + self.layers_per_ctx * self.llama_model.n_heads + ] + v_caches_in = v_caches[ + offset : offset + self.layers_per_ctx * self.llama_model.n_heads + ] + export_inputs.append(k_caches_in) + export_inputs.append(v_caches_in) + + return tuple(export_inputs) + + +def create_device_inputs(example_inputs, kv_input_numel, kv_type=torch.float32): + # TODO: support batch inputs if necessary + input_list = "" + inputs, flat_inputs = [], [] + for input in example_inputs: + if isinstance(input, list): + for inp in input: + flat_inputs.append(inp) + else: + flat_inputs.append(input) + + for i, data in enumerate(flat_inputs): + input_list += f"input_0_{i}.raw " + if data.flatten().shape[0] == kv_input_numel: + data = data.to(dtype=kv_type) + inputs.append(data) + + input_list += "\n" + return tuple(inputs), input_list + + +if __name__ == "__main__": + parser = setup_common_args_and_variables() + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts by this example. Default ./llama2_qnn", + default="./llama2_qnn", + type=str, + ) + + parser.add_argument( + "-F", + "--use_fp16", + help="If specified, will run in fp16 precision and discard ptq setting", + action="store_true", + default=False, + ) + + parser.add_argument( + "-P", + "--ptq", + help="If specified, will do PTQ quantization. default is 16bits activation and 4bits weight. Support 8a8w and 16a4w.", + default="16a4w", + ) + + parser.add_argument( + "--checkpoint", + help="Pass llama2 checkpoint.", + required=True, + type=str, + ) + + parser.add_argument( + "--params", + help="Pass llama2 params json file.", + required=True, + type=str, + ) + + parser.add_argument( + "--tokenizer_bin", + help="Pass llama2 tokenizer binary.", + required=True, + type=str, + ) + + parser.add_argument( + "--tokenizer_model", + help="Pass llama2 tokenizer model.", + type=str, + default=None, + ) + + parser.add_argument( + "--prompt", + help="User prompts for llama2.", + required=True, + type=str, + ) + + parser.add_argument( + "--seq_len", + help="Ouput sequence length for llama2.", + default=128, + type=int, + ) + + parser.add_argument( + "--temperature", + help="Sampling temperature for llama2.", + default=0.8, + type=float, + ) + + parser.add_argument( + "-d", + "--dtype-override", + default="fp32", + type=str, + choices=["fp32", "fp16"], + help="Override the dtype of the model (default is the checkpoint dtype). Options: fp32", + ) + + parser.add_argument( + "--pre_gen_pte", + help="Pre-generated llama2.", + type=str, + ) + + args = parser.parse_args() + division = 4 + # ensure the working directory exist. + os.makedirs(args.artifact, exist_ok=True) + start_ts = time.time() + with open(args.params) as f: + config = ModelArgs(**json.load(f)) + # TODO: support batch inputs if necessary + config.max_batch_size = 1 + config.max_seq_len = 1024 + device = "cpu" + state_dict = torch.load(args.checkpoint, map_location=device, mmap=True) + end_load_ts = time.time() + print("torch.load checkpoint", end_load_ts - start_ts) + llama_instance = None + with torch.device("meta"): + llama_instance = LlamaModel(config, output_new_cache_only=True) + if "model" in state_dict: + state_dict = state_dict["model"] + llama_instance.load_state_dict( + state_dict, + strict=False, + assign=True, + ) + end_load_state_dict_ts = time.time() + print("instance.load_state_dict", end_load_state_dict_ts - end_load_ts) + + for l in llama_instance.layers: + if getattr(l.attention, "prepare_sha", None): + l.attention.prepare_sha() + kv_type = torch.uint8 + if args.ptq == "8a8w": + quant_dtype = QuantDtype.use_8a8w + elif args.ptq == "16a4w": + quant_dtype = QuantDtype.use_16a4w + else: + raise AssertionError( + f"No support for quant type {args.ptq}. Support 8a8w and 16a4w." + ) + + if args.use_fp16: + quant_dtype = None + else: + assert args.tokenizer_model is not None, "Need tokenizer model for calibration" + + if args.dtype_override is not None: + dtype_override = DType[args.dtype_override] + llama_instance = llama_instance.to(dtype_override.to_torch_dtype()) + + llama_instance = convert_linear_to_conv2d(llama_instance) + + composite_llama = CompositeLlama(division, llama_instance.eval()) + kv_input_numel = ( + composite_llama.llama_meta["get_max_seq_len"] - 1 + ) * composite_llama.llama_meta["get_head_dim"] + start_split_ts = time.time() + inputs, input_list = create_device_inputs( + composite_llama.get_export_inputs(), kv_input_numel, kv_type + ) + pte_filename_list = [] + if args.pre_gen_pte is None: + composite_llama.split_llama() + end_split_ts = time.time() + print("composite_llama.split_llama()", end_split_ts - start_split_ts) + + if quant_dtype is not None: + composite_llama.quantize( + quant_dtype, + custom_annotations=( + annotate_matmul_16a8w, + annotate_linear_16a8w_in_affine_layer, + ), + ) + end_quantize_ts = time.time() + print( + "composite_llama.quantize(quant_dtype)", end_quantize_ts - end_split_ts + ) + del llama_instance + pte_filename_list = composite_llama.lowering_modules( + args.artifact, kv_type=kv_type + ) + assert len(pte_filename_list) != 0, "Failed to save pte file." + end_lowering_ts = time.time() + print("Complete Compile", end_lowering_ts - end_quantize_ts) + else: + for i in range(division + 2): + pte_filename = f"llama2_qnn_{i}" + pte_filename_list.append(f"{args.pre_gen_pte}/{pte_filename}.pte") + + workspace = f"/data/local/tmp/{getpass.getuser()}/executorch/composite_llama" + pte_filenames = [Path(pte_filename).name for pte_filename in pte_filename_list] + + runner_args = " ".join( + [ + f"--model_paths {','.join(pte_filenames)}", + "--output_folder_path outputs", + "--input_list_path input_list.txt", + f"--tokenizer_path {os.path.basename(args.tokenizer_bin)}", + f"--prompt {args.prompt}", + f"--seq_len {args.seq_len}", + f"--temperature {args.temperature}", + ] + ) + runner_cmd = " ".join( + [ + f"cd {workspace} &&", + "export ADSP_LIBRARY_PATH=. &&", + "export LD_LIBRARY_PATH=. &&", + f"./qnn_llama_runner {runner_args}", + ] + ) + + if not args.compile_only: + + adb = SimpleADB( + qnn_sdk=os.getenv("QNN_SDK_ROOT"), + artifact_path=f"{args.build_folder}", + pte_path=pte_filename_list, + workspace=workspace, + device_id=args.device, + host_id=args.host, + soc_model=args.model, + shared_buffer=args.shared_buffer, + runner="examples/qualcomm/qnn_llama_runner", + ) + adb.push(inputs=[inputs], input_list=input_list, files=[args.tokenizer_bin]) + adb.execute(custom_runner_cmd=runner_cmd) + + # collect output data + output_data_folder = f"{args.artifact}/outputs" + make_output_dir(output_data_folder) + outputs = [] + + def post_process(): + for f in sorted( + os.listdir(output_data_folder), key=lambda f: int(f.split("_")[1]) + ): + with codecs.open( + os.path.join(output_data_folder, f), + "r", + encoding="utf-8", + errors="replace", + ) as fdata: + outputs.append(fdata.read()) + + adb.pull(output_path=args.artifact, callback=post_process) + + for idx, output in enumerate(outputs): + print(f"Results[{idx}]:\n{output}") + + else: + compile_only_dir = os.path.join(args.artifact, args.artifact) + to_device_dir = os.path.join(compile_only_dir, "to_device") + os.makedirs(to_device_dir, exist_ok=True) + # input_list + input_list_file = os.path.join(to_device_dir, "input_list.txt") + with open(input_list_file, "w") as f: + f.write(input_list) + + # write inputs + for idx, data in enumerate([inputs]): + flat_inputs = [] + for d in data: + if isinstance(d, list): + for dd in d: + flat_inputs.append(dd) + else: + flat_inputs.append(d) + for i, d in enumerate(flat_inputs): + filename = os.path.join(to_device_dir, f"input_{idx}_{i}.raw") + d.detach().numpy().tofile(filename) + + # binaries + arch_table = { + "SM8650": "75", + "SM8550": "73", + "SM8475": "69", + "SM8450": "69", + } + dsp_arch = arch_table[args.model] + qnn_sdk_root = os.getenv("QNN_SDK_ROOT") + + on_device_files = [ + os.path.join(qnn_sdk_root, "lib", "aarch64-android", "libQnnHtp.so"), + os.path.join( + qnn_sdk_root, + "lib", + f"hexagon-v{dsp_arch}", + "unsigned", + f"libQnnHtpV{dsp_arch}Skel.so", + ), + os.path.join( + qnn_sdk_root, "lib", "aarch64-android", f"libQnnHtpV{dsp_arch}Stub.so" + ), + os.path.join(qnn_sdk_root, "lib", "aarch64-android", "libQnnSystem.so"), + os.path.join(args.build_folder, "examples", "qualcomm", "qnn_llama_runner"), + os.path.join( + args.build_folder, + "backends", + "qualcomm", + "libqnn_executorch_backend.so", + ), + ] + pte_filename_list + + for on_device_file in on_device_files: + shutil.copy2(on_device_file, to_device_dir) + + # tokenizer + shutil.copy2(args.tokenizer_bin, to_device_dir) + + run_sh_lines = [ + "set -e", + 'SOURCEDIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"', + f'adb_cmd="adb -s {args.device} -H {args.host}"', + f'${{adb_cmd}} shell "rm -rf {workspace} && mkdir -p {workspace}/outputs"', + f"${{adb_cmd}} push ${{SOURCEDIR}}/to_device/* {workspace}", + f'${{adb_cmd}} shell "{runner_cmd}"', + "echo", + "echo ----- output_0_0.raw -----", + "echo", + f'${{adb_cmd}} shell "cat {workspace}/outputs/output_0_0.raw"', + "", + ] + + run_sh_file = os.path.join(compile_only_dir, "run.sh") + with open(run_sh_file, "w") as fp: + fp.write("\n".join(run_sh_lines)) + + os.chmod(run_sh_file, stat.S_IRWXU | stat.S_IRWXG) + + print("Zipping files.....") + shutil.make_archive( + compile_only_dir, + "zip", + root_dir=args.artifact, + base_dir=os.path.relpath(compile_only_dir, args.artifact), + ) + + print(f"Compile only mode, necessary files are written to {compile_only_dir}") + print(f"And it's zipped as {compile_only_dir}.zip") diff --git a/examples/qualcomm/llama2/llama.py b/examples/qualcomm/llama2/llama.py index 49170184742..8c31d372f75 100644 --- a/examples/qualcomm/llama2/llama.py +++ b/examples/qualcomm/llama2/llama.py @@ -12,7 +12,6 @@ from functools import partial import torch -from torch.ao.quantization.observer import MinMaxObserver from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.backends.qualcomm.utils.utils import convert_linear_to_conv2d @@ -25,6 +24,7 @@ ) from sentencepiece import SentencePieceProcessor +from torch.ao.quantization.observer import MinMaxObserver def create_device_inputs(example_inputs): @@ -48,7 +48,7 @@ def create_device_inputs(example_inputs): def calibrate(example_inputs, module: torch.fx.GraphModule): sp_model = SentencePieceProcessor(model_file="tokenizer.model") - _, _, kv_mask, k_caches, v_caches = example_inputs + _, _, atten_mask, k_caches, v_caches = example_inputs # TODO: change criteria & support batch inputs if necessary pos = torch.tensor(0, dtype=torch.int32) @@ -68,14 +68,24 @@ def sample_top_p(probs: torch.Tensor, top_p: float) -> torch.Tensor: with torch.no_grad(): while token_list[-1] != sp_model.eos_id() and pos < 128: - logits, kv_mask, k_caches, v_caches = module( + logits, new_k_caches, new_v_caches = module( torch.full((1, 1), token_list[pos]), torch.full((1, 1), pos), - kv_mask, + atten_mask, *k_caches, *v_caches, ) + k_caches = [ + torch.cat([k_cache[:, 1:, :], new_k_caches[i]], dim=1) + for i, k_cache in enumerate(k_caches) + ] + v_caches = [ + torch.cat([v_cache[:, 1:, :], new_v_caches[i]], dim=1) + for i, v_cache in enumerate(v_caches) + ] + pos += 1 + atten_mask[0][-pos - 1] = 0 if pos >= len(token_list): probs = torch.softmax(logits[:, -1] / 0.8, dim=-1) token_list.append(sample_top_p(probs, 0.9).item()) @@ -174,8 +184,12 @@ def sample_top_p(probs: torch.Tensor, top_p: float) -> torch.Tensor: config.max_batch_size = 1 state_dict = torch.load(args.checkpoint) - instance = LlamaModel(config) - instance.load_state_dict(state_dict["model"]) + if "model" in state_dict: + state_dict = state_dict["model"] + with torch.device("meta"): + instance = LlamaModel(config) + instance.load_state_dict(state_dict, strict=False, assign=True) + inputs, input_list = create_device_inputs(instance.get_export_inputs()) pte_filename = "llama2_qnn" @@ -193,6 +207,11 @@ def sample_top_p(probs: torch.Tensor, top_p: float) -> torch.Tensor: else: assert args.tokenizer_model is not None, "Need tokenizer model for calibration" + # prepare sha if the function is provided + for l in instance.layers: + if getattr(l.attention, "prepare_sha", None): + l.attention.prepare_sha() + if args.pre_gen_pte is None: build_executorch_binary( # try this if you want: convert_linear_to_conv2d(instance.eval()), @@ -207,7 +226,7 @@ def sample_top_p(probs: torch.Tensor, top_p: float) -> torch.Tensor: shared_buffer=args.shared_buffer, metadata=instance.get_metadata(), direct_io=True, - act_observer=MinMaxObserver + act_observer=MinMaxObserver, ) if args.compile_only: diff --git a/examples/qualcomm/llama2/model/static_llama.py b/examples/qualcomm/llama2/model/static_llama.py index c4cc25607a1..3b98700b220 100644 --- a/examples/qualcomm/llama2/model/static_llama.py +++ b/examples/qualcomm/llama2/model/static_llama.py @@ -10,7 +10,6 @@ import torch.nn as nn from executorch.examples.models.llama2.llama_transformer import ( - apply_rotary_emb, FeedForward, ModelArgs, precompute_freqs_cis, @@ -18,8 +17,20 @@ ) +def apply_rotary_emb_single( + x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor +) -> torch.Tensor: + x_r, x_i = x[..., ::2], x[..., 1::2] + + x_out_r = x_r * freqs_cos - x_i * freqs_sin + x_out_i = x_r * freqs_sin + x_i * freqs_cos + + x_out = torch.cat([x_out_r, x_out_i], dim=-1) + return x_out + + class LlamaAttention(nn.Module): - def __init__(self, config: ModelArgs, split_kv_cache=False): + def __init__(self, config: ModelArgs, output_new_cache_only=False): super().__init__() self.dim = config.dim self.n_heads = config.n_heads @@ -27,7 +38,7 @@ def __init__(self, config: ModelArgs, split_kv_cache=False): self.n_kv_heads = config.n_kv_heads self.num_key_value_groups = config.n_heads // self.n_kv_heads self.max_seq_len = config.max_seq_len - self.split_kv_cache = split_kv_cache + self.output_new_cache_only = output_new_cache_only self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False) self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False) @@ -38,17 +49,91 @@ def __init__(self, config: ModelArgs, split_kv_cache=False): scale = float(self.head_dim) ** -0.5 scale_tensor = torch.tensor( - [scale], dtype=torch.float32, requires_grad=False + [scale], dtype=torch.float32, requires_grad=False, device="cpu" ).view(1, 1, 1) self.register_buffer("scale_tensor", scale_tensor, False) + def prepare_sha(self): + self.wq_sha = nn.ModuleList( + [ + nn.Linear(self.dim, self.head_dim, bias=False) + for _ in range(self.n_heads) + ] + ) + self.wk_sha = nn.ModuleList( + [ + nn.Linear(self.dim, self.head_dim, bias=False) + for _ in range(self.n_heads) + ] + ) + self.wv_sha = nn.ModuleList( + [ + nn.Linear(self.dim, self.head_dim, bias=False) + for _ in range(self.n_heads) + ] + ) + + self.forward_mha = self.forward + self.forward = self.forward_sha + + for i in range(self.n_heads): + self.wq_sha[i].weight.data.copy_( + self.wq.weight[i * self.head_dim : (i + 1) * self.head_dim] + ) + self.wk_sha[i].weight.data.copy_( + self.wk.weight[i * self.head_dim : (i + 1) * self.head_dim] + ) + self.wv_sha[i].weight.data.copy_( + self.wv.weight[i * self.head_dim : (i + 1) * self.head_dim] + ) + + def forward_sha( + self, + hidden_states: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + atten_mask: torch.Tensor, + k_caches: List[torch.Tensor], + v_caches: List[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + bsz, seqlen, _ = hidden_states.shape + + q = [wq_sha(hidden_states) for wq_sha in self.wq_sha] + k = [wk_sha(hidden_states) for wk_sha in self.wk_sha] + v = [wv_sha(hidden_states) for wv_sha in self.wv_sha] + for i in range(len(q)): + q[i] = apply_rotary_emb_single(q[i], freqs_cos, freqs_sin) + k[i] = apply_rotary_emb_single(k[i], freqs_cos, freqs_sin).permute(0, 2, 1) + + output_kh, output_vh, output_y = [], [], [] + for i, _ in enumerate(k_caches): + # cat at the seq dim + kh = torch.cat([k_caches[i], k[i]], dim=-1) + vh = torch.cat([v_caches[i], v[i]], dim=1) + + attn = q[i] @ kh + attn = attn * self.scale_tensor + atten_mask + attn = self.attn_softmax(attn) + y = attn @ vh + + if self.output_new_cache_only: + output_kh.append(k[i]) + output_vh.append(v[i]) + else: + output_kh.append(kh) + output_vh.append(vh) + output_y.append(y) + + y = torch.concat(output_y, dim=-1) + y = self.wo(y) + return y, output_kh, output_vh + def forward( self, hidden_states: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor, atten_mask: torch.Tensor, - kv_mask: torch.Tensor, k_caches: List[torch.Tensor], v_caches: List[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -59,45 +144,40 @@ def forward( k = k.view(bsz, seqlen, self.n_kv_heads, self.head_dim) v = v.view(bsz, seqlen, self.n_kv_heads, self.head_dim) - q, k = apply_rotary_emb(q, k, freqs_cos, freqs_sin) + q = apply_rotary_emb_single(q, freqs_cos, freqs_sin) + k = apply_rotary_emb_single(k, freqs_cos, freqs_sin).permute(0, 2, 1) - if self.split_kv_cache: - output_kh, output_vh, output_y = [], [], [] - for i, _ in enumerate(k_caches): - kh = k_caches[i] + k[:, :, i, :] * kv_mask - vh = v_caches[i] + v[:, :, i, :] * kv_mask + output_kh, output_vh, output_y = [], [], [] - attn = q[:, :, i, :] @ kh.permute(0, 2, 1) - attn = attn * self.scale_tensor + atten_mask - attn = self.attn_softmax(attn) - y = attn @ vh + for i, _ in enumerate(k_caches): + # cat at the seq dim + kh = torch.cat( + [k_caches[i], k[:, :, :, i]], dim=-1 + ) # TODO verify the correctness + vh = torch.cat([v_caches[i], v[:, :, i, :]], dim=1) - output_kh.append(kh) - output_vh.append(vh) - output_y.append(y) - - y = torch.concat(output_y, dim=-1) - y = self.wo(y) - return y, output_kh, output_vh - else: - k = k_caches + k * kv_mask - v = v_caches + v * kv_mask - - attn = q.transpose(1, 2) @ k.permute(0, 2, 3, 1) + attn = q[:, :, i, :] @ kh.permute(0, 2, 1) attn = attn * self.scale_tensor + atten_mask attn = self.attn_softmax(attn) - y = attn @ v.transpose(1, 2) - y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) - y = self.wo(y) + y = attn @ vh + + output_kh.append(kh) + output_vh.append(vh) + output_y.append(y) + + y = torch.concat(output_y, dim=-1) + y = self.wo(y) - return y, k, v + return y, output_kh, output_vh class LlamaDecoderLayer(nn.Module): - def __init__(self, config: ModelArgs, split_kv_cache=False): + def __init__(self, config: ModelArgs, output_new_cache_only=False): super().__init__() self.dim = config.dim - self.attention = LlamaAttention(config=config, split_kv_cache=split_kv_cache) + self.attention = LlamaAttention( + config=config, output_new_cache_only=output_new_cache_only + ) self.feed_forward = FeedForward(config) self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps) self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps) @@ -108,16 +188,14 @@ def forward( freqs_cos: torch.Tensor, freqs_sin: torch.Tensor, atten_mask: torch.Tensor, - kv_mask: torch.Tensor, k_caches: List[torch.Tensor], v_caches: List[torch.Tensor], - ) -> Tuple[torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: h, k_cache, v_cache = self.attention( hidden_states=self.attention_norm(x), freqs_cos=freqs_cos, freqs_sin=freqs_sin, atten_mask=atten_mask, - kv_mask=kv_mask, k_caches=k_caches, v_caches=v_caches, ) @@ -127,7 +205,7 @@ def forward( class LlamaModel(nn.Module): - def __init__(self, config: ModelArgs, split_kv_cache=False): + def __init__(self, config: ModelArgs, output_new_cache_only=True): super().__init__() self.dim = config.dim self.head_dim = config.dim // config.n_heads @@ -137,10 +215,14 @@ def __init__(self, config: ModelArgs, split_kv_cache=False): self.n_kv_heads = config.n_kv_heads self.n_layers = config.n_layers self.vocab_size = config.vocab_size - self.split_kv_cache = split_kv_cache + self.rope_freq_base = config.rope_freq_base + self.output_new_cache_only = output_new_cache_only self.layers = nn.ModuleList( - [LlamaDecoderLayer(config, split_kv_cache) for _ in range(config.n_layers)] + [ + LlamaDecoderLayer(config, self.output_new_cache_only) + for _ in range(config.n_layers) + ] ) self.norm = RMSNorm(config.dim, eps=config.norm_eps) self.output = nn.Linear(config.dim, config.vocab_size, bias=False) @@ -150,28 +232,14 @@ def __init__(self, config: ModelArgs, split_kv_cache=False): config.max_seq_len, config.rope_freq_base, ) - atten_mask = torch.triu( - torch.full( - (self.max_seq_len, self.max_seq_len), - -255.0, - ), - diagonal=1, - ) - self.register_buffer("atten_mask", atten_mask, persistent=False) self.register_buffer("freqs_cos", freqs_cos, persistent=False) self.register_buffer("freqs_sin", freqs_sin, persistent=False) - if split_kv_cache: - self.register_buffer("mask", torch.ones(self.head_dim), persistent=False) - self.register_buffer("unmask", torch.zeros(self.head_dim), persistent=False) - else: - self.register_buffer("mask", torch.ones(self.dim), persistent=False) - self.register_buffer("unmask", torch.zeros(self.dim), persistent=False) def forward( self, tokens: torch.Tensor, input_pos: torch.Tensor, - kv_mask: torch.Tensor, + atten_mask: torch.Tensor, *args, ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]: output_k_cache = [] @@ -179,53 +247,28 @@ def forward( # following tensors should be invariant across batches freqs_cos = self.freqs_cos[input_pos][0] freqs_sin = self.freqs_sin[input_pos][0] - atten_mask = self.atten_mask[input_pos][0] hidden_states = self.tok_embeddings(tokens) for ind, decoder_layer in enumerate(self.layers): - if self.split_kv_cache: - offset_k = ind * self.n_heads - offset_v = self.n_layers * self.n_heads + offset_k - k_caches = args[offset_k : offset_k + self.n_heads] - v_caches = args[offset_v : offset_v + self.n_heads] - hidden_states, k, v = decoder_layer( - hidden_states, - freqs_cos=freqs_cos, - freqs_sin=freqs_sin, - atten_mask=atten_mask, - kv_mask=kv_mask, - k_caches=k_caches, - v_caches=v_caches, - ) - output_k_cache.extend(k) - output_v_cache.extend(v) - else: - k_caches = args[ind] - v_caches = args[self.n_layers + ind] - hidden_states, k, v = decoder_layer( - hidden_states, - freqs_cos=freqs_cos, - freqs_sin=freqs_sin, - atten_mask=atten_mask, - kv_mask=kv_mask.view( - self.max_seq_len, self.n_kv_heads, self.head_dim - ), - k_caches=k_caches, - v_caches=v_caches, - ) - output_k_cache.append(k) - output_v_cache.append(v) + offset_k = ind * self.n_heads + offset_v = self.n_layers * self.n_heads + offset_k + k_caches = args[offset_k : offset_k + self.n_heads] + v_caches = args[offset_v : offset_v + self.n_heads] + hidden_states, k, v = decoder_layer( + hidden_states, + freqs_cos=freqs_cos, + freqs_sin=freqs_sin, + atten_mask=atten_mask, + k_caches=k_caches, + v_caches=v_caches, + ) + output_k_cache.extend(k) + output_v_cache.extend(v) hidden_states = self.norm(hidden_states) logits = self.output(hidden_states) - # TODO: add op builder for kv mask update once HTP supports more ops - # this part is now expected to be fallback on cpu - # for simplicity, input_pos is assumed to never go over max_seq_len-1 - kv_mask[input_pos] = self.unmask - kv_mask[input_pos + 1] = self.mask - - return logits, kv_mask, output_k_cache, output_v_cache + return logits, output_k_cache, output_v_cache def get_example_inputs(self): tokens = torch.randint( @@ -233,41 +276,29 @@ def get_example_inputs(self): ) pos_ids = torch.zeros((self.max_batch_size, 1), dtype=torch.int32) k_cache, v_cache = [], [] - if self.split_kv_cache: - kv_mask = torch.zeros(self.max_seq_len, self.head_dim) - kv_mask[0] = torch.ones(self.head_dim) - for _ in range(self.n_layers): - for _ in range(self.n_heads): - k_cache += torch.zeros( + atten_mask = torch.full((self.max_batch_size, self.max_seq_len), -255.0) + atten_mask[:, -1] = 0 + for _ in range(self.n_layers): + for _ in range(self.n_heads): + # transpose first to decrease the runtime efforts + k_cache.append( + torch.zeros( self.max_batch_size, - self.max_seq_len, self.head_dim, + self.max_seq_len - 1, ) - v_cache += torch.zeros( + ) + v_cache.append( + torch.zeros( self.max_batch_size, - self.max_seq_len, + self.max_seq_len - 1, self.head_dim, ) - else: - kv_mask = torch.zeros(self.max_seq_len, self.dim) - kv_mask[0] = torch.ones(self.dim) - for _ in range(self.n_layers): - k_cache += torch.zeros( - self.max_batch_size, - self.max_seq_len, - self.n_heads, - self.head_dim, - ) - v_cache += torch.zeros( - self.max_batch_size, - self.max_seq_len, - self.n_heads, - self.head_dim, ) return ( tokens, pos_ids, - kv_mask, + atten_mask, k_cache, v_cache, ) @@ -279,41 +310,29 @@ def get_export_inputs(self): pos_ids = torch.zeros((self.max_batch_size, 1), dtype=torch.int32) # this is important for torch.export not to take it as dummy input k_cache, v_cache = [], [] - if self.split_kv_cache: - kv_mask = torch.zeros(self.max_seq_len, self.head_dim) - kv_mask[0] = torch.ones(self.head_dim) - for _ in range(self.n_layers): - for _ in range(self.n_heads): - k_cache += torch.randn( + atten_mask = torch.full((self.max_batch_size, self.max_seq_len), -255.0) + atten_mask[:, -1] = 0 + for _ in range(self.n_layers): + for _ in range(self.n_heads): + # transpose first to decrease the runtime efforts + k_cache.append( + torch.randn( self.max_batch_size, - self.max_seq_len, self.head_dim, + self.max_seq_len - 1, ) - v_cache += torch.randn( + ) + v_cache.append( + torch.randn( self.max_batch_size, - self.max_seq_len, + self.max_seq_len - 1, self.head_dim, ) - else: - kv_mask = torch.zeros(self.max_seq_len, self.dim) - kv_mask[0] = torch.ones(self.dim) - for _ in range(self.n_layers): - k_cache += torch.randn( - self.max_batch_size, - self.max_seq_len, - self.n_heads, - self.head_dim, - ) - v_cache += torch.randn( - self.max_batch_size, - self.max_seq_len, - self.n_heads, - self.head_dim, ) return ( tokens, pos_ids, - kv_mask, + atten_mask, k_cache, v_cache, ) diff --git a/examples/qualcomm/llama2/runner/runner.cpp b/examples/qualcomm/llama2/runner/runner.cpp index 0b4fa9e71af..379c4d2e41b 100644 --- a/examples/qualcomm/llama2/runner/runner.cpp +++ b/examples/qualcomm/llama2/runner/runner.cpp @@ -34,23 +34,29 @@ std::string statsToJsonString(const Runner::Stats& stats); } // namespace Runner::Runner( - const std::string& model_path, + const std::vector& model_path_list, const std::string& tokenizer_path, const float temperature) - : module_(std::make_unique( - model_path, - Module::MlockConfig::UseMlockIgnoreErrors)), - tokenizer_path_(tokenizer_path), + : tokenizer_path_(tokenizer_path), temperature_(temperature) { - ET_LOG( - Info, - "Creating LLaMa runner: model_path=%s, tokenizer_path=%s", - model_path.c_str(), - tokenizer_path.c_str()); + for(auto& model_path : model_path_list){ + modules_.emplace_back(std::make_unique( + model_path, + Module::MlockConfig::UseMlockIgnoreErrors)); + ET_LOG( + Info, + "Creating LLaMa runner: model_path=%s, tokenizer_path=%s", + model_path.c_str(), + tokenizer_path.c_str()); + } } bool Runner::is_loaded() const { - return module_->is_loaded() && tokenizer_ && sampler_; + bool loaded = true; + for(auto& module : modules_){ + loaded &= module->is_loaded(); + } + return loaded && tokenizer_ && sampler_; } Error Runner::load() { @@ -58,20 +64,23 @@ Error Runner::load() { return Error::Ok; } stats_.model_load_start_ms = util::time_in_ms(); - ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method("forward")); - - // Read out metadata from the model - ET_LOG(Info, "Reading metadata from model"); - const auto method_names = module_->method_names(); - ET_CHECK_MSG(method_names.ok(), "Failed to read method names from model"); - model_methods_ = method_names.get(); - vocab_size_ = getMetadataHelper("get_vocab_size", 32000); - bos_id_ = getMetadataHelper("get_bos_id", 1); - eos_id_ = getMetadataHelper("get_eos_id", 2); - n_bos_ = getMetadataHelper("get_n_bos", 1); - n_eos_ = getMetadataHelper("get_n_eos", 1); - max_seq_len_ = getMetadataHelper("get_max_seq_len", 128); - + for(auto& module : modules_){ + ET_CHECK_OK_OR_RETURN_ERROR(module->load_method("forward")); + + + // Read out metadata from the model + ET_LOG(Info, "Reading metadata from model"); + const auto method_names = module->method_names(); + ET_CHECK_MSG(method_names.ok(), "Failed to read method names from model"); + model_methods_ = method_names.get(); + vocab_size_ = getMetadataHelper(module.get(), "get_vocab_size", 32000); + bos_id_ = getMetadataHelper(module.get(), "get_bos_id", 1); + eos_id_ = getMetadataHelper(module.get(), "get_eos_id", 2); + n_bos_ = getMetadataHelper(module.get(), "get_n_bos", 1); + n_eos_ = getMetadataHelper(module.get(), "get_n_eos", 1); + max_seq_len_ = getMetadataHelper(module.get(), "get_max_seq_len", 128); + head_dim_ = getMetadataHelper(module.get(), "get_head_dim", 32); + } // Load tokenizer tokenizer_ = std::make_unique(vocab_size_, bos_id_, eos_id_); tokenizer_->load(tokenizer_path_); @@ -101,10 +110,10 @@ Error Runner::load() { } template -T Runner::getMetadataHelper(std::string method_name, T default_val) { +T Runner::getMetadataHelper(Module* module, std::string method_name, T default_val) { T res = default_val; if (model_methods_.count(method_name)) { - Result> outputs = module_->execute(method_name); + Result> outputs = module->execute(method_name); if (outputs.ok()) { std::vector outs = outputs.get(); if (outs.size() > 0) { @@ -131,32 +140,96 @@ int32_t Runner::logitsToToken(const exec_aten::Tensor& logits_tensor) { return sampler_->sample(logits_last); } + // Given an input token. Set up the inputs for the model and execute a single // step. Returning the logits tensor. Result Runner::run_model_step( int64_t input_token, Tensor& token, Tensor& start_pos, - std::vector& input_tensors) { + Tensor& atten_mask, + Tensor& freqs_cos, + Tensor& freqs_sin, + std::vector>& kv_tensors, + std::vector>& kv_outputs) { token.mutable_data_ptr()[0] = input_token; - // inputs:[tokens, start_pos, kv_mask, k_cache, v_cache] - // input_tensors:[kv_mask, k_cache, v_cache] - std::vector inputs = {token, start_pos}; - inputs.insert(inputs.end(), input_tensors.begin(), input_tensors.end()); - - Result> outputs_res = module_->forward(inputs); + // embedding + std::vector inputs = {token}; + Result> outputs_res = modules_[0]->forward(inputs); ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error()); + EValue hidden_states = outputs_res.get()[0]; + + // llama block + std::vector>> llama_block_results; + for(int i = 1; i < modules_.size() - 1; ++i){ + inputs = {hidden_states, freqs_cos, freqs_sin, atten_mask}; + inputs.insert(inputs.end(), kv_tensors[i-1].begin(), kv_tensors[i-1].end()); + Result> llama_block_outputs_res = modules_[i]->forward(inputs); + ET_CHECK_OK_OR_RETURN_ERROR(llama_block_outputs_res.error()); + hidden_states = llama_block_outputs_res.get()[0]; + } + + // TODO: need to handle batch size != 1 + // update k_cache + size_t v_offset = kv_outputs[0][0].nbytes(); + size_t el_size = kv_outputs[0][0].element_size(); + size_t k_input_step = (max_seq_len_-1) * el_size; + for (int i = 1; i < modules_.size() - 1; ++i) { + int k_tensors_end = kv_tensors[i].size() / 2; + //update k caches + for (int j = 0, index = i-1; j < k_tensors_end; ++j) { + char *input_addr = static_cast(kv_tensors[index][j].mutable_data_ptr()); + char *output_addr = static_cast(kv_outputs[index][j].mutable_data_ptr()); + + // fill the output k values back + #pragma omp parallel for + for (int src = 0, dst = k_input_step; src < kv_outputs[index][j].nbytes(); src+=el_size, dst+=k_input_step) { + memcpy(input_addr+dst, output_addr+src, el_size); + } + + // inputs + ET_CHECK_MSG( + internal::set_tensor_data(kv_tensors[index][j], input_addr + kv_tensors[index][j].element_size(), kv_tensors[index][j].nbytes()) == Error::Ok, + "Failed to set input tensor when updating kv_cache"); + } + // update v caches + for (int j = k_tensors_end, index = i-1; j < kv_tensors[index].size(); ++j) { + // inputs + char *input_addr = static_cast(kv_tensors[index][j].mutable_data_ptr()) + v_offset; + ET_CHECK_MSG( + internal::set_tensor_data(kv_tensors[index][j], input_addr, kv_tensors[index][j].nbytes()) == Error::Ok, + "Failed to set input tensor when updating kv_cache"); + + // outputs + char *output_addr = static_cast(kv_outputs[index][j].mutable_data_ptr()) + v_offset; + ET_CHECK_MSG( + internal::set_tensor_data(kv_outputs[index][j], output_addr, kv_outputs[index][j].nbytes()) == Error::Ok, + "Failed to set output tensor when updating kv_cache"); + ET_CHECK_MSG( + modules_[i]->set_output_data_ptr(kv_outputs[index][j], j+1) == Error::Ok, + "Failed to set output tensor for llama block"); + } + } + + // affine module + inputs = {hidden_states}; + Result> logits_outputs_res = modules_[modules_.size()-1]->forward(inputs); + ET_CHECK_OK_OR_RETURN_ERROR(logits_outputs_res.error()); // Bump start_pos by 1 start_pos.mutable_data_ptr()[0]++; - return outputs_res.get()[1].toTensor(); -} + // update atten_mask + atten_mask.mutable_data_ptr()[atten_mask.numel() - 1 - start_pos.const_data_ptr()[0]] = 0; + + return logits_outputs_res.get()[0].toTensor(); +} // TODO: add overloaded method for on-device tokenize Error Runner::generate( const std::string& prompt, int32_t seq_len, - std::vector& managed_inputs, + std::vector>& managed_kv_inputs, + std::vector>& freqs_inputs, std::function token_callback, std::function stats_callback) { ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null"); @@ -195,7 +268,22 @@ Error Runner::generate( std::vector start_pos_data = {0}; std::vector start_pos_shape = {1, 1}; - std::vector logits_data(vocab_size_); + std::vector atten_mask_data(max_seq_len_); + std::fill(atten_mask_data.begin(), atten_mask_data.end()-1, -255.0); + atten_mask_data.back() = 0; + + std::vector freqs_cos_data(head_dim_/2); + std::fill(freqs_cos_data.begin(), freqs_cos_data.end(), 0.0); + + std::vector freqs_sin_data(head_dim_/2); + std::fill(freqs_sin_data.begin(), freqs_sin_data.end(), 0.0); + + std::vector freqs_cos_shape = {1, head_dim_/2}; + + std::vector freqs_sin_shape = {1, head_dim_/2}; + + std::vector atten_mask_shape = {1, max_seq_len_}; + std::vector logits_data_shape = {1, vocab_size_}; // initialize tensor wrappers @@ -203,34 +291,105 @@ Error Runner::generate( token_data.data(), 128, token_shape, ScalarType::Int); ManagedTensor managed_pos_id( start_pos_data.data(), 128, start_pos_shape, ScalarType::Int); - ManagedTensor managed_logits( - logits_data.data(), 128, logits_data_shape, ScalarType::Float); + ManagedTensor managed_atten_mask( + atten_mask_data.data(), 128, atten_mask_shape, ScalarType::Float); + ManagedTensor managed_freqs_cos( + freqs_cos_data.data(), 128, freqs_cos_shape, ScalarType::Float); + ManagedTensor managed_freqs_sin( + freqs_sin_data.data(), 128, freqs_sin_shape, ScalarType::Float); + - Tensor logits = managed_logits.get_aliasing_tensor(); Tensor token = managed_token.get_aliasing_tensor(); + Tensor atten_mask = managed_atten_mask.get_aliasing_tensor(); Tensor start_pos = managed_pos_id.get_aliasing_tensor(); - - // TODO: investigate why kv_mask was duplicated in the output - // current output: [kv_mask, logits, k_cache, v_cache, kv_mask] - // change following indexes back when issue got resolved - std::vector inputs; - for (int i = 0; i < managed_inputs.size(); ++i) { - inputs.push_back(managed_inputs[i].get_aliasing_tensor()); + Tensor freqs_cos = managed_freqs_cos.get_aliasing_tensor(); + Tensor freqs_sin = managed_freqs_sin.get_aliasing_tensor(); + + // embedding + std::vector embedding_logits_data(vocab_size_); + ManagedTensor embedding_managed_logits( + embedding_logits_data.data(), 128, logits_data_shape, ScalarType::Float); + Tensor embedding_logits = embedding_managed_logits.get_aliasing_tensor(); + ET_CHECK_MSG( + modules_[0]->set_output_data_ptr(embedding_logits, 0) == Error::Ok, + "Failed to set output tensor for embedding module - logits"); + + // llama block + std::vector> llama_block_logit_tensor_data(modules_.size()-2); + std::vector llama_block_logit_tensors, kv_outputs_managed; + std::vector> kv_tensors(modules_.size()-2), kv_outputs(modules_.size()-2); + std::vector> methods_meta = get_methods_meta(); + + for (int i = 1; i < modules_.size() - 1; ++i){ + Result &cur_meta = methods_meta[i]; + std::vector logits_data(vocab_size_); + llama_block_logit_tensor_data.push_back(logits_data); + llama_block_logit_tensors.emplace_back(ManagedTensor( + logits_data.data(), 128, logits_data_shape, ScalarType::Float)); + Tensor logits = llama_block_logit_tensors.back().get_aliasing_tensor(); + const int k_caches_end = managed_kv_inputs[i-1].size()/2; + + // k caches init + for (int j = 0; j < k_caches_end; ++j) { + kv_tensors[i-1].push_back(managed_kv_inputs[i-1][j].get_aliasing_tensor()); + Result out_tensor_meta = cur_meta->output_tensor_meta(j+1); + auto tensor_shape = out_tensor_meta->sizes(); + std::vector out_tensor_shape( + tensor_shape.data(), tensor_shape.data() + tensor_shape.size()); + + int output_offset = (out_tensor_meta->nbytes()+kv_tensors[i-1][j].element_size()) * (max_seq_len_-1); + char *output_addr = static_cast(kv_tensors[i-1][j].mutable_data_ptr()) + output_offset; + + kv_outputs_managed.push_back(ManagedTensor( + output_addr, 128, out_tensor_shape, kv_tensors[i-1][j].scalar_type())); + kv_outputs[i-1].push_back(kv_outputs_managed.back().get_aliasing_tensor()); + ET_CHECK_MSG( + modules_[i]->set_output_data_ptr(kv_outputs[i-1][j], j+1) == Error::Ok, + "Failed to set output tensor for llama block"); + } + // v caches init + for (int j = k_caches_end; j < managed_kv_inputs[i-1].size(); ++j) { + kv_tensors[i-1].push_back(managed_kv_inputs[i-1][j].get_aliasing_tensor()); + char *output_addr = static_cast(kv_tensors[i-1][j].mutable_data_ptr()) + + (max_seq_len_-1)*head_dim_*kv_tensors[i-1][j].element_size(); + + Result out_tensor_meta = cur_meta->output_tensor_meta(j+1); + auto tensor_shape = out_tensor_meta->sizes(); + std::vector out_tensor_shape( + tensor_shape.data(), tensor_shape.data() + tensor_shape.size()); + + kv_outputs_managed.push_back(ManagedTensor( + output_addr, 128, out_tensor_shape, kv_tensors[i-1][j].scalar_type())); + kv_outputs[i-1].push_back(kv_outputs_managed.back().get_aliasing_tensor()); + ET_CHECK_MSG( + modules_[i]->set_output_data_ptr(kv_outputs[i-1][j], j+1) == Error::Ok, + "Failed to set output tensor for llama block"); + } ET_CHECK_MSG( - module_->set_output_data_ptr(inputs.back(), i + 2) == Error::Ok, - "Failed to set output tensor"); + modules_[i]->set_output_data_ptr(logits, 0) == Error::Ok, + "Failed to set output tensor for llama block - logits"); } + + // affine layer + std::vector affine_logits_data(vocab_size_); + ManagedTensor affine_managed_logits( + affine_logits_data.data(), 128, logits_data_shape, ScalarType::Float); + Tensor affine_logits = affine_managed_logits.get_aliasing_tensor(); ET_CHECK_MSG( - module_->set_output_data_ptr(logits, 1) == Error::Ok, - "Failed to set output tensor - logits"); + modules_[modules_.size()-1]->set_output_data_ptr(affine_logits, 0) == Error::Ok, + "Failed to set output tensor for affine module - logits"); // Start consuming user's prompts and generating new tokens std::string final_output; while (pos < seq_len - 1) { + for(int i = 0; i < head_dim_/2; i++){ + freqs_cos.mutable_data_ptr()[i] = freqs_inputs[0][pos*(head_dim_/2)+i]; + freqs_sin.mutable_data_ptr()[i] = freqs_inputs[1][pos*(head_dim_/2)+i]; + } + // Run the model Result logits_res = - run_model_step(cur_token, token, start_pos, inputs); - + run_model_step(cur_token, token, start_pos, atten_mask, freqs_cos, freqs_sin, kv_tensors, kv_outputs); if (pos == num_prompt_tokens) { stats_.first_token_ms = util::time_in_ms(); } else if (pos == num_prompt_tokens - 1) { @@ -240,8 +399,8 @@ Error Runner::generate( ET_CHECK_OK_OR_RETURN_ERROR(logits_res.error()); exec_aten::Tensor& logits_tensor = logits_res.get(); prev_token = cur_token; - long sample_start_time_ms = util::time_in_ms(); + cur_token = logitsToToken(logits_tensor); stats_.aggregate_sampling_time_ms += util::time_in_ms() - sample_start_time_ms; @@ -370,15 +529,21 @@ void Runner::stop() { shouldStop_ = true; } -Result Runner::method_meta() { - return module_->method_meta("forward"); +std::vector> Runner::get_methods_meta() { + std::vector> tmp; + for (auto& module : modules_){ + tmp.push_back(module->method_meta("forward")); + } + return tmp; } // explicit instantiation of template methods template int64_t Runner::getMetadataHelper( + Module* module, std::string method_name, int64_t default_val); template bool Runner::getMetadataHelper( + Module* module, std::string method_name, bool default_val); diff --git a/examples/qualcomm/llama2/runner/runner.h b/examples/qualcomm/llama2/runner/runner.h index 120e8af5be7..ffda2eb37cb 100644 --- a/examples/qualcomm/llama2/runner/runner.h +++ b/examples/qualcomm/llama2/runner/runner.h @@ -29,7 +29,7 @@ namespace executor { class Runner { public: explicit Runner( - const std::string& model_path, + const std::vector& model_path_list, const std::string& tokenizer_path, const float temperature = 0.8f); @@ -64,23 +64,28 @@ class Runner { Error generate( const std::string& prompt, int32_t seq_len, - std::vector& managed_inputs, + std::vector>& managed_kv_inputs, + std::vector>& freqs_inputs, std::function token_callback = {}, std::function stats_callback = {}); void stop(); - Result method_meta(); + std::vector> get_methods_meta(); private: // metadata template - T getMetadataHelper(std::string method_name, T default_val); + T getMetadataHelper(Module*, std::string method_name, T default_val); template int32_t logitsToToken(const exec_aten::Tensor& logits_tensor); Result run_model_step( int64_t input_token, - Tensor& token, - Tensor& start_pos, - std::vector& input_tensors); + Tensor& token, + Tensor& start_pos, + Tensor& atten_mask, + Tensor& freqs_cos, + Tensor& freqs_sin, + std::vector>& kv_tensors, + std::vector>& kv_outputs); // metadata int32_t vocab_size_; int64_t bos_id_; @@ -88,8 +93,9 @@ class Runner { int32_t n_bos_; int32_t n_eos_; int32_t max_seq_len_; + int32_t head_dim_; std::unordered_set model_methods_; - std::unique_ptr module_; + std::vector> modules_; std::string tokenizer_path_; float temperature_; std::unique_ptr tokenizer_; diff --git a/examples/qualcomm/scripts/utils.py b/examples/qualcomm/scripts/utils.py index 1064d9ff3a2..6e43835bfdf 100755 --- a/examples/qualcomm/scripts/utils.py +++ b/examples/qualcomm/scripts/utils.py @@ -6,6 +6,7 @@ import argparse import os +import re import subprocess import sys from pathlib import Path @@ -15,7 +16,6 @@ import numpy as np import torch -from torch.ao.quantization.observer import MovingAverageMinMaxObserver from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner from executorch.backends.qualcomm.quantizer.quantizer import ( get_16a4w_qnn_ptq_config, @@ -35,6 +35,7 @@ from executorch.exir.backend.backend_api import to_backend from executorch.exir.capture._config import ExecutorchBackendConfig from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass +from torch.ao.quantization.observer import MovingAverageMinMaxObserver from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e @@ -54,11 +55,11 @@ def __init__( ): self.qnn_sdk = qnn_sdk self.artifact_path = artifact_path - self.pte_path = pte_path + self.pte_path = pte_path if isinstance(pte_path, list) else [pte_path] self.workspace = workspace self.device_id = device_id self.host_id = host_id - self.working_dir = Path(self.pte_path).parent.absolute() + self.working_dir = Path(self.pte_path[0]).parent.absolute() self.input_list_filename = "input_list.txt" self.etdump_path = f"{self.workspace}/etdump.etdp" self.output_folder = f"{self.workspace}/outputs" @@ -96,7 +97,6 @@ def push(self, inputs, input_list, files=None): # necessary artifacts for artifact in [ - f"{self.pte_path}", f"{self.qnn_sdk}/lib/aarch64-android/libQnnHtp.so", ( f"{self.qnn_sdk}/lib/hexagon-v{self.soc_model}/" @@ -111,11 +111,13 @@ def push(self, inputs, input_list, files=None): f"{self.artifact_path}/{self.runner}", f"{self.artifact_path}/backends/qualcomm/libqnn_executorch_backend.so", input_list_file, - ]: + ] + self.pte_path: self._adb(["push", artifact, self.workspace]) # input data for idx, data in enumerate(inputs): + # print("[Warning] inputs push are is skip") + # break flat_inputs = [] for input in data: if isinstance(input, list): @@ -137,9 +139,12 @@ def execute(self, custom_runner_cmd=None): self._adb(["shell", f"mkdir -p {self.output_folder}"]) # run the delegation if custom_runner_cmd is None: + pte_path_str = ",".join( + [os.path.basename(pte_path) for pte_path in self.pte_path] + ) qnn_executor_runner_args = " ".join( [ - f"--model_path {os.path.basename(self.pte_path)}", + f"--model_paths {pte_path_str}", f"--output_folder_path {self.output_folder}", f"--input_list_path {self.input_list_filename}", f"--etdump_path {self.etdump_path}", @@ -185,7 +190,7 @@ def build_executorch_binary( direct_io=False, # TODO: temporal workaround for llama shared_buffer=False, metadata=None, - act_observer=MovingAverageMinMaxObserver + act_observer=MovingAverageMinMaxObserver, ): if quant_dtype is not None: quantizer = QnnQuantizer() @@ -196,10 +201,14 @@ def build_executorch_binary( pass # default setting elif quant_dtype == QuantDtype.use_16a16w: quantizer.add_16bit_quant_ops(quantizer.SUPPORTED_OPS) - quantizer.set_bit16_op_quant_config(get_default_16bit_qnn_ptq_config(act_observer=act_observer)) + quantizer.set_bit16_op_quant_config( + get_default_16bit_qnn_ptq_config(act_observer=act_observer) + ) elif quant_dtype == QuantDtype.use_16a4w: quantizer.add_16bit_quant_ops(quantizer.SUPPORTED_OPS) - quantizer.set_bit16_op_quant_config(get_16a4w_qnn_ptq_config(act_observer=act_observer)) + quantizer.set_bit16_op_quant_config( + get_16a4w_qnn_ptq_config(act_observer=act_observer) + ) quantizer.set_per_channel_weight_dtype(weight_dtype_for_16bit_act="int4") else: raise AssertionError(f"No support for QuantDtype {quant_dtype}.") @@ -270,6 +279,7 @@ def build_executorch_binary( constant_methods=metadata, compile_config=EdgeCompileConfig(_check_ir_validity=False), ) + edge_prog_mgr = edge_prog_mgr.to_backend(qnn_partitioner) exec_prog_mgr = edge_prog_mgr.to_executorch(config=executorch_config) with open(f"{file_name}.pte", "wb") as file: