11
11
#include < algorithm>
12
12
#include < cinttypes>
13
13
#include < cmath>
14
+ #if defined(__aarch64__) || defined(__ARM_NEON)
15
+ #include < arm_neon.h>
16
+ #endif
14
17
15
18
/* *
16
19
* For an input tensor, use the scale and zero_point arguments to quantize it.
@@ -22,6 +25,8 @@ namespace native {
22
25
using Tensor = exec_aten::Tensor;
23
26
using Scalar = exec_aten::Scalar;
24
27
using ScalarType = exec_aten::ScalarType;
28
+ using StridesType = exec_aten::StridesType;
29
+ using SizesType = exec_aten::SizesType;
25
30
26
31
namespace {
27
32
@@ -63,6 +68,183 @@ void check_dequantize_per_tensor_args(
63
68
quant_max);
64
69
}
65
70
71
+ /* *
72
+ * Useful to reduce a tensor `in` over a given dimension `dim` using the
73
+ * reduce function `fn`, which should have the following signature:
74
+ * void fn(const size_t size, const size_t stride, const size_t base_ix)
75
+ * where `size` and `stride` are the size and stride of the dimension being
76
+ * reduced and `base_ix` is the index of the first element of the reduction.
77
+ */
78
+ template <typename Fn>
79
+ void apply_over_unpacked_dim (
80
+ const Fn& fn,
81
+ const exec_aten::Tensor& in,
82
+ const int64_t & dim) {
83
+ if (in.numel () == 0 ) {
84
+ return ;
85
+ }
86
+
87
+ ET_CHECK_MSG (in.dim () > 0 , " Input tensor must have at least one dimension" );
88
+ ET_CHECK_VALID_DIM (dim, in.dim ());
89
+
90
+ const size_t d = ET_NORMALIZE_IX (dim, in.dim ());
91
+ const size_t dim_size = in.size (d);
92
+ const size_t outer_size = getLeadingDims (in, d);
93
+ const size_t inner_size = getTrailingDims (in, d);
94
+ // Loop through all outer dimensions
95
+ for (size_t outer_idx = 0 ; outer_idx < outer_size; ++outer_idx) {
96
+ // Loop through dim
97
+ for (size_t unpacked_dim_idx = 0 ; unpacked_dim_idx < dim_size;
98
+ ++unpacked_dim_idx) {
99
+ fn (inner_size, outer_idx, unpacked_dim_idx);
100
+ }
101
+ }
102
+ }
103
+
104
+ void dequantize_optimized (
105
+ const int8_t * in,
106
+ const double scale,
107
+ const int64_t zero_point,
108
+ float * out,
109
+ int64_t quant_min,
110
+ int64_t quant_max,
111
+ size_t numel) {
112
+ ET_CHECK_MSG (
113
+ zero_point >= quant_min,
114
+ " zero_point must be %" PRId64 " <= quant_min %" PRId64,
115
+ zero_point,
116
+ quant_min);
117
+ ET_CHECK_MSG (
118
+ zero_point <= quant_max,
119
+ " zero_point must be %" PRId64 " >= quant_max %" PRId64,
120
+ zero_point,
121
+ quant_max);
122
+ size_t i = 0 ;
123
+ #if defined(__aarch64__) || defined(__ARM_NEON)
124
+ int8x8_t zero_point_vec = vdup_n_s8 (zero_point);
125
+ float32x4_t scales = vdupq_n_f32 (static_cast <float >(scale));
126
+ constexpr int32_t kVecSize = 16 ;
127
+ const size_t num_vecs = numel / kVecSize ;
128
+ const int8_t * in_copy = in;
129
+ float * out_copy = out;
130
+ for (; i < num_vecs; i++) {
131
+ int8x16_t in_vec = vld1q_s8 (in_copy);
132
+ int16x8_t sub_vec_0_7 = vsubl_s8 (vget_low_s8 (in_vec), zero_point_vec);
133
+ int32x4_t sub_vec_0_3 = vmovl_s16 (vget_low_s16 (sub_vec_0_7));
134
+ int32x4_t sub_vec_4_7 = vmovl_s16 (vget_high_s16 (sub_vec_0_7));
135
+ float32x4_t out_vec_0_3 = vmulq_f32 (vcvtq_f32_s32 (sub_vec_0_3), scales);
136
+ float32x4_t out_vec_4_7 = vmulq_f32 (vcvtq_f32_s32 (sub_vec_4_7), scales);
137
+
138
+ int16x8_t sub_vec_8_15 = vsubl_s8 (vget_high_s8 (in_vec), zero_point_vec);
139
+ int32x4_t sub_vec_8_11 = vmovl_s16 (vget_low_s16 (sub_vec_8_15));
140
+ int32x4_t sub_vec_12_15 = vmovl_s16 (vget_high_s16 (sub_vec_8_15));
141
+ float32x4_t out_vec_8_11 = vmulq_f32 (vcvtq_f32_s32 (sub_vec_8_11), scales);
142
+ float32x4_t out_vec_12_15 = vmulq_f32 (vcvtq_f32_s32 (sub_vec_12_15), scales);
143
+ vst1q_f32 (out_copy + 0 , out_vec_0_3);
144
+ vst1q_f32 (out_copy + 4 , out_vec_4_7);
145
+ vst1q_f32 (out_copy + 8 , out_vec_8_11);
146
+ vst1q_f32 (out_copy + 12 , out_vec_12_15);
147
+ in_copy += kVecSize ;
148
+ out_copy += kVecSize ;
149
+ }
150
+ i = i * kVecSize ;
151
+ #endif
152
+ for (; i < numel; i++) {
153
+ out[i] = (in[i] - zero_point) * scale;
154
+ }
155
+ }
156
+
157
+ float get_scale (const Tensor& scale, size_t channel_ix) {
158
+ ET_CHECK_MSG (
159
+ (scale.scalar_type () == ScalarType::Double) ||
160
+ (scale.scalar_type () == ScalarType::Float),
161
+ " scale.scalar_type() %" PRId8 " is not double or float type" ,
162
+ static_cast <int8_t >(scale.scalar_type ()));
163
+ if (scale.scalar_type () == ScalarType::Double) {
164
+ return static_cast <float >(scale.const_data_ptr <double >()[channel_ix]);
165
+ } else {
166
+ return scale.const_data_ptr <float >()[channel_ix];
167
+ }
168
+ }
169
+
170
+ bool can_use_optimized_dequantize_per_channel (
171
+ const Tensor& in,
172
+ const ScalarType in_dtype,
173
+ exec_aten::optional<ScalarType>& out_dtype) {
174
+ bool is_contiguous = false ;
175
+ #ifdef USE_ATEN_LIB
176
+ is_contiguous = in.is_contiguous ();
177
+ #else
178
+ is_contiguous = executorch::runtime::is_contiguous_dim_order (
179
+ in.dim_order ().data (), in.dim ());
180
+ #endif
181
+ if (!is_contiguous || (in_dtype != ScalarType::Char) ||
182
+ (out_dtype.has_value () && out_dtype.value () != ScalarType::Float)) {
183
+ return false ;
184
+ }
185
+ return true ;
186
+ }
187
+
188
+ void dequantize_per_channel_optimized (
189
+ const Tensor& in,
190
+ const Tensor& scales,
191
+ const optional<Tensor>& opt_zero_points,
192
+ Tensor& out,
193
+ int64_t axis,
194
+ int64_t quant_min,
195
+ int64_t quant_max,
196
+ ScalarType in_dtype,
197
+ exec_aten::optional<ScalarType>& out_dtype) {
198
+ check_dequantize_per_tensor_args (
199
+ in, quant_min, quant_max, in_dtype, out_dtype, out);
200
+ ET_CHECK_MSG (
201
+ in_dtype == ScalarType::Char,
202
+ " in.scalar_type() %" PRId8 " is not supported:" ,
203
+ static_cast <int8_t >(in.scalar_type ()));
204
+ if (out_dtype.has_value ()) {
205
+ ET_CHECK_MSG (
206
+ out_dtype.value () == ScalarType::Float,
207
+ " Only float output is supported" );
208
+ }
209
+ const int8_t * in_data = in.const_data_ptr <int8_t >();
210
+ float * out_data = out.mutable_data_ptr <float >();
211
+ const int64_t * zero_points_data = nullptr ;
212
+ if (opt_zero_points.has_value ()) {
213
+ zero_points_data = opt_zero_points.value ().const_data_ptr <int64_t >();
214
+ }
215
+ const StridesType axis_stride = in.strides ()[axis];
216
+ const StridesType outer_stride = in.size (axis) * axis_stride;
217
+ apply_over_unpacked_dim (
218
+ [in_data,
219
+ out_data,
220
+ &scales,
221
+ zero_points_data,
222
+ axis_stride,
223
+ outer_stride,
224
+ quant_min,
225
+ quant_max](
226
+ SizesType numel, SizesType outer_idx, SizesType unpacked_dim_idx) {
227
+ const int8_t * in_data_local =
228
+ in_data + outer_idx * outer_stride + unpacked_dim_idx * axis_stride;
229
+ const double scale = get_scale (scales, unpacked_dim_idx);
230
+ const int64_t zero_point = zero_points_data != nullptr
231
+ ? zero_points_data[unpacked_dim_idx]
232
+ : 0 ;
233
+ float * out_data_local = out_data + outer_idx * outer_stride +
234
+ unpacked_dim_idx * axis_stride;
235
+ dequantize_optimized (
236
+ in_data_local,
237
+ scale,
238
+ zero_point,
239
+ out_data_local,
240
+ quant_min,
241
+ quant_max,
242
+ numel);
243
+ },
244
+ in,
245
+ axis);
246
+ }
247
+
66
248
} // namespace
67
249
68
250
/* *
@@ -172,19 +354,6 @@ Tensor& dequantize_per_tensor_tensor_args_out(
172
354
return out;
173
355
}
174
356
175
- float get_scale (const Tensor& scale, size_t channel_ix) {
176
- ET_CHECK_MSG (
177
- (scale.scalar_type () == ScalarType::Double) ||
178
- (scale.scalar_type () == ScalarType::Float),
179
- " scale.scalar_type() %" PRId8 " is not double or float type" ,
180
- static_cast <int8_t >(scale.scalar_type ()));
181
- if (scale.scalar_type () == ScalarType::Double) {
182
- return static_cast <float >(scale.const_data_ptr <double >()[channel_ix]);
183
- } else {
184
- return scale.const_data_ptr <float >()[channel_ix];
185
- }
186
- }
187
-
188
357
Tensor& dequantize_per_channel_out (
189
358
const Tensor& input,
190
359
const Tensor& scale,
@@ -229,6 +398,20 @@ Tensor& dequantize_per_channel_out(
229
398
check_dequantize_per_tensor_args (
230
399
input, quant_min, quant_max, dtype, out_dtype, out);
231
400
401
+ if (can_use_optimized_dequantize_per_channel (input, dtype, out_dtype)) {
402
+ dequantize_per_channel_optimized (
403
+ input,
404
+ scale,
405
+ opt_zero_points,
406
+ out,
407
+ axis,
408
+ quant_min,
409
+ quant_max,
410
+ dtype,
411
+ out_dtype);
412
+ return out;
413
+ }
414
+
232
415
// a list contains all dimensions except axis
233
416
int64_t dims[kTensorDimensionLimit ];
234
417
for (int64_t i = 0 ; i < input.dim () - 1 ; i++) {
0 commit comments