4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
+ import copy
7
8
from typing import Any , Dict , Tuple
8
9
9
10
import executorch .backends .qualcomm .python .PyQnnWrapperAdaptor as PyQnnWrapper
38
39
float : PyQnnWrapper .Qnn_DataType_t .QNN_DATATYPE_FLOAT_32 ,
39
40
}
40
41
41
- PER_CHANNEL_ENCODING_MAPPING = {
42
- exir_ops .edge .quantized_decomposed .quantize_per_channel .default : PyQnnWrapper . Qnn_QuantizationEncoding_t . QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET ,
43
- exir_ops .edge .quantized_decomposed .dequantize_per_channel .default : PyQnnWrapper . Qnn_QuantizationEncoding_t . QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET ,
42
+ PER_CHANNEL_ENCODING = {
43
+ exir_ops .edge .quantized_decomposed .quantize_per_channel .default ,
44
+ exir_ops .edge .quantized_decomposed .dequantize_per_channel .default ,
44
45
}
45
46
46
- PER_TENSOR_ENCODING_MAPPING = {
47
- exir_ops .edge .quantized_decomposed .quantize_per_tensor .default : PyQnnWrapper . Qnn_QuantizationEncoding_t . QNN_QUANTIZATION_ENCODING_SCALE_OFFSET ,
48
- exir_ops .edge .quantized_decomposed .quantize_per_tensor .tensor : PyQnnWrapper . Qnn_QuantizationEncoding_t . QNN_QUANTIZATION_ENCODING_SCALE_OFFSET ,
49
- exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default : PyQnnWrapper . Qnn_QuantizationEncoding_t . QNN_QUANTIZATION_ENCODING_SCALE_OFFSET ,
50
- exir_ops .edge .quantized_decomposed .dequantize_per_tensor .tensor : PyQnnWrapper . Qnn_QuantizationEncoding_t . QNN_QUANTIZATION_ENCODING_SCALE_OFFSET ,
47
+ PER_TENSOR_ENCODING = {
48
+ exir_ops .edge .quantized_decomposed .quantize_per_tensor .default ,
49
+ exir_ops .edge .quantized_decomposed .quantize_per_tensor .tensor ,
50
+ exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default ,
51
+ exir_ops .edge .quantized_decomposed .dequantize_per_tensor .tensor ,
51
52
}
52
53
53
54
@@ -87,6 +88,68 @@ def _get_tensor(node, index):
87
88
tensor = tensor .permute (dims = op_node .meta ["axis_order" ]).contiguous ()
88
89
return tensor
89
90
91
+ def make_qnn_per_channel_config (self , node : torch .fx .Node , quant_attrs : Dict ):
92
+ quant_config = copy .deepcopy (quant_attrs )
93
+
94
+ scales = quant_attrs ["scales" ]
95
+ zero_points = quant_attrs ["zero_points" ]
96
+ assert len (scales ) == len (
97
+ zero_points
98
+ ), f"Per channel encoding of node { node } , has different size for scales { len (scales )} and zero_points { len (zero_points )} "
99
+
100
+ scale_offset = []
101
+ for i in range (len (scales )):
102
+ # check Qnn_ScaleOffset_t in QNN/include/QnnTypes.h
103
+ scale_offset .append (
104
+ PyQnnWrapper .Qnn_ScaleOffset_t (scales [i ], - zero_points [i ])
105
+ )
106
+
107
+ user_0 = list (node .users )[0 ]
108
+ # Memory layout of QNN conv weight always ends in Output. Like conv2d is HWIO
109
+ if (
110
+ "convolution" in user_0 .target .__name__
111
+ and list (node .users )[0 ].args [1 ] == node
112
+ ):
113
+ quant_config ["axis" ] = 3
114
+
115
+ else :
116
+ quant_config ["axis" ] = quant_attrs ["axis" ]
117
+
118
+ quant_config ["scale_offset" ] = scale_offset
119
+ # special case for 4 bits
120
+ if (
121
+ quant_config ["dtype" ] == torch .int8
122
+ and quant_config ["quant_max" ] - quant_config ["quant_min" ] <= 15
123
+ ):
124
+ quant_config ["bitwidth" ] = 4
125
+ return (
126
+ PyQnnWrapper .Qnn_QuantizationEncoding_t .QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET ,
127
+ quant_config ,
128
+ )
129
+ return (
130
+ PyQnnWrapper .Qnn_QuantizationEncoding_t .QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET ,
131
+ quant_config ,
132
+ )
133
+
134
+ def make_qnn_per_tensor_config (self , quant_attrs : Dict ):
135
+ quant_config = copy .deepcopy (quant_attrs )
136
+ # check Qnn_ScaleOffset_t in QNN/include/QnnTypes.h
137
+ quant_config ["offset" ] = - quant_attrs ["zero_point" ]
138
+ # special case for 4 bits
139
+ if (
140
+ quant_config ["dtype" ] == torch .int8
141
+ and quant_config ["quant_max" ] - quant_config ["quant_min" ] <= 15
142
+ ):
143
+ quant_config ["bitwidth" ] = 4
144
+ return (
145
+ PyQnnWrapper .Qnn_QuantizationEncoding_t .QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET ,
146
+ quant_config ,
147
+ )
148
+ return (
149
+ PyQnnWrapper .Qnn_QuantizationEncoding_t .QNN_QUANTIZATION_ENCODING_SCALE_OFFSET ,
150
+ quant_config ,
151
+ )
152
+
90
153
def get_quant_encoding_conf (self , node : torch .fx .Node ) -> Tuple [Any , Dict ]:
91
154
if not node .meta .get ("quant_attrs" , None ):
92
155
return (
@@ -99,66 +162,35 @@ def get_quant_encoding_conf(self, node: torch.fx.Node) -> Tuple[Any, Dict]:
99
162
if "requantize" in node .meta
100
163
else node .meta ["quant_attrs" ]
101
164
)
102
- encoding = quant_attrs ["encoding" ]
103
-
104
- quant_config = {}
105
- if encoding in PER_CHANNEL_ENCODING_MAPPING :
106
- scales = quant_attrs ["scales" ]
107
- zero_points = quant_attrs ["zero_points" ]
108
- assert len (scales ) == len (
109
- zero_points
110
- ), f"Per channel encoding of node { node } , has differnt size fo scales { len (scales )} and zero_points { len (zero_points )} "
111
-
112
- scale_offset = []
113
- for i in range (len (scales )):
114
- scale_offset .append (
115
- PyQnnWrapper .Qnn_ScaleOffset_t (scales [i ], - zero_points [i ])
116
- )
117
165
118
- user_0 = list (node .users )[0 ]
119
- # Memory layout of QNN conv is NHW"C", need to set axis as 3
120
- if (
121
- type (user_0 .target ) != str
122
- and user_0 .target .__name__ in ["aten.convolution.default" ]
123
- and list (node .users )[0 ].args [1 ] == node
124
- ):
125
- quant_config ["axis" ] = 3
126
- else :
127
- quant_config ["axis" ] = quant_attrs ["axis" ]
128
-
129
- quant_config ["scale_offset" ] = scale_offset
130
- quant_config ["quant_max" ] = quant_attrs ["quant_max" ]
131
- quant_config ["quant_min" ] = quant_attrs ["quant_min" ]
132
- quant_config ["dtype" ] = quant_attrs ["dtype" ]
133
- return PER_CHANNEL_ENCODING_MAPPING [encoding ], quant_config
134
-
135
- # per tensor situation
136
- quant_config ["scale" ] = quant_attrs ["scale" ]
137
- # check Qnn_ScaleOffset_t in QNN/include/QnnTypes.h
138
- quant_config ["offset" ] = - quant_attrs ["zero_point" ]
139
- # Distinguish what data type the node is
140
- quant_config ["quant_max" ] = quant_attrs ["quant_max" ]
141
- quant_config ["quant_min" ] = quant_attrs ["quant_min" ]
142
- quant_config ["dtype" ] = quant_attrs ["dtype" ]
143
- return PER_TENSOR_ENCODING_MAPPING [encoding ], quant_config
166
+ if quant_attrs ["encoding" ] in PER_CHANNEL_ENCODING :
167
+ return self .make_qnn_per_channel_config (node , quant_attrs )
168
+
169
+ return self .make_qnn_per_tensor_config (quant_attrs )
144
170
145
171
def get_quant_tensor_value (
146
- self , node : torch .fx . Node , tensor : torch . Tensor , dtype
172
+ self , tensor : torch .Tensor , quant_attrs : Dict , dtype , bitwidth
147
173
) -> torch .Tensor :
148
- quant_attrs = node .meta ["quant_attrs" ]
149
- encoding = quant_attrs ["encoding" ]
150
-
151
- if encoding in PER_CHANNEL_ENCODING_MAPPING :
152
- scales = quant_attrs ["scales" ]
153
- offsets = quant_attrs ["zero_points" ]
154
- return tensor .div (scales ).add (offsets ).round ().to (quant_attrs ["dtype" ])
174
+ if quant_attrs ["encoding" ] in PER_TENSOR_ENCODING :
175
+ scale = quant_attrs ["scale" ]
176
+ zero_point = quant_attrs ["zero_point" ]
177
+ else : # per channel case
178
+ scale = quant_attrs ["scales" ]
179
+ zero_point = quant_attrs ["zero_points" ]
180
+
181
+ # To bypass torch.uint16 quantization is not supported
182
+ dtype = (
183
+ torch .int32
184
+ if dtype == PyQnnWrapper .Qnn_DataType_t .QNN_DATATYPE_UINT_16
185
+ else quant_attrs ["dtype" ]
186
+ )
155
187
156
- # per tensor situation
157
- scale = quant_attrs [ "scale" ]
158
- offset = quant_attrs [ "zero_point" ]
159
- if dtype == PyQnnWrapper . Qnn_DataType_t . QNN_DATATYPE_UINT_16 :
160
- return tensor . div ( scale ). add ( offset ). round (). to ( torch .int32 )
161
- return tensor . div ( scale ). add ( offset ). round (). to ( quant_attrs [ "dtype" ])
188
+ tensor = tensor . div ( scale ). add ( zero_point ). round (). to ( dtype )
189
+ # Make the backends access data correctly
190
+ if bitwidth == 4 :
191
+ mask = torch . full ( tensor . size (), 0x0F , dtype = torch . int8 )
192
+ tensor = torch .bitwise_and ( mask , tensor )
193
+ return tensor
162
194
163
195
def get_tensor_type (
164
196
self ,
@@ -278,7 +310,12 @@ def define_value(
278
310
)
279
311
else :
280
312
if quant_configs :
281
- tensor = self .get_quant_tensor_value (node , tensor , dtype )
313
+ tensor = self .get_quant_tensor_value (
314
+ tensor ,
315
+ node .meta ["quant_attrs" ],
316
+ dtype ,
317
+ quant_configs .get ("bitwidth" ),
318
+ )
282
319
tensor_wrapper = PyQnnWrapper .TensorWrapper (
283
320
tensor_name ,
284
321
tensor_type ,
0 commit comments