Skip to content

[Draft] Qualcomm AI Engine Direct - Support kv_cached llama2 model #2966

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions backends/qualcomm/builders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
op_avg_pool2d,
op_batch_norm,
op_bmm,
op_cast,
op_cat,
op_ceil,
op_clamp,
Expand Down Expand Up @@ -41,9 +40,13 @@
op_skip_ops,
op_slice_copy,
op_softmax,
op_split,
op_sqrt,
op_squeeze,
op_sub,
op_sum_int_list,
op_tanh,
op_to,
op_transpose,
op_unsqueeze,
op_upsample_bilinear2d,
Expand All @@ -55,7 +58,6 @@
op_avg_pool2d,
op_batch_norm,
op_bmm,
op_cast,
op_cat,
op_ceil,
op_clamp,
Expand Down Expand Up @@ -85,9 +87,13 @@
op_skip_ops,
op_slice_copy,
op_softmax,
op_split,
op_squeeze,
op_sqrt,
op_sub,
op_sum_int_list,
op_tanh,
op_to,
op_transpose,
op_unsqueeze,
op_upsample_bilinear2d,
Expand Down
61 changes: 24 additions & 37 deletions backends/qualcomm/builders/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,13 @@

from executorch.exir.dialects._ops import ops as exir_ops

from .qnn_constants import QNN_uint16

from .utils import get_parameter, is_graph_input, is_graph_output, is_parameter
from .utils import (
deduce_dtype,
get_parameter,
is_graph_input,
is_graph_output,
is_parameter,
)


QNN_QUANT_TYPE_MAP = {
Expand All @@ -26,16 +30,17 @@
# Note that there is no int64 tensor data type in Qnn.
torch.int64: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UNDEFINED,
torch.uint8: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_8,
QNN_uint16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_16,
torch.uint16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_16,
}
QNN_TENSOR_TYPE_MAP = {
torch.bool: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8,
torch.float32: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
torch.int8: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_8,
torch.int16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_16,
torch.int32: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_32,
torch.int64: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_64,
torch.uint8: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_8,
QNN_uint16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_16,
torch.uint16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_16,
float: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
}

