@@ -22,6 +22,8 @@ namespace native {
22
22
using Tensor = exec_aten::Tensor;
23
23
using Scalar = exec_aten::Scalar;
24
24
using ScalarType = exec_aten::ScalarType;
25
+ using StridesType = exec_aten::StridesType;
26
+ using SizesType = exec_aten::SizesType;
25
27
26
28
namespace {
27
29
@@ -61,6 +63,163 @@ void check_dequantize_per_tensor_args(
61
63
quant_max);
62
64
}
63
65
66
+ /* *
67
+ * Useful to reduce a tensor `in` over a given dimension `dim` using the
68
+ * reduce function `fn`, which should have the following signature:
69
+ * void fn(const size_t size, const size_t stride, const size_t base_ix)
70
+ * where `size` and `stride` are the size and stride of the dimension being
71
+ * reduced and `base_ix` is the index of the first element of the reduction.
72
+ */
73
+ template <typename Fn>
74
+ void apply_over_unpacked_dim (
75
+ const Fn& fn,
76
+ const exec_aten::Tensor& in,
77
+ const int64_t & dim) {
78
+ if (in.numel () == 0 ) {
79
+ return ;
80
+ }
81
+
82
+ ET_CHECK_MSG (in.dim () > 0 , " Input tensor must have at least one dimension" );
83
+ ET_CHECK_VALID_DIM (dim, in.dim ());
84
+
85
+ const size_t d = ET_NORMALIZE_IX (dim, in.dim ());
86
+ const size_t dim_size = in.size (d);
87
+ const size_t outer_size = getLeadingDims (in, d);
88
+ const size_t inner_size = getTrailingDims (in, d);
89
+ // Loop through all outer dimensions
90
+ for (size_t outer_idx = 0 ; outer_idx < outer_size; ++outer_idx) {
91
+ // Loop through dim
92
+ for (size_t unpacked_dim_idx = 0 ; unpacked_dim_idx < dim_size;
93
+ ++unpacked_dim_idx) {
94
+ fn (inner_size, outer_idx, unpacked_dim_idx);
95
+ }
96
+ }
97
+ }
98
+
99
+ void dequantize_optimized (
100
+ const int8_t * in,
101
+ const double scale,
102
+ const int64_t zero_point,
103
+ float * out,
104
+ int64_t quant_min,
105
+ int64_t quant_max,
106
+ size_t numel) {
107
+ ET_CHECK_MSG (
108
+ zero_point >= quant_min,
109
+ " zero_point must be %" PRId64 " <= quant_min %" PRId64,
110
+ zero_point,
111
+ quant_min);
112
+ ET_CHECK_MSG (
113
+ zero_point <= quant_max,
114
+ " zero_point must be %" PRId64 " >= quant_max %" PRId64,
115
+ zero_point,
116
+ quant_max);
117
+ size_t i = 0 ;
118
+ #if defined(__aarch64__) || defined(__ARM_NEON)
119
+ int8x8_t zero_point_vec = vdup_n_s8 (zero_point);
120
+ float32x4_t scales = vdupq_n_f32 (static_cast <float >(scale));
121
+ constexpr int32_t kVecSize = 16 ;
122
+ const size_t num_vecs = numel / kVecSize ;
123
+ const size_t rem = numel % kVecSize ;
124
+ for (; i < numel; i += kVecSize ) {
125
+ int8x16_t in_vec = vld1q_s8 (in);
126
+ int16x8_t sub_vec_0_7 = vsubl_s8 (vget_low_s8 (in_vec), zero_point_vec);
127
+ int32x4_t sub_vec_0_3 = vmovl_s16 (vget_low_s16 (sub_vec_0_7));
128
+ int32x4_t sub_vec_4_7 = vmovl_s16 (vget_high_s16 (sub_vec_0_7));
129
+ float32x4_t out_vec_0_3 = vmulq_f32 (vcvtq_f32_s32 (sub_vec_0_3), scales);
130
+ float32x4_t out_vec_4_7 = vmulq_f32 (vcvtq_f32_s32 (sub_vec_4_7), scales);
131
+
132
+ int16x8_t sub_vec_8_15 = vsubl_s8 (vget_high_s8 (in_vec), zero_point_vec);
133
+ int32x4_t sub_vec_8_11 = vmovl_s16 (vget_low_s16 (sub_vec_8_15));
134
+ int32x4_t sub_vec_12_15 = vmovl_s16 (vget_high_s16 (sub_vec_8_15));
135
+ float32x4_t out_vec_8_11 = vmulq_f32 (vcvtq_f32_s32 (sub_vec_8_11), scales);
136
+ float32x4_t out_vec_12_15 = vmulq_f32 (vcvtq_f32_s32 (sub_vec_12_15), scales);
137
+ in += kVecSize ;
138
+ }
139
+ #endif
140
+ for (; i < numel; i++) {
141
+ out[i] = (in[i] - zero_point) * scale;
142
+ }
143
+ }
144
+
145
+ bool can_use_optimized_dequantize_per_channel (
146
+ const Tensor& in,
147
+ const ScalarType in_dtype,
148
+ exec_aten::optional<ScalarType>& out_dtype) {
149
+ if (!executorch::runtime::is_contiguous_dim_order (
150
+ in.dim_order ().data (), in.dim ()) ||
151
+ (in_dtype != ScalarType::Char) ||
152
+ (out_dtype.has_value () && out_dtype.value () != ScalarType::Float)) {
153
+ return false ;
154
+ }
155
+ return true ;
156
+ }
157
+
158
+ void dequantize_per_channel_optimized (
159
+ const Tensor& in,
160
+ const Tensor& scales,
161
+ const optional<Tensor>& opt_zero_points,
162
+ Tensor& out,
163
+ int64_t axis,
164
+ int64_t quant_min,
165
+ int64_t quant_max,
166
+ ScalarType in_dtype,
167
+ exec_aten::optional<ScalarType>& out_dtype) {
168
+ check_dequantize_per_tensor_args (
169
+ in, quant_min, quant_max, in_dtype, out_dtype, out);
170
+ ET_CHECK_MSG (
171
+ executorch::runtime::is_contiguous_dim_order (
172
+ in.dim_order ().data (), in.dim ()),
173
+ " in must be in contiguous dim order" );
174
+ ET_CHECK_MSG (
175
+ in_dtype == ScalarType::Char,
176
+ " in.scalar_type() %" PRId8 " is not supported:" ,
177
+ static_cast <int8_t >(in.scalar_type ()));
178
+ if (out_dtype.has_value ()) {
179
+ ET_CHECK_MSG (
180
+ out_dtype.value () == ScalarType::Float,
181
+ " Only float output is supported" );
182
+ }
183
+ const int8_t * in_data = in.const_data_ptr <int8_t >();
184
+ float * out_data = out.mutable_data_ptr <float >();
185
+ const int64_t * zero_points_data = nullptr ;
186
+ if (opt_zero_points.has_value ()) {
187
+ zero_points_data = opt_zero_points.value ().const_data_ptr <int64_t >();
188
+ }
189
+ const double * scales_data = scales.const_data_ptr <double >();
190
+ const StridesType axis_stride = in.strides ()[axis];
191
+ const StridesType outer_stride = in.size (axis) * axis_stride;
192
+ apply_over_unpacked_dim (
193
+ [in_data,
194
+ out_data,
195
+ scales_data,
196
+ zero_points_data,
197
+ axis_stride,
198
+ outer_stride,
199
+ quant_min,
200
+ quant_max](
201
+ SizesType numel, SizesType outer_idx, SizesType unpacked_dim_idx) {
202
+ const int8_t * in_data_local =
203
+ in_data + outer_idx * outer_stride + unpacked_dim_idx * axis_stride;
204
+ const double scale = scales_data[unpacked_dim_idx];
205
+ const int64_t zero_point = zero_points_data != nullptr
206
+ ? zero_points_data[unpacked_dim_idx]
207
+ : 0 ;
208
+ float * out_data_local = out_data + outer_idx * outer_stride +
209
+ unpacked_dim_idx * axis_stride;
210
+ dequantize_optimized (
211
+ in_data_local,
212
+ scale,
213
+ zero_point,
214
+ out_data_local,
215
+ quant_min,
216
+ quant_max,
217
+ numel);
218
+ },
219
+ in,
220
+ axis);
221
+ }
222
+
64
223
} // namespace
65
224
66
225
/* *
@@ -217,6 +376,20 @@ Tensor& dequantize_per_channel_out(
217
376
check_dequantize_per_tensor_args (
218
377
input, quant_min, quant_max, dtype, out_dtype, out);
219
378
379
+ if (can_use_optimized_dequantize_per_channel (input, dtype, out_dtype)) {
380
+ dequantize_per_channel_optimized (
381
+ input,
382
+ scale,
383
+ opt_zero_points,
384
+ out,
385
+ axis,
386
+ quant_min,
387
+ quant_max,
388
+ dtype,
389
+ out_dtype);
390
+ return out;
391
+ }
392
+
220
393
// a list contains all dimensions except axis
221
394
int64_t dims[kTensorDimensionLimit ];
222
395
for (int64_t i = 0 ; i < input.dim () - 1 ; i++) {
0 commit comments