Skip to content

Commit aaada7f

Browse files
author
chiwwang
committed
Publish current llama2-7b status
example/qualcomm/llama2/llama.py can be used like: ``` python examples/qualcomm/llama2/llama.py -a llama_only_quant -b build_android -m SM8650 --ptq 16a4w --tokenizer_model tokenizer.model --checkpoint stories110M.pt --params params.json --tokenizer_bin tokenizer.bin --prompt Once ``` Note that we don't have a runner for llama2 without split. What we did to optimize performance on HTP is listed: 1. One multihead attentions is transformed to multiple single head. 2. KV-cache is changed to graph I/O. The update is performed in qnn_llama_runner.cpp on CPU. 3. llama2 is partitioned to 6 pte files in examples/qualcomm/llama2/composite_llama.py 4. Embedding is quantized. This might need further investigation, e.g., can we move it out of the model on CPU..etc 5. Support u16 and u8 mixed-precision quantization. 6. KV-cache is left as quantized format in graph I/O. 7. RMSNorm is tweaked a bit to reduce the quantization sensitivity. 8. HTP Spill-Fill buffer feature is used among pte files. 9. Convert all Linear layers to Conv2d. 10 Properly set quant_min and quant_max in Observers to offset=128 in symmetrical quantization.
1 parent 4008600 commit aaada7f

40 files changed

+3904
-442
lines changed

backends/qualcomm/builders/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
op_avg_pool2d,
1111
op_batch_norm,
1212
op_bmm,
13-
op_cast,
1413
op_cat,
1514
op_ceil,
1615
op_clamp,
@@ -41,11 +40,13 @@
4140
op_skip_ops,
4241
op_slice_copy,
4342
op_softmax,
43+
op_split,
4444
op_sqrt,
4545
op_squeeze,
4646
op_sub,
4747
op_sum_int_list,
4848
op_tanh,
49+
op_to,
4950
op_transpose,
5051
op_unsqueeze,
5152
op_upsample_bilinear2d,
@@ -57,7 +58,6 @@
5758
op_avg_pool2d,
5859
op_batch_norm,
5960
op_bmm,
60-
op_cast,
6161
op_cat,
6262
op_ceil,
6363
op_clamp,
@@ -87,11 +87,13 @@
8787
op_skip_ops,
8888
op_slice_copy,
8989
op_softmax,
90+
op_split,
9091
op_squeeze,
9192
op_sqrt,
9293
op_sub,
9394
op_sum_int_list,
9495
op_tanh,
96+
op_to,
9597
op_transpose,
9698
op_unsqueeze,
9799
op_upsample_bilinear2d,

backends/qualcomm/builders/node_visitor.py

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,13 @@
1414

1515
from executorch.exir.dialects._ops import ops as exir_ops
1616

17-
from .utils import get_parameter, is_graph_input, is_graph_output, is_parameter
17+
from .utils import (
18+
deduce_dtype,
19+
get_parameter,
20+
is_graph_input,
21+
is_graph_output,
22+
is_parameter,
23+
)
1824

1925

