diff --git a/backends/qualcomm/builders/node_visitor.py b/backends/qualcomm/builders/node_visitor.py index 3f40dc56737..60c51b055d0 100644 --- a/backends/qualcomm/builders/node_visitor.py +++ b/backends/qualcomm/builders/node_visitor.py @@ -14,8 +14,6 @@ from executorch.exir.dialects._ops import ops as exir_ops -from .qnn_constants import QNN_uint16 - from .utils import get_parameter, is_graph_input, is_graph_output, is_parameter @@ -26,7 +24,7 @@ # Note that there is no int64 tensor data type in Qnn. torch.int64: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UNDEFINED, torch.uint8: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_8, - QNN_uint16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_16, + torch.uint16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_16, } QNN_TENSOR_TYPE_MAP = { torch.bool: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, @@ -36,7 +34,7 @@ torch.int32: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_32, torch.int64: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_64, torch.uint8: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_8, - QNN_uint16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_16, + torch.uint16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_16, float: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, } @@ -170,7 +168,7 @@ def get_quant_encoding_conf( return self.make_qnn_per_tensor_config(quant_attrs) def get_quant_tensor_value( - self, tensor: torch.Tensor, quant_attrs: Dict, dtype, bitwidth + self, tensor: torch.Tensor, quant_attrs: Dict, quant_configs: Dict ) -> torch.Tensor: if quant_attrs["encoding"] in PER_TENSOR_ENCODING: scale = quant_attrs["scale"] @@ -179,16 +177,11 @@ def get_quant_tensor_value( scale = quant_attrs["scales"] zero_point = quant_attrs["zero_points"] - # To bypass torch.uint16 quantization is not supported - dtype = ( - torch.int32 - if dtype == PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_16 - else quant_attrs["dtype"] - ) + dtype = quant_configs["dtype"] tensor = tensor.div(scale).add(zero_point).round().to(dtype) # Make the backends access data correctly - if bitwidth == 4: + if quant_configs.get("bitwidth") == 4: mask = torch.full(tensor.size(), 0x0F, dtype=torch.int8) tensor = torch.bitwise_and(mask, tensor) return tensor @@ -237,7 +230,7 @@ def get_data_type( <= torch.iinfo(torch.int16).max - torch.iinfo(torch.int16).min ): if unsigned: - quant_config["dtype"] = QNN_uint16 + quant_config["dtype"] = torch.uint16 else: quant_config["dtype"] = torch.int16 return QNN_QUANT_TYPE_MAP[quant_config["dtype"]] @@ -328,8 +321,7 @@ def define_tensor( tensor = self.get_quant_tensor_value( tensor, node.meta["quant_attrs"], - dtype, - quant_configs.get("bitwidth"), + quant_configs, ) tensor_wrapper = PyQnnWrapper.TensorWrapper( tensor_name, diff --git a/backends/qualcomm/builders/qnn_constants.py b/backends/qualcomm/builders/qnn_constants.py index 2adb3102357..118a0768d9d 100644 --- a/backends/qualcomm/builders/qnn_constants.py +++ b/backends/qualcomm/builders/qnn_constants.py @@ -8,7 +8,6 @@ from enum import IntEnum, unique QNN_OP_PACKAGE_NAME_QTI_AISW = "qti.aisw" -QNN_uint16 = "uint16" # Below constants should be same as those in QNN headers. # Maybe someday we should expose these constants by pybind