Skip to content

Commit 2c0be5e

Browse files
committed
Qualcomm AI Engine Direct - Enable 4 bits BW quantization
- Add QNN_QUANTIZATION_ENCODING_BW... confings for qnn wrapper - Add 4 bits quant config - Add 4 bits quant single op tests - Add per channel weight setting for quantizer - Fix convert_to_linear error - Refine quantizer
1 parent 588c391 commit 2c0be5e

20 files changed

+591
-291
lines changed

backends/qualcomm/aot/python/PyQnnWrapperAdaptor.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,28 @@ std::unique_ptr<QuantizeParamsWrapper> CreateQuantizationParamWrapper(
3232
quantize_param_wrapper =
3333
std::make_unique<AxisScaleOffsetQuantizeParamsWrapper>(
3434
axis, scale_offset);
35+
} else if (encoding == QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET) {
36+
uint32_t bitwidth = quant_info["bitwidth"].cast<uint32_t>();
37+
int32_t axis = quant_info["axis"].cast<int32_t>();
38+
std::vector<Qnn_ScaleOffset_t> scale_offset =
39+
quant_info["scale_offset"].cast<std::vector<Qnn_ScaleOffset_t>>();
40+
uint32_t num_elements = scale_offset.size();
41+
std::vector<float> scales;
42+
std::vector<int32_t> offsets;
43+
for (const auto& scale_offset : scale_offset) {
44+
scales.push_back(scale_offset.scale);
45+
offsets.push_back(scale_offset.offset);
46+
}
47+
quantize_param_wrapper =
48+
std::make_unique<BwAxisScaleOffsetQuantizeParamsWrapper>(
49+
bitwidth, axis, num_elements, scales, offsets);
50+
} else if (encoding == QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET) {
51+
uint32_t bitwidth = quant_info["bitwidth"].cast<uint32_t>();
52+
float scale = quant_info["scale"].cast<float>();
53+
int32_t offset = quant_info["offset"].cast<int32_t>();
54+
quantize_param_wrapper =
55+
std::make_unique<BwScaleOffsetQuantizeParamsWrapper>(
56+
bitwidth, scale, offset);
3557
} else if (encoding == QNN_QUANTIZATION_ENCODING_SCALE_OFFSET) {
3658
float scale = quant_info["scale"].cast<float>();
3759
int32_t offset = quant_info["offset"].cast<int32_t>();

backends/qualcomm/aot/wrappers/QuantizeParamsWrapper.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,33 @@ std::unique_ptr<QuantizeParamsWrapper> CreateQuantizationParamWrapper(
2727
quantize_param_wrapper =
2828
std::make_unique<AxisScaleOffsetQuantizeParamsWrapper>(
2929
quantization.axisScaleOffsetEncoding.axis, scale_offset);
30+
} else if (
31+
quantization.quantizationEncoding ==
32+
QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET) {
33+
std::vector<float> scales(
34+
quantization.bwAxisScaleOffsetEncoding.scales,
35+
quantization.bwAxisScaleOffsetEncoding.scales +
36+
quantization.bwAxisScaleOffsetEncoding.numElements);
37+
std::vector<int32_t> offsets(
38+
quantization.bwAxisScaleOffsetEncoding.offsets,
39+
quantization.bwAxisScaleOffsetEncoding.offsets +
40+
quantization.bwAxisScaleOffsetEncoding.numElements);
41+
42+
quantize_param_wrapper =
43+
std::make_unique<BwAxisScaleOffsetQuantizeParamsWrapper>(
44+
quantization.bwAxisScaleOffsetEncoding.bitwidth,
45+
quantization.bwAxisScaleOffsetEncoding.axis,
46+
quantization.bwAxisScaleOffsetEncoding.numElements,
47+
scales,
48+
offsets);
49+
} else if (
50+
quantization.quantizationEncoding ==
51+
QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET) {
52+
quantize_param_wrapper =
53+
std::make_unique<BwScaleOffsetQuantizeParamsWrapper>(
54+
quantization.bwScaleOffsetEncoding.bitwidth,
55+
quantization.bwScaleOffsetEncoding.scale,
56+
quantization.bwScaleOffsetEncoding.offset);
3057
} else if (
3158
quantization.quantizationEncoding ==
3259
QNN_QUANTIZATION_ENCODING_SCALE_OFFSET) {

backends/qualcomm/aot/wrappers/QuantizeParamsWrapper.h

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,117 @@ class UndefinedQuantizeParamsWrapper final : public QuantizeParamsWrapper {
7777
}
7878
};
7979

80+
class BwAxisScaleOffsetQuantizeParamsWrapper final
81+
: public QuantizeParamsWrapper {
82+
public:
83+
explicit BwAxisScaleOffsetQuantizeParamsWrapper(
84+
std::uint32_t bitwidth,
85+
std::int32_t axis,
86+
std::uint32_t num_elements,
87+
std::vector<float> scales,
88+
std::vector<int32_t> offsets)
89+
: QuantizeParamsWrapper(
90+
QNN_DEFINITION_DEFINED,
91+
QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET),
92+
bitwidth_(bitwidth),
93+
axis_(axis),
94+
num_elements_(num_elements),
95+
scales_(scales),
96+
offsets_(offsets) {}
97+
98+
BwAxisScaleOffsetQuantizeParamsWrapper(
99+
const BwAxisScaleOffsetQuantizeParamsWrapper& rhs)
100+
: QuantizeParamsWrapper(
101+
rhs.GetEncodingDefinition(),
102+
rhs.GetQuantizationEncoding()),
103+
bitwidth_(rhs.bitwidth_),
104+
axis_(rhs.axis_),
105+
num_elements_(rhs.num_elements_),
106+
scales_(rhs.scales_),
107+
offsets_(rhs.offsets_) {}
108+
BwAxisScaleOffsetQuantizeParamsWrapper(
109+
BwAxisScaleOffsetQuantizeParamsWrapper&& rhs) = delete;
110+
BwAxisScaleOffsetQuantizeParamsWrapper& operator=(
111+
const BwAxisScaleOffsetQuantizeParamsWrapper& rhs) = delete;
112+
BwAxisScaleOffsetQuantizeParamsWrapper& operator=(
113+
BwAxisScaleOffsetQuantizeParamsWrapper&& rhs) = delete;
114+
115+
~BwAxisScaleOffsetQuantizeParamsWrapper() override = default;
116+
117+
std::unique_ptr<QuantizeParamsWrapper> Clone() override {
118+
return std::make_unique<BwAxisScaleOffsetQuantizeParamsWrapper>(*this);
119+
}
120+
121+
Qnn_QuantizeParams_t CreateQuantizeParams() override {
122+
Qnn_QuantizeParams_t rval;
123+
rval.encodingDefinition = GetEncodingDefinition();
124+
rval.quantizationEncoding = GetQuantizationEncoding();
125+
rval.bwAxisScaleOffsetEncoding.bitwidth = bitwidth_;
126+
rval.bwAxisScaleOffsetEncoding.axis = axis_;
127+
rval.bwAxisScaleOffsetEncoding.numElements = num_elements_;
128+
rval.bwAxisScaleOffsetEncoding.scales = scales_.data();
129+
rval.bwAxisScaleOffsetEncoding.offsets = offsets_.data();
130+
return rval;
131+
}
132+
133+
private:
134+
std::uint32_t bitwidth_;
135+
std::int32_t axis_;
136+
std::uint32_t num_elements_;
137+
std::vector<float> scales_;
138+
std::vector<int32_t> offsets_;
139+
};
140+
141+
class BwScaleOffsetQuantizeParamsWrapper final : public QuantizeParamsWrapper {
142+
public:
143+
explicit BwScaleOffsetQuantizeParamsWrapper(
144+
std::uint32_t bitwidth,
145+
float scale,
146+
std::int32_t offset)
147+
: QuantizeParamsWrapper(
148+
QNN_DEFINITION_DEFINED,
149+
QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET),
150+
bitwidth_(bitwidth),
151+
scale_(scale),
152+
offset_(offset) {}
153+
154+
BwScaleOffsetQuantizeParamsWrapper(
155+
const BwScaleOffsetQuantizeParamsWrapper& rhs)
156+
: QuantizeParamsWrapper(
157+
rhs.GetEncodingDefinition(),
158+
rhs.GetQuantizationEncoding()),
159+
bitwidth_(rhs.bitwidth_),
160+
scale_(rhs.scale_),
161+
offset_(rhs.offset_) {}
162+
BwScaleOffsetQuantizeParamsWrapper(BwScaleOffsetQuantizeParamsWrapper&& rhs) =
163+
delete;
164+
BwScaleOffsetQuantizeParamsWrapper& operator=(
165+
const BwScaleOffsetQuantizeParamsWrapper& rhs) = delete;
166+
BwScaleOffsetQuantizeParamsWrapper& operator=(
167+
BwScaleOffsetQuantizeParamsWrapper&& rhs) = delete;
168+
169+
~BwScaleOffsetQuantizeParamsWrapper() override = default;
170+
171+
std::unique_ptr<QuantizeParamsWrapper> Clone() override {
172+
return std::make_unique<BwScaleOffsetQuantizeParamsWrapper>(*this);
173+
}
174+
175+
Qnn_QuantizeParams_t CreateQuantizeParams() override {
176+
Qnn_QuantizeParams_t rval;
177+
rval.encodingDefinition = GetEncodingDefinition();
178+
rval.quantizationEncoding = GetQuantizationEncoding();
179+
rval.bwScaleOffsetEncoding.bitwidth = bitwidth_;
180+
rval.bwScaleOffsetEncoding.scale = scale_;
181+
rval.bwScaleOffsetEncoding.offset = offset_;
182+
return rval;
183+
}
184+
185+
private:
186+
std::uint32_t bitwidth_;
187+
float scale_;
188+
std::int32_t offset_;
189+
};
190+
80191
class ScaleOffsetQuantizeParamsWrapper final : public QuantizeParamsWrapper {
81192
public:
82193
explicit ScaleOffsetQuantizeParamsWrapper(float scale, std::int32_t offset)

backends/qualcomm/builders/node_visitor.py

Lines changed: 101 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import copy
78
from typing import Any, Dict, Tuple
89

910
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
@@ -38,16 +39,16 @@
3839
float: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
3940
}
4041

41-
PER_CHANNEL_ENCODING_MAPPING = {
42-
exir_ops.edge.quantized_decomposed.quantize_per_channel.default: PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET,
43-
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default: PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET,
42+
PER_CHANNEL_ENCODING = {
43+
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
44+
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
4445
}
4546

46-
PER_TENSOR_ENCODING_MAPPING = {
47-
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_SCALE_OFFSET,
48-
exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor: PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_SCALE_OFFSET,
49-
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_SCALE_OFFSET,
50-
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor: PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_SCALE_OFFSET,
47+
PER_TENSOR_ENCODING = {
48+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
49+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor,
50+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
51+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
5152
}
5253

5354

@@ -87,6 +88,68 @@ def _get_tensor(node, index):
8788
tensor = tensor.permute(dims=op_node.meta["axis_order"]).contiguous()
8889
return tensor
8990

91+
def make_qnn_per_channel_config(self, node: torch.fx.Node, quant_attrs: Dict):
92+
quant_config = copy.deepcopy(quant_attrs)
93+
94+
scales = quant_attrs["scales"]
95+
zero_points = quant_attrs["zero_points"]
96+
assert len(scales) == len(
97+
zero_points
98+
), f"Per channel encoding of node {node}, has different size for scales {len(scales)} and zero_points {len(zero_points)}"
99+
100+
scale_offset = []
101+
for i in range(len(scales)):
102+
# check Qnn_ScaleOffset_t in QNN/include/QnnTypes.h
103+
scale_offset.append(
104+
PyQnnWrapper.Qnn_ScaleOffset_t(scales[i], -zero_points[i])
105+
)
106+
107+
user_0 = list(node.users)[0]
108+
# Memory layout of QNN conv weight always ends in Output. Like conv2d is HWIO
109+
if (
110+
"convolution" in user_0.target.__name__
111+
and list(node.users)[0].args[1] == node
112+
):
113+
quant_config["axis"] = 3
114+
115+
else:
116+
quant_config["axis"] = quant_attrs["axis"]
117+
118+
quant_config["scale_offset"] = scale_offset
119+
# special case for 4 bits
120+
if (
121+
quant_config["dtype"] == torch.int8
122+
and quant_config["quant_max"] - quant_config["quant_min"] <= 15
123+
):
124+
quant_config["bitwidth"] = 4
125+
return (
126+
PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET,
127+
quant_config,
128+
)
129+
return (
130+
PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET,
131+
quant_config,
132+
)
133+
134+
def make_qnn_per_tensor_config(self, quant_attrs: Dict):
135+
quant_config = copy.deepcopy(quant_attrs)
136+
# check Qnn_ScaleOffset_t in QNN/include/QnnTypes.h
137+
quant_config["offset"] = -quant_attrs["zero_point"]
138+
# special case for 4 bits
139+
if (
140+
quant_config["dtype"] == torch.int8
141+
and quant_config["quant_max"] - quant_config["quant_min"] <= 15
142+
):
143+
quant_config["bitwidth"] = 4
144+
return (
145+
PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET,
146+
quant_config,
147+
)
148+
return (
149+
PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_SCALE_OFFSET,
150+
quant_config,
151+
)
152+
90153
def get_quant_encoding_conf(self, node: torch.fx.Node) -> Tuple[Any, Dict]:
91154
if not node.meta.get("quant_attrs", None):
92155
return (
@@ -99,66 +162,35 @@ def get_quant_encoding_conf(self, node: torch.fx.Node) -> Tuple[Any, Dict]:
99162
if "requantize" in node.meta
100163
else node.meta["quant_attrs"]
101164
)
102-
encoding = quant_attrs["encoding"]
103-
104-
quant_config = {}
105-
if encoding in PER_CHANNEL_ENCODING_MAPPING:
106-
scales = quant_attrs["scales"]
107-
zero_points = quant_attrs["zero_points"]
108-
assert len(scales) == len(
109-
zero_points
110-
), f"Per channel encoding of node {node}, has differnt size fo scales {len(scales)} and zero_points {len(zero_points)}"
111-
112-
scale_offset = []
113-
for i in range(len(scales)):
114-
scale_offset.append(
115-
PyQnnWrapper.Qnn_ScaleOffset_t(scales[i], -zero_points[i])
116-
)
117165

118-
user_0 = list(node.users)[0]
119-
# Memory layout of QNN conv is NHW"C", need to set axis as 3
120-
if (
121-
type(user_0.target) != str
122-
and user_0.target.__name__ in ["aten.convolution.default"]
123-
and list(node.users)[0].args[1] == node
124-
):
125-
quant_config["axis"] = 3
126-
else:
127-
quant_config["axis"] = quant_attrs["axis"]
128-
129-
quant_config["scale_offset"] = scale_offset
130-
quant_config["quant_max"] = quant_attrs["quant_max"]
131-
quant_config["quant_min"] = quant_attrs["quant_min"]
132-
quant_config["dtype"] = quant_attrs["dtype"]
133-
return PER_CHANNEL_ENCODING_MAPPING[encoding], quant_config
134-
135-
# per tensor situation
136-
quant_config["scale"] = quant_attrs["scale"]
137-
# check Qnn_ScaleOffset_t in QNN/include/QnnTypes.h
138-
quant_config["offset"] = -quant_attrs["zero_point"]
139-
# Distinguish what data type the node is
140-
quant_config["quant_max"] = quant_attrs["quant_max"]
141-
quant_config["quant_min"] = quant_attrs["quant_min"]
142-
quant_config["dtype"] = quant_attrs["dtype"]
143-
return PER_TENSOR_ENCODING_MAPPING[encoding], quant_config
166+
if quant_attrs["encoding"] in PER_CHANNEL_ENCODING:
167+
return self.make_qnn_per_channel_config(node, quant_attrs)
168+
169+
return self.make_qnn_per_tensor_config(quant_attrs)
144170

145171
def get_quant_tensor_value(
146-
self, node: torch.fx.Node, tensor: torch.Tensor, dtype
172+
self, tensor: torch.Tensor, quant_attrs: Dict, dtype, bitwidth
147173
) -> torch.Tensor:
148-
quant_attrs = node.meta["quant_attrs"]
149-
encoding = quant_attrs["encoding"]
150-
151-
if encoding in PER_CHANNEL_ENCODING_MAPPING:
152-
scales = quant_attrs["scales"]
153-
offsets = quant_attrs["zero_points"]
154-
return tensor.div(scales).add(offsets).round().to(quant_attrs["dtype"])
174+
if quant_attrs["encoding"] in PER_TENSOR_ENCODING:
175+
scale = quant_attrs["scale"]
176+
zero_point = quant_attrs["zero_point"]
177+
else: # per channel case
178+
scale = quant_attrs["scales"]
179+
zero_point = quant_attrs["zero_points"]
180+
181+
# To bypass torch.uint16 quantization is not supported
182+
dtype = (
183+
torch.int32
184+
if dtype == PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_16
185+
else quant_attrs["dtype"]
186+
)
155187

156-
# per tensor situation
157-
scale = quant_attrs["scale"]
158-
offset = quant_attrs["zero_point"]
159-
if dtype == PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_16:
160-
return tensor.div(scale).add(offset).round().to(torch.int32)
161-
return tensor.div(scale).add(offset).round().to(quant_attrs["dtype"])
188+
tensor = tensor.div(scale).add(zero_point).round().to(dtype)
189+
# Make the backends access data correctly
190+
if bitwidth == 4:
191+
mask = torch.full(tensor.size(), 0x0F, dtype=torch.int8)
192+
tensor = torch.bitwise_and(mask, tensor)
193+
return tensor
162194

163195
def get_tensor_type(
164196
self,
@@ -278,7 +310,12 @@ def define_value(
278310
)
279311
else:
280312
if quant_configs:
281-
tensor = self.get_quant_tensor_value(node, tensor, dtype)
313+
tensor = self.get_quant_tensor_value(
314+
tensor,
315+
node.meta["quant_attrs"],
316+
dtype,
317+
quant_configs.get("bitwidth"),
318+
)
282319
tensor_wrapper = PyQnnWrapper.TensorWrapper(
283320
tensor_name,
284321
tensor_type,

backends/qualcomm/builders/op_conv2d.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,7 @@ def define_node(
248248

249249
filter_node = node.args[1]
250250
filter_tensor = get_parameter(filter_node, self.edge_program)
251+
# weight of pytorch OIHW, yet QNN is HWIO
251252
filter_axis_order = (2, 3, 1, 0)
252253
filter_tensor = filter_tensor.permute(dims=filter_axis_order).contiguous()
253254
filter_tensor_wrapper = self.define_tensor(

0 commit comments

Comments
 (0)