Skip to content

Qualcomm AI Engine Direct - Enable 4 bits BW quantization #2506

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions backends/qualcomm/aot/python/PyQnnWrapperAdaptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,28 @@ std::unique_ptr<QuantizeParamsWrapper> CreateQuantizationParamWrapper(
quantize_param_wrapper =
std::make_unique<AxisScaleOffsetQuantizeParamsWrapper>(
axis, scale_offset);
} else if (encoding == QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET) {
uint32_t bitwidth = quant_info["bitwidth"].cast<uint32_t>();
int32_t axis = quant_info["axis"].cast<int32_t>();
std::vector<Qnn_ScaleOffset_t> scale_offset =
quant_info["scale_offset"].cast<std::vector<Qnn_ScaleOffset_t>>();
uint32_t num_elements = scale_offset.size();
std::vector<float> scales;
std::vector<int32_t> offsets;
for (const auto& scale_offset : scale_offset) {
scales.push_back(scale_offset.scale);
offsets.push_back(scale_offset.offset);
}
quantize_param_wrapper =
std::make_unique<BwAxisScaleOffsetQuantizeParamsWrapper>(
bitwidth, axis, num_elements, scales, offsets);
} else if (encoding == QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET) {
uint32_t bitwidth = quant_info["bitwidth"].cast<uint32_t>();
float scale = quant_info["scale"].cast<float>();
int32_t offset = quant_info["offset"].cast<int32_t>();
quantize_param_wrapper =
std::make_unique<BwScaleOffsetQuantizeParamsWrapper>(
bitwidth, scale, offset);
} else if (encoding == QNN_QUANTIZATION_ENCODING_SCALE_OFFSET) {
float scale = quant_info["scale"].cast<float>();
int32_t offset = quant_info["offset"].cast<int32_t>();
Expand Down
27 changes: 27 additions & 0 deletions backends/qualcomm/aot/wrappers/QuantizeParamsWrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,33 @@ std::unique_ptr<QuantizeParamsWrapper> CreateQuantizationParamWrapper(
quantize_param_wrapper =
std::make_unique<AxisScaleOffsetQuantizeParamsWrapper>(
quantization.axisScaleOffsetEncoding.axis, scale_offset);
} else if (
quantization.quantizationEncoding ==
QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET) {
std::vector<float> scales(
quantization.bwAxisScaleOffsetEncoding.scales,
quantization.bwAxisScaleOffsetEncoding.scales +
quantization.bwAxisScaleOffsetEncoding.numElements);
std::vector<int32_t> offsets(
quantization.bwAxisScaleOffsetEncoding.offsets,
quantization.bwAxisScaleOffsetEncoding.offsets +
quantization.bwAxisScaleOffsetEncoding.numElements);

quantize_param_wrapper =
std::make_unique<BwAxisScaleOffsetQuantizeParamsWrapper>(
quantization.bwAxisScaleOffsetEncoding.bitwidth,
quantization.bwAxisScaleOffsetEncoding.axis,
quantization.bwAxisScaleOffsetEncoding.numElements,
scales,
offsets);
} else if (
quantization.quantizationEncoding ==
QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET) {
quantize_param_wrapper =
std::make_unique<BwScaleOffsetQuantizeParamsWrapper>(
quantization.bwScaleOffsetEncoding.bitwidth,
quantization.bwScaleOffsetEncoding.scale,
quantization.bwScaleOffsetEncoding.offset);
} else if (
quantization.quantizationEncoding ==
QNN_QUANTIZATION_ENCODING_SCALE_OFFSET) {
Expand Down
111 changes: 111 additions & 0 deletions backends/qualcomm/aot/wrappers/QuantizeParamsWrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,117 @@ class UndefinedQuantizeParamsWrapper final : public QuantizeParamsWrapper {
}
};

class BwAxisScaleOffsetQuantizeParamsWrapper final
: public QuantizeParamsWrapper {
public:
explicit BwAxisScaleOffsetQuantizeParamsWrapper(
std::uint32_t bitwidth,
std::int32_t axis,
std::uint32_t num_elements,
std::vector<float> scales,
std::vector<int32_t> offsets)
: QuantizeParamsWrapper(
QNN_DEFINITION_DEFINED,
QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET),
bitwidth_(bitwidth),
axis_(axis),
num_elements_(num_elements),
scales_(scales),
offsets_(offsets) {}

BwAxisScaleOffsetQuantizeParamsWrapper(
const BwAxisScaleOffsetQuantizeParamsWrapper& rhs)
: QuantizeParamsWrapper(
rhs.GetEncodingDefinition(),
rhs.GetQuantizationEncoding()),
bitwidth_(rhs.bitwidth_),
axis_(rhs.axis_),
num_elements_(rhs.num_elements_),
scales_(rhs.scales_),
offsets_(rhs.offsets_) {}
BwAxisScaleOffsetQuantizeParamsWrapper(
BwAxisScaleOffsetQuantizeParamsWrapper&& rhs) = delete;
BwAxisScaleOffsetQuantizeParamsWrapper& operator=(
const BwAxisScaleOffsetQuantizeParamsWrapper& rhs) = delete;
BwAxisScaleOffsetQuantizeParamsWrapper& operator=(
BwAxisScaleOffsetQuantizeParamsWrapper&& rhs) = delete;

~BwAxisScaleOffsetQuantizeParamsWrapper() override = default;

std::unique_ptr<QuantizeParamsWrapper> Clone() override {
return std::make_unique<BwAxisScaleOffsetQuantizeParamsWrapper>(*this);
}

Qnn_QuantizeParams_t CreateQuantizeParams() override {
Qnn_QuantizeParams_t rval;
rval.encodingDefinition = GetEncodingDefinition();
rval.quantizationEncoding = GetQuantizationEncoding();
rval.bwAxisScaleOffsetEncoding.bitwidth = bitwidth_;
rval.bwAxisScaleOffsetEncoding.axis = axis_;
rval.bwAxisScaleOffsetEncoding.numElements = num_elements_;
rval.bwAxisScaleOffsetEncoding.scales = scales_.data();
rval.bwAxisScaleOffsetEncoding.offsets = offsets_.data();
return rval;
}

private:
std::uint32_t bitwidth_;
std::int32_t axis_;
std::uint32_t num_elements_;
std::vector<float> scales_;
std::vector<int32_t> offsets_;
};

class BwScaleOffsetQuantizeParamsWrapper final : public QuantizeParamsWrapper {
public:
explicit BwScaleOffsetQuantizeParamsWrapper(
std::uint32_t bitwidth,
float scale,
std::int32_t offset)
: QuantizeParamsWrapper(
QNN_DEFINITION_DEFINED,
QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET),
bitwidth_(bitwidth),
scale_(scale),
offset_(offset) {}

BwScaleOffsetQuantizeParamsWrapper(
const BwScaleOffsetQuantizeParamsWrapper& rhs)
: QuantizeParamsWrapper(
rhs.GetEncodingDefinition(),
rhs.GetQuantizationEncoding()),
bitwidth_(rhs.bitwidth_),
scale_(rhs.scale_),
offset_(rhs.offset_) {}
BwScaleOffsetQuantizeParamsWrapper(BwScaleOffsetQuantizeParamsWrapper&& rhs) =
delete;
BwScaleOffsetQuantizeParamsWrapper& operator=(
const BwScaleOffsetQuantizeParamsWrapper& rhs) = delete;
BwScaleOffsetQuantizeParamsWrapper& operator=(
BwScaleOffsetQuantizeParamsWrapper&& rhs) = delete;

~BwScaleOffsetQuantizeParamsWrapper() override = default;

std::unique_ptr<QuantizeParamsWrapper> Clone() override {
return std::make_unique<BwScaleOffsetQuantizeParamsWrapper>(*this);
}

Qnn_QuantizeParams_t CreateQuantizeParams() override {
Qnn_QuantizeParams_t rval;
rval.encodingDefinition = GetEncodingDefinition();
rval.quantizationEncoding = GetQuantizationEncoding();
rval.bwScaleOffsetEncoding.bitwidth = bitwidth_;
rval.bwScaleOffsetEncoding.scale = scale_;
rval.bwScaleOffsetEncoding.offset = offset_;
return rval;
}

private:
std::uint32_t bitwidth_;
float scale_;
std::int32_t offset_;
};

class ScaleOffsetQuantizeParamsWrapper final : public QuantizeParamsWrapper {
public:
explicit ScaleOffsetQuantizeParamsWrapper(float scale, std::int32_t offset)
Expand Down
165 changes: 101 additions & 64 deletions backends/qualcomm/builders/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import copy
from typing import Any, Dict, Tuple

import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
Expand Down Expand Up @@ -38,16 +39,16 @@
float: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
}

PER_CHANNEL_ENCODING_MAPPING = {
exir_ops.edge.quantized_decomposed.quantize_per_channel.default: PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET,
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default: PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET,
PER_CHANNEL_ENCODING = {
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
}

PER_TENSOR_ENCODING_MAPPING = {
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_SCALE_OFFSET,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor: PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_SCALE_OFFSET,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_SCALE_OFFSET,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor: PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_SCALE_OFFSET,
PER_TENSOR_ENCODING = {
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
}


Expand Down Expand Up @@ -87,6 +88,68 @@ def _get_tensor(node, index):
tensor = tensor.permute(dims=op_node.meta["axis_order"]).contiguous()
return tensor

def make_qnn_per_channel_config(self, node: torch.fx.Node, quant_attrs: Dict):
quant_config = copy.deepcopy(quant_attrs)

scales = quant_attrs["scales"]
zero_points = quant_attrs["zero_points"]
assert len(scales) == len(
zero_points
), f"Per channel encoding of node {node}, has different size for scales {len(scales)} and zero_points {len(zero_points)}"

scale_offset = []
for i in range(len(scales)):
# check Qnn_ScaleOffset_t in QNN/include/QnnTypes.h
scale_offset.append(
PyQnnWrapper.Qnn_ScaleOffset_t(scales[i], -zero_points[i])
)

user_0 = list(node.users)[0]
# Memory layout of QNN conv weight always ends in Output. Like conv2d is HWIO
if (
"convolution" in user_0.target.__name__
and list(node.users)[0].args[1] == node
):
quant_config["axis"] = 3

else:
quant_config["axis"] = quant_attrs["axis"]

quant_config["scale_offset"] = scale_offset
# special case for 4 bits
if (
quant_config["dtype"] == torch.int8
and quant_config["quant_max"] - quant_config["quant_min"] <= 15
):
quant_config["bitwidth"] = 4
return (
PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET,
quant_config,
)
return (
PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET,
quant_config,
)

def make_qnn_per_tensor_config(self, quant_attrs: Dict):
quant_config = copy.deepcopy(quant_attrs)
# check Qnn_ScaleOffset_t in QNN/include/QnnTypes.h
quant_config["offset"] = -quant_attrs["zero_point"]
# special case for 4 bits
if (
quant_config["dtype"] == torch.int8
and quant_config["quant_max"] - quant_config["quant_min"] <= 15
):
quant_config["bitwidth"] = 4
return (
PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET,
quant_config,
)
return (
PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_SCALE_OFFSET,
quant_config,
)

def get_quant_encoding_conf(self, node: torch.fx.Node) -> Tuple[Any, Dict]:
if not node.meta.get("quant_attrs", None):
return (
Expand All @@ -99,66 +162,35 @@ def get_quant_encoding_conf(self, node: torch.fx.Node) -> Tuple[Any, Dict]:
if "requantize" in node.meta
else node.meta["quant_attrs"]
)
encoding = quant_attrs["encoding"]

quant_config = {}
if encoding in PER_CHANNEL_ENCODING_MAPPING:
scales = quant_attrs["scales"]
zero_points = quant_attrs["zero_points"]
assert len(scales) == len(
zero_points
), f"Per channel encoding of node {node}, has differnt size fo scales {len(scales)} and zero_points {len(zero_points)}"

scale_offset = []
for i in range(len(scales)):
scale_offset.append(
PyQnnWrapper.Qnn_ScaleOffset_t(scales[i], -zero_points[i])
)

user_0 = list(node.users)[0]
# Memory layout of QNN conv is NHW"C", need to set axis as 3
if (
type(user_0.target) != str
and user_0.target.__name__ in ["aten.convolution.default"]
and list(node.users)[0].args[1] == node
):
quant_config["axis"] = 3
else:
quant_config["axis"] = quant_attrs["axis"]

quant_config["scale_offset"] = scale_offset
quant_config["quant_max"] = quant_attrs["quant_max"]
quant_config["quant_min"] = quant_attrs["quant_min"]
quant_config["dtype"] = quant_attrs["dtype"]
return PER_CHANNEL_ENCODING_MAPPING[encoding], quant_config

# per tensor situation
quant_config["scale"] = quant_attrs["scale"]
# check Qnn_ScaleOffset_t in QNN/include/QnnTypes.h
quant_config["offset"] = -quant_attrs["zero_point"]
# Distinguish what data type the node is
quant_config["quant_max"] = quant_attrs["quant_max"]
quant_config["quant_min"] = quant_attrs["quant_min"]
quant_config["dtype"] = quant_attrs["dtype"]
return PER_TENSOR_ENCODING_MAPPING[encoding], quant_config
if quant_attrs["encoding"] in PER_CHANNEL_ENCODING:
return self.make_qnn_per_channel_config(node, quant_attrs)

return self.make_qnn_per_tensor_config(quant_attrs)

def get_quant_tensor_value(
self, node: torch.fx.Node, tensor: torch.Tensor, dtype
self, tensor: torch.Tensor, quant_attrs: Dict, dtype, bitwidth
) -> torch.Tensor:
quant_attrs = node.meta["quant_attrs"]
encoding = quant_attrs["encoding"]

if encoding in PER_CHANNEL_ENCODING_MAPPING:
scales = quant_attrs["scales"]
offsets = quant_attrs["zero_points"]
return tensor.div(scales).add(offsets).round().to(quant_attrs["dtype"])
if quant_attrs["encoding"] in PER_TENSOR_ENCODING:
scale = quant_attrs["scale"]
zero_point = quant_attrs["zero_point"]
else: # per channel case
scale = quant_attrs["scales"]
zero_point = quant_attrs["zero_points"]

# To bypass torch.uint16 quantization is not supported
dtype = (
torch.int32
if dtype == PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_16
else quant_attrs["dtype"]
)

# per tensor situation
scale = quant_attrs["scale"]
offset = quant_attrs["zero_point"]
if dtype == PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_16:
return tensor.div(scale).add(offset).round().to(torch.int32)
return tensor.div(scale).add(offset).round().to(quant_attrs["dtype"])
tensor = tensor.div(scale).add(zero_point).round().to(dtype)
# Make the backends access data correctly
if bitwidth == 4:
mask = torch.full(tensor.size(), 0x0F, dtype=torch.int8)
tensor = torch.bitwise_and(mask, tensor)
return tensor

def get_tensor_type(
self,
Expand Down Expand Up @@ -278,7 +310,12 @@ def define_value(
)
else:
if quant_configs:
tensor = self.get_quant_tensor_value(node, tensor, dtype)
tensor = self.get_quant_tensor_value(
tensor,
node.meta["quant_attrs"],
dtype,
quant_configs.get("bitwidth"),
)
tensor_wrapper = PyQnnWrapper.TensorWrapper(
tensor_name,
tensor_type,
Expand Down
1 change: 1 addition & 0 deletions backends/qualcomm/builders/op_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ def define_node(

filter_node = node.args[1]
filter_tensor = get_parameter(filter_node, self.edge_program)
# weight of pytorch OIHW, yet QNN is HWIO
filter_axis_order = (2, 3, 1, 0)
filter_tensor = filter_tensor.permute(dims=filter_axis_order).contiguous()
filter_tensor_wrapper = self.define_tensor(
Expand Down
Loading