2026
QNN_QUANT_TYPE_MAP = {
@@ -215,24 +221,9 @@ def get_data_type(
215221
self,
216222
tensor: torch.Tensor,
217223
quant_config: Dict,
218-
is_tensor: bool,
219224
) -> PyQnnWrapper.Qnn_TensorType_t:
220-
if quant_config and is_tensor:
221-
quant_range = quant_config["quant_max"] - quant_config["quant_min"]
222-
unsigned = quant_config["quant_min"] >= 0
223-
if quant_range <= torch.iinfo(torch.int8).max - torch.iinfo(torch.int8).min:
224-
if unsigned:
225-
quant_config["dtype"] = torch.uint8
226-
else:
227-
quant_config["dtype"] = torch.int8
228-
elif (
229-
quant_range
230-
<= torch.iinfo(torch.int16).max - torch.iinfo(torch.int16).min
231-
):
232-
if unsigned:
233-
quant_config["dtype"] = torch.uint16
234-
else:
235-
quant_config["dtype"] = torch.int16
225+
if quant_config:
226+
quant_config["dtype"] = deduce_dtype(tensor, quant_config)
236227
return QNN_QUANT_TYPE_MAP[quant_config["dtype"]]
237228
else:
238229
return QNN_TENSOR_TYPE_MAP[tensor.dtype]
@@ -277,7 +268,7 @@ def define_tensor(
277268
nodes_to_wrappers: Dict[str, PyQnnWrapper.TensorWrapper],
278269
is_input_tensor: bool,
279270
node_name: str = None,
280-
is_tensor: bool = True,
271+
wrapper_idx: int = 0,
281272
) -> PyQnnWrapper.TensorWrapper:
282273
"""
283274
Covert torch.Tensor to TensorWrapper
@@ -293,17 +284,20 @@ def define_tensor(
293284
if node_name is None:
294285
node_name = node.name
295286

296-
if node_name in nodes_to_wrappers:
297-
return nodes_to_wrappers[node_name]
287+
if cached := nodes_to_wrappers[node_name].get(wrapper_idx, None):
288+
return cached
289+
298290
tensor_name = node.name
291+
if is_graph_input(node, self.edge_program):
292+
tensor_name = "QnnInput_" + str(self.external_ids[node]) + "_" + tensor_name
299293
if is_graph_output(node):
300294
tensor_name = "output_" + tensor_name
301295
dims = [1] if len(tensor.size()) == 0 else tensor.size()
302296
tensor_type = self.get_tensor_type(node, tensor_type)
303297
quant_encoding, quant_configs = self.get_quant_encoding_conf(
304298
node, is_input_tensor
305299
)
306-
dtype = self.get_data_type(tensor, quant_configs, is_tensor)
300+
dtype = self.get_data_type(tensor, quant_configs)
307301
if isinstance(tensor, torch._subclasses.fake_tensor.FakeTensor):
308302
tensor_wrapper = PyQnnWrapper.TensorWrapper(
309303
tensor_name,
@@ -334,7 +328,7 @@ def define_tensor(
334328
tensor.detach().numpy(),
335329
True,
336330
)
337-
nodes_to_wrappers[node_name] = tensor_wrapper
331+
nodes_to_wrappers[node_name][wrapper_idx] = tensor_wrapper
338332
return tensor_wrapper
339333

340334
def define_node(

backends/qualcomm/builders/op_cast.py

Lines changed: 0 additions & 57 deletions
This file was deleted.

backends/qualcomm/builders/op_embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def define_node(
3434
weight_tensor,
3535
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
3636
nodes_to_wrappers,
37-
is_input_tensor=False,
37+
is_input_tensor=True,
3838
)
3939

4040
indices_node = node.args[1]

backends/qualcomm/builders/op_skip_ops.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,5 +46,7 @@ def define_node(
4646
raise AssertionError(
4747
f"Invalid number of index for {node.name }: {len(node.args[1])}"
4848
)
49-
nodes_to_wrappers[node.name] = nodes_to_wrappers.get(node.args[0].name)
49+
nodes_to_wrappers[node.name] = {
50+
0: nodes_to_wrappers.get(node.args[0].name).get(node.args[1])
51+
}
5052
return

backends/qualcomm/builders/op_slice_copy.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,9 @@ def define_node(
6161
ranges = []
6262
for i in range(input_tensor_rank):
6363
if i == dim:
64-
ranges.extend([start, end, 1])
64+
# find step
65+
step = node.args[4] if len(node.args) > 4 else 1
66+
ranges.extend([start, end, step])
6567
else:
6668
ranges.extend([0, input_tensor.shape[i], 1])
6769

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from typing import cast, Dict
7+
8+
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
9+
10+
import numpy as np
11+
import torch
12+
13+
from .node_visitor import NodeVisitor, register_node_visitor
14+
from .qnn_constants import OpSplit, QNN_OP_PACKAGE_NAME_QTI_AISW
15+
16+
17+
@register_node_visitor
18+
class Split(NodeVisitor):
19+
target = ["aten.split_with_sizes.default"]
20+
21+
def __init__(self, *args) -> None:
22+
super().__init__(*args)
23+
24+
def define_node(
25+
self,
26+
node: torch.fx.Node,
27+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
28+
) -> PyQnnWrapper.PyQnnOpWrapper:
29+
input_node = node.args[0]
30+
input_tensor = self.get_tensor(input_node, node)
31+
input_tensor_wrapper = self.define_tensor(
32+
input_node,
33+
input_tensor,
34+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
35+
nodes_to_wrappers,
36+
is_input_tensor=True,
37+
)
38+
split_input_tensors = [input_tensor_wrapper]
39+
40+
axis = 0 if len(node.args) < 3 else cast(int, node.args[2])
41+
if axis < 0:
42+
axis = axis % len(input_tensor.shape)
43+
if "axis_order" in node.meta:
44+
axis = node.meta["axis_order"].index(axis)
45+
46+
# this is not the general case, only a quick workaround here
47+
index = np.arange(1, input_tensor.shape[axis], dtype=np.uint32)
48+
index_shape = [len(index)]
49+
50+
split_output_tensors = []
51+
for i in range(input_tensor.shape[axis]):
52+
output_tensor = self.get_tensor(node, node, i)
53+
output_tensor_wrapper = self.define_tensor(
54+
node,
55+
output_tensor,
56+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
57+
nodes_to_wrappers,
58+
is_input_tensor=False,
59+
wrapper_idx=i,
60+
)
61+
split_output_tensors.append(output_tensor_wrapper)
62+
63+
split_op = PyQnnWrapper.PyQnnOpWrapper(
64+
node.name,
65+
QNN_OP_PACKAGE_NAME_QTI_AISW,
66+
OpSplit.op_name,
67+
)
68+
split_op.AddInputTensors(split_input_tensors)
69+
split_op.AddOutputTensors(split_output_tensors)
70+
71+
split_op.AddScalarParam(
72+
OpSplit.param_axis,
73+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
74+
{"data": np.uint32(axis)},
75+
)
76+
split_op.AddTensorParam(
77+
OpSplit.param_split_index,
78+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
79+
len(index_shape),
80+
index_shape,
81+
index,
82+
True,
83+
)
84+
85+
return split_op

0 commit comments

Comments
 (0)