Expand Down Expand Up @@ -169,7 +174,7 @@ def get_quant_encoding_conf(
return self.make_qnn_per_tensor_config(quant_attrs)

def get_quant_tensor_value(
self, tensor: torch.Tensor, quant_attrs: Dict, dtype, bitwidth
self, tensor: torch.Tensor, quant_attrs: Dict, quant_configs: Dict
) -> torch.Tensor:
if quant_attrs["encoding"] in PER_TENSOR_ENCODING:
scale = quant_attrs["scale"]
Expand All @@ -178,16 +183,11 @@ def get_quant_tensor_value(
scale = quant_attrs["scales"]
zero_point = quant_attrs["zero_points"]

# To bypass torch.uint16 quantization is not supported
dtype = (
torch.int32
if dtype == PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_16
else quant_attrs["dtype"]
)
dtype = quant_configs["dtype"]

tensor = tensor.div(scale).add(zero_point).round().to(dtype)
# Make the backends access data correctly
if bitwidth == 4:
if quant_configs.get("bitwidth") == 4:
mask = torch.full(tensor.size(), 0x0F, dtype=torch.int8)
tensor = torch.bitwise_and(mask, tensor)
return tensor
Expand Down Expand Up @@ -221,24 +221,9 @@ def get_data_type(
self,
tensor: torch.Tensor,
quant_config: Dict,
is_tensor: bool,
) -> PyQnnWrapper.Qnn_TensorType_t:
if quant_config and is_tensor:
quant_range = quant_config["quant_max"] - quant_config["quant_min"]
unsigned = quant_config["quant_min"] >= 0
if quant_range <= torch.iinfo(torch.int8).max - torch.iinfo(torch.int8).min:
if unsigned:
quant_config["dtype"] = torch.uint8
else:
quant_config["dtype"] = torch.int8
elif (
quant_range
<= torch.iinfo(torch.int16).max - torch.iinfo(torch.int16).min
):
if unsigned:
quant_config["dtype"] = QNN_uint16
else:
quant_config["dtype"] = torch.int16
if quant_config:
quant_config["dtype"] = deduce_dtype(tensor, quant_config)
return QNN_QUANT_TYPE_MAP[quant_config["dtype"]]
else:
return QNN_TENSOR_TYPE_MAP[tensor.dtype]
Expand Down Expand Up @@ -283,7 +268,7 @@ def define_tensor(
nodes_to_wrappers: Dict[str, PyQnnWrapper.TensorWrapper],
is_input_tensor: bool,
node_name: str = None,
is_tensor: bool = True,
wrapper_idx: int = 0,
) -> PyQnnWrapper.TensorWrapper:
"""
Covert torch.Tensor to TensorWrapper
Expand All @@ -299,17 +284,20 @@ def define_tensor(
if node_name is None:
node_name = node.name

if node_name in nodes_to_wrappers:
return nodes_to_wrappers[node_name]
if cached := nodes_to_wrappers[node_name].get(wrapper_idx, None):
return cached

tensor_name = node.name
if is_graph_input(node, self.edge_program):
tensor_name = "QnnInput_" + str(self.external_ids[node]) + "_" + tensor_name
if is_graph_output(node):
tensor_name = "output_" + tensor_name
dims = [1] if len(tensor.size()) == 0 else tensor.size()
tensor_type = self.get_tensor_type(node, tensor_type)
quant_encoding, quant_configs = self.get_quant_encoding_conf(
node, is_input_tensor
)
dtype = self.get_data_type(tensor, quant_configs, is_tensor)
dtype = self.get_data_type(tensor, quant_configs)
if isinstance(tensor, torch._subclasses.fake_tensor.FakeTensor):
tensor_wrapper = PyQnnWrapper.TensorWrapper(
tensor_name,
Expand All @@ -327,8 +315,7 @@ def define_tensor(
tensor = self.get_quant_tensor_value(
tensor,
node.meta["quant_attrs"],
dtype,
quant_configs.get("bitwidth"),
quant_configs,
)
tensor_wrapper = PyQnnWrapper.TensorWrapper(
tensor_name,
Expand All @@ -341,7 +328,7 @@ def define_tensor(
tensor.detach().numpy(),
True,
)
nodes_to_wrappers[node_name] = tensor_wrapper
nodes_to_wrappers[node_name][wrapper_idx] = tensor_wrapper
return tensor_wrapper

def define_node(
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/builders/op_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def define_node(
weight_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
nodes_to_wrappers,
is_input_tensor=False,
is_input_tensor=True,
)

indices_node = node.args[1]
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/builders/op_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def define_node(
bias_node = node.args[2]

# TODO remove this when qnn sdk support
if "scales" in bias_node.meta.get("quant_attrs"):
if "scales" in bias_node.meta.get("quant_attrs", {}):
print(
f"[WARNING] Fallback linear bias, {bias_node}. per channel bias quantization is not support yet."
)
Expand Down
1 change: 0 additions & 1 deletion backends/qualcomm/builders/op_log_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,5 +72,4 @@ def define_node(
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
{"data": np.uint32(dim)},
)
# pdb.set_trace()
return log_softmax_op
4 changes: 3 additions & 1 deletion backends/qualcomm/builders/op_skip_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,7 @@ def define_node(
raise AssertionError(
f"Invalid number of index for {node.name }: {len(node.args[1])}"
)
nodes_to_wrappers[node.name] = nodes_to_wrappers.get(node.args[0].name)
nodes_to_wrappers[node.name] = {
0: nodes_to_wrappers.get(node.args[0].name).get(node.args[1])
}
return
4 changes: 3 additions & 1 deletion backends/qualcomm/builders/op_slice_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def define_node(
ranges = []
for i in range(input_tensor_rank):
if i == dim:
ranges.extend([start, end, 1])
# find step
step = node.args[4] if len(node.args) > 4 else 1
ranges.extend([start, end, step])
else:
ranges.extend([0, input_tensor.shape[i], 1])

Expand Down
85 changes: 85 additions & 0 deletions backends/qualcomm/builders/op_split.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import cast, Dict

import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper

import numpy as np
import torch

from .node_visitor import NodeVisitor, register_node_visitor
from .qnn_constants import OpSplit, QNN_OP_PACKAGE_NAME_QTI_AISW


@register_node_visitor
class Split(NodeVisitor):
target = ["aten.split_with_sizes.default"]

def __init__(self, *args) -> None:
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
input_node = node.args[0]
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
is_input_tensor=True,
)
split_input_tensors = [input_tensor_wrapper]

axis = 0 if len(node.args) < 3 else cast(int, node.args[2])
if axis < 0:
axis = axis % len(input_tensor.shape)
if "axis_order" in node.meta:
axis = node.meta["axis_order"].index(axis)

# this is not the general case, only a quick workaround here
index = np.arange(1, input_tensor.shape[axis], dtype=np.uint32)
index_shape = [len(index)]

split_output_tensors = []
for i in range(input_tensor.shape[axis]):
output_tensor = self.get_tensor(node, node, i)
output_tensor_wrapper = self.define_tensor(
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
is_input_tensor=False,
wrapper_idx=i,
)
split_output_tensors.append(output_tensor_wrapper)

split_op = PyQnnWrapper.PyQnnOpWrapper(
node.name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpSplit.op_name,
)
split_op.AddInputTensors(split_input_tensors)
split_op.AddOutputTensors(split_output_tensors)

split_op.AddScalarParam(
OpSplit.param_axis,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
{"data": np.uint32(axis)},
)
split_op.AddTensorParam(
OpSplit.param_split_index,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
len(index_shape),
index_shape,
index,
True,
)

return split_op
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
import torch

from .node_visitor import NodeVisitor, register_node_visitor
from .qnn_constants import OpCast, QNN_OP_PACKAGE_NAME_QTI_AISW
from .qnn_constants import OpSqrt, QNN_OP_PACKAGE_NAME_QTI_AISW


@register_node_visitor
class Cast(NodeVisitor):
target = ["aten._to_copy.default"]
class SQRT(NodeVisitor):
target = ["aten.sqrt.default"]

def __init__(self, *args) -> None:
super().__init__(*args)
Expand All @@ -25,6 +25,7 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
# tensor input
input_node = node.args[0]
input_tensor = self.get_tensor(input_node, node)

Expand All @@ -35,23 +36,24 @@ def define_node(
nodes_to_wrappers,
is_input_tensor=True,
)
sqrt_input_tensors = [input_tensor_wrapper]

output_tensor = self.get_tensor(node, node)

out_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
node,
output_tensor,
out_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
is_input_tensor=False,
)
sqrt_output_tensors = [output_tensor_wrapper]

cast_op = PyQnnWrapper.PyQnnOpWrapper(
sqrt_op = PyQnnWrapper.PyQnnOpWrapper(
node.name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpCast.op_name,
OpSqrt.op_name,
)
cast_op.AddInputTensors([input_tensor_wrapper])
cast_op.AddOutputTensors([output_tensor_wrapper])
sqrt_op.AddInputTensors(sqrt_input_tensors)
sqrt_op.AddOutputTensors(sqrt_output_tensors)

return cast_op
return sqrt_op
Loading