Skip to content

Commit 3b6af64

Browse files
committed
resolve uint16 type and reorder input in runtime
1 parent 74adfc1 commit 3b6af64

File tree

10 files changed

+36
-32
lines changed

10 files changed

+36
-32
lines changed

backends/qualcomm/builders/node_visitor.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414

1515
from executorch.exir.dialects._ops import ops as exir_ops
1616

17-
from .qnn_constants import QNN_uint16
18-
1917
from .utils import get_parameter, is_graph_input, is_graph_output, is_parameter
2018

2119

@@ -26,7 +24,7 @@
2624
# Note that there is no int64 tensor data type in Qnn.
2725
torch.int64: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UNDEFINED,
2826
torch.uint8: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_8,
29-
QNN_uint16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_16,
27+
torch.uint16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_16,
3028
}
3129
QNN_TENSOR_TYPE_MAP = {
3230
torch.float32: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
@@ -35,7 +33,7 @@
3533
torch.int32: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_32,
3634
torch.int64: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_64,
3735
torch.uint8: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_8,
38-
QNN_uint16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_16,
36+
torch.uint16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_16,
3937
float: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
4038
}
4139

@@ -169,7 +167,7 @@ def get_quant_encoding_conf(
169167
return self.make_qnn_per_tensor_config(quant_attrs)
170168

171169
def get_quant_tensor_value(
172-
self, tensor: torch.Tensor, quant_attrs: Dict, dtype, bitwidth
170+
self, tensor: torch.Tensor, quant_attrs: Dict, quant_configs: Dict
173171
) -> torch.Tensor:
174172
if quant_attrs["encoding"] in PER_TENSOR_ENCODING:
175173
scale = quant_attrs["scale"]
@@ -178,16 +176,11 @@ def get_quant_tensor_value(
178176
scale = quant_attrs["scales"]
179177
zero_point = quant_attrs["zero_points"]
180178

181-
# To bypass torch.uint16 quantization is not supported
182-
dtype = (
183-
torch.int32
184-
if dtype == PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_16
185-
else quant_attrs["dtype"]
186-
)
179+
dtype = quant_configs["dtype"]
187180

188181
tensor = tensor.div(scale).add(zero_point).round().to(dtype)
189182
# Make the backends access data correctly
190-
if bitwidth == 4:
183+
if quant_configs.get("bitwidth") == 4:
191184
mask = torch.full(tensor.size(), 0x0F, dtype=torch.int8)
192185
tensor = torch.bitwise_and(mask, tensor)
193186
return tensor
@@ -236,7 +229,7 @@ def get_data_type(
236229
<= torch.iinfo(torch.int16).max - torch.iinfo(torch.int16).min
237230
):
238231
if unsigned:
239-
quant_config["dtype"] = QNN_uint16
232+
quant_config["dtype"] = torch.uint16
240233
else:
241234
quant_config["dtype"] = torch.int16
242235
return QNN_QUANT_TYPE_MAP[quant_config["dtype"]]
@@ -304,6 +297,8 @@ def define_tensor(
304297
return cached
305298

306299
tensor_name = node.name
300+
if is_graph_input(node, self.edge_program):
301+
tensor_name = "QnnInput_"+str(self.external_ids[node])+"_"+ tensor_name
307302
if is_graph_output(node):
308303
tensor_name = "output_" + tensor_name
309304
dims = [1] if len(tensor.size()) == 0 else tensor.size()
@@ -329,8 +324,7 @@ def define_tensor(
329324
tensor = self.get_quant_tensor_value(
330325
tensor,
331326
node.meta["quant_attrs"],
332-
dtype,
333-
quant_configs.get("bitwidth"),
327+
quant_configs,
334328
)
335329
tensor_wrapper = PyQnnWrapper.TensorWrapper(
336330
tensor_name,

backends/qualcomm/builders/op_split.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616

1717
@register_node_visitor
18-
class Softmax(NodeVisitor):
18+
class Split(NodeVisitor):
1919
target = ["aten.split_with_sizes.default"]
2020

2121
def __init__(self, *args) -> None:

backends/qualcomm/builders/qnn_constants.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from enum import IntEnum, unique
99

1010
QNN_OP_PACKAGE_NAME_QTI_AISW = "qti.aisw"
11-
QNN_uint16 = "uint16"
1211

1312
# Below constants should be same as those in QNN headers.
1413
# Maybe someday we should expose these constants by pybind

backends/qualcomm/quantizer/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,14 +113,14 @@ def get_default_8bit_qnn_ptq_config() -> QuantizationConfig:
113113

114114

115115
# 4 bits quantization only supports specific ops.
116-
def get_16a4w_qnn_ptq_config() -> QuantizationConfig:
116+
def get_16a4w_qnn_ptq_config(act_observer=MovingAverageMinMaxObserver) -> QuantizationConfig:
117117
extra_args: Dict[str, Any] = {"eps": 2**-20}
118118
act_quantization_spec = QuantizationSpec(
119119
dtype=torch.int32,
120120
quant_min=torch.iinfo(torch.uint16).min,
121121
quant_max=torch.iinfo(torch.uint16).max,
122122
qscheme=torch.per_tensor_affine,
123-
observer_or_fake_quant_ctr=MovingAverageMinMaxObserver.with_args(**extra_args),
123+
observer_or_fake_quant_ctr=act_observer.with_args(**extra_args),
124124
)
125125

126126
weight_quantization_spec = QuantizationSpec(
@@ -150,14 +150,14 @@ def get_16a4w_qnn_ptq_config() -> QuantizationConfig:
150150
return quantization_config
151151

152152

153-
def get_default_16bit_qnn_ptq_config() -> QuantizationConfig:
153+
def get_default_16bit_qnn_ptq_config(act_observer=MovingAverageMinMaxObserver) -> QuantizationConfig:
154154
extra_args: Dict[str, Any] = {"eps": 2**-20}
155155
act_quantization_spec = QuantizationSpec(
156156
dtype=torch.int32,
157157
quant_min=torch.iinfo(torch.uint16).min,
158158
quant_max=torch.iinfo(torch.uint16).max,
159159
qscheme=torch.per_tensor_affine,
160-
observer_or_fake_quant_ctr=MovingAverageMinMaxObserver.with_args(**extra_args),
160+
observer_or_fake_quant_ctr=act_observer.with_args(**extra_args),
161161
)
162162

163163
weight_quantization_spec = QuantizationSpec(

backends/qualcomm/runtime/QnnExecuTorchBackend.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
#include <executorch/backends/qualcomm/runtime/QnnExecuTorchBackend.h>
1212
#include <executorch/backends/qualcomm/runtime/QnnManager.h>
1313
#include <executorch/backends/qualcomm/schema_generated.h>
14-
14+
#include <algorithm>
1515
#include <string>
1616
namespace torch {
1717
namespace executor {
@@ -20,6 +20,12 @@ using namespace qnn;
2020
using namespace qnn_delegate;
2121
constexpr const char* QNN_COMPILE_SPEC = "qnn_compile_spec";
2222

23+
bool CompareQnnInput(const std::shared_ptr<TensorWrapper>& a, const std::shared_ptr<TensorWrapper>& b) {
24+
int numA = std::stoi(a->GetName().substr(a->GetName().find('_') + 1));
25+
int numB = std::stoi(b->GetName().substr(b->GetName().find('_') + 1));
26+
return numA < numB;
27+
}
28+
2329
Result<DelegateHandle*> QnnExecuTorchBackend::init(
2430
BackendInitContext& context,
2531
FreeableBuffer* processed,
@@ -187,6 +193,9 @@ Error QnnExecuTorchBackend::execute(
187193
qnn_manager->GetGraphOutputs();
188194
std::vector<Qnn_Tensor_t> input_tensor_structs;
189195
std::vector<Qnn_Tensor_t> output_tensor_structs;
196+
// Using the order of the nodes as external_id in AOT
197+
// to extract the right arg from *args at runtime
198+
std::sort(input_tensors.begin(), input_tensors.end(), CompareQnnInput);
190199

191200
input_tensor_structs.reserve(input_tensors.size());
192201
for (int i = 0; i < input_tensors.size(); ++i) {

backends/qualcomm/runtime/backends/QnnBackendCache.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,8 @@ QnnBackendCache::QnnBackendCache(
8787
state_ = SERIALIZE;
8888
QNN_EXECUTORCH_LOG_INFO("Caching: Caching is in SAVE MODE.");
8989
return;
90-
} else {
90+
}
91+
/*else {
9192
// TODO: need fix on this since qnn context binary could somehow
9293
// pass the check of flatbuffer verifier
9394
// check if context binary came from flatbuffer
@@ -100,7 +101,7 @@ QnnBackendCache::QnnBackendCache(
100101
state_ = ONLINE_PREPARE;
101102
return;
102103
}
103-
}
104+
}*/
104105

105106
if (qnn_sys_impl_.Load() != Error::Ok) {
106107
QNN_EXECUTORCH_LOG_ERROR(

examples/qualcomm/executor_runner/qnn_llama_runner.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ int main(int argc, char** argv) {
100100
if (input_files.size() == 0) {
101101
break;
102102
}
103-
// inputs: [tokens, pos_ids, atten_mask, kv_mask, k_cache, v_cache]
103+
// inputs: [tokens, pos_ids, kv_mask, *k_cache, *v_cache]
104104
// tokens are determined by command line arguments
105105
// pos_ids are infered inside runner
106106
std::vector<ManagedTensor> managed_inputs;
@@ -120,10 +120,6 @@ int main(int argc, char** argv) {
120120
tensor_meta->nbytes());
121121

122122
inputs[input_index].resize(tensor_meta->nbytes());
123-
if (input_index <= 2) {
124-
fin.seekg(0, fin.beg);
125-
fin.read(inputs[input_index].data(), file_size);
126-
}
127123
fin.close();
128124

129125
auto tensor_shape = tensor_meta->sizes();

examples/qualcomm/llama2/llama.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from functools import partial
1313

1414
import torch
15+
from torch.ao.quantization.observer import MinMaxObserver
1516

1617
from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
1718
from executorch.backends.qualcomm.utils.utils import convert_linear_to_conv2d
@@ -206,6 +207,7 @@ def sample_top_p(probs: torch.Tensor, top_p: float) -> torch.Tensor:
206207
shared_buffer=args.shared_buffer,
207208
metadata=instance.get_metadata(),
208209
direct_io=True,
210+
act_observer=MinMaxObserver
209211
)
210212

211213
if args.compile_only:

examples/qualcomm/llama2/runner/runner.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,8 @@ Result<torch::executor::Tensor> Runner::run_model_step(
139139
Tensor& start_pos,
140140
std::vector<Tensor>& input_tensors) {
141141
token.mutable_data_ptr<int32_t>()[0] = input_token;
142-
// inputs:[tokens, start_pos, atten_mask, kv_mask, k_cache, v_cache]
142+
// inputs:[tokens, start_pos, kv_mask, k_cache, v_cache]
143+
// input_tensors:[kv_mask, k_cache, v_cache]
143144
std::vector<EValue> inputs = {token, start_pos};
144145
inputs.insert(inputs.end(), input_tensors.begin(), input_tensors.end());
145146

examples/qualcomm/scripts/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import numpy as np
1616

1717
import torch
18+
from torch.ao.quantization.observer import MovingAverageMinMaxObserver
1819
from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner
1920
from executorch.backends.qualcomm.quantizer.quantizer import (
2021
get_16a4w_qnn_ptq_config,
@@ -184,6 +185,7 @@ def build_executorch_binary(
184185
direct_io=False, # TODO: temporal workaround for llama
185186
shared_buffer=False,
186187
metadata=None,
188+
act_observer=MovingAverageMinMaxObserver
187189
):
188190
if quant_dtype is not None:
189191
quantizer = QnnQuantizer()
@@ -194,10 +196,10 @@ def build_executorch_binary(
194196
pass # default setting
195197
elif quant_dtype == QuantDtype.use_16a16w:
196198
quantizer.add_16bit_quant_ops(quantizer.SUPPORTED_OPS)
197-
quantizer.set_bit16_op_quant_config(get_default_16bit_qnn_ptq_config())
199+
quantizer.set_bit16_op_quant_config(get_default_16bit_qnn_ptq_config(act_observer=act_observer))
198200
elif quant_dtype == QuantDtype.use_16a4w:
199201
quantizer.add_16bit_quant_ops(quantizer.SUPPORTED_OPS)
200-
quantizer.set_bit16_op_quant_config(get_16a4w_qnn_ptq_config())
202+
quantizer.set_bit16_op_quant_config(get_16a4w_qnn_ptq_config(act_observer=act_observer))
201203
quantizer.set_per_channel_weight_dtype(weight_dtype_for_16bit_act="int4")
202204
else:
203205
raise AssertionError(f"No support for QuantDtype {quant_dtype}.")

0 commit comments

Comments
 (0)