Skip to content

Commit f3d57c2

Browse files
committed
[Executorch][quant] Optimize per channel dequantize
Pull Request resolved: #5670 When using quantized kv cache, dequantization routine takes significantly long. This diff just vectorizes dequant per channel for common case. ghstack-source-id: 245703485 @exported-using-ghexport Differential Revision: [D63338858](https://our.internmc.facebook.com/intern/diff/D63338858/)
1 parent b36c899 commit f3d57c2

File tree

2 files changed

+226
-8
lines changed

2 files changed

+226
-8
lines changed

kernels/quantized/cpu/op_dequantize.cpp

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
#include <algorithm>
1212
#include <cinttypes>
1313
#include <cmath>
14+
#if defined(__aarch64__) || defined(__ARM_NEON)
15+
#include <arm_neon.h>
16+
#endif
1417

1518
/**
1619
* For an input tensor, use the scale and zero_point arguments to quantize it.
@@ -22,6 +25,8 @@ namespace native {
2225
using Tensor = exec_aten::Tensor;
2326
using Scalar = exec_aten::Scalar;
2427
using ScalarType = exec_aten::ScalarType;
28+
using StridesType = exec_aten::StridesType;
29+
using SizesType = exec_aten::SizesType;
2530

2631
namespace {
2732

@@ -61,6 +66,171 @@ void check_dequantize_per_tensor_args(
6166
quant_max);
6267
}
6368

69+
/**
70+
* Useful to reduce a tensor `in` over a given dimension `dim` using the
71+
* reduce function `fn`, which should have the following signature:
72+
* void fn(const size_t size, const size_t stride, const size_t base_ix)
73+
* where `size` and `stride` are the size and stride of the dimension being
74+
* reduced and `base_ix` is the index of the first element of the reduction.
75+
*/
76+
template <typename Fn>
77+
void apply_over_unpacked_dim(
78+
const Fn& fn,
79+
const exec_aten::Tensor& in,
80+
const int64_t& dim) {
81+
if (in.numel() == 0) {
82+
return;
83+
}
84+
85+
ET_CHECK_MSG(in.dim() > 0, "Input tensor must have at least one dimension");
86+
ET_CHECK_VALID_DIM(dim, in.dim());
87+
88+
const size_t d = ET_NORMALIZE_IX(dim, in.dim());
89+
const size_t dim_size = in.size(d);
90+
const size_t outer_size = getLeadingDims(in, d);
91+
const size_t inner_size = getTrailingDims(in, d);
92+
// Loop through all outer dimensions
93+
for (size_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) {
94+
// Loop through dim
95+
for (size_t unpacked_dim_idx = 0; unpacked_dim_idx < dim_size;
96+
++unpacked_dim_idx) {
97+
fn(inner_size, outer_idx, unpacked_dim_idx);
98+
}
99+
}
100+
}
101+
102+
void dequantize_optimized(
103+
const int8_t* in,
104+
const double scale,
105+
const int64_t zero_point,
106+
float* out,
107+
int64_t quant_min,
108+
int64_t quant_max,
109+
size_t numel) {
110+
ET_CHECK_MSG(
111+
zero_point >= quant_min,
112+
"zero_point must be %" PRId64 " <= quant_min %" PRId64,
113+
zero_point,
114+
quant_min);
115+
ET_CHECK_MSG(
116+
zero_point <= quant_max,
117+
"zero_point must be %" PRId64 " >= quant_max %" PRId64,
118+
zero_point,
119+
quant_max);
120+
size_t i = 0;
121+
#if defined(__aarch64__) || defined(__ARM_NEON)
122+
int8x8_t zero_point_vec = vdup_n_s8(zero_point);
123+
float32x4_t scales = vdupq_n_f32(static_cast<float>(scale));
124+
constexpr int32_t kVecSize = 16;
125+
const size_t num_vecs = numel / kVecSize;
126+
const int8_t* in_copy = in;
127+
float* out_copy = out;
128+
for (; i < num_vecs; i++) {
129+
int8x16_t in_vec = vld1q_s8(in_copy);
130+
int16x8_t sub_vec_0_7 = vsubl_s8(vget_low_s8(in_vec), zero_point_vec);
131+
int32x4_t sub_vec_0_3 = vmovl_s16(vget_low_s16(sub_vec_0_7));
132+
int32x4_t sub_vec_4_7 = vmovl_s16(vget_high_s16(sub_vec_0_7));
133+
float32x4_t out_vec_0_3 = vmulq_f32(vcvtq_f32_s32(sub_vec_0_3), scales);
134+
float32x4_t out_vec_4_7 = vmulq_f32(vcvtq_f32_s32(sub_vec_4_7), scales);
135+
136+
int16x8_t sub_vec_8_15 = vsubl_s8(vget_high_s8(in_vec), zero_point_vec);
137+
int32x4_t sub_vec_8_11 = vmovl_s16(vget_low_s16(sub_vec_8_15));
138+
int32x4_t sub_vec_12_15 = vmovl_s16(vget_high_s16(sub_vec_8_15));
139+
float32x4_t out_vec_8_11 = vmulq_f32(vcvtq_f32_s32(sub_vec_8_11), scales);
140+
float32x4_t out_vec_12_15 = vmulq_f32(vcvtq_f32_s32(sub_vec_12_15), scales);
141+
vst1q_f32(out_copy + 0, out_vec_0_3);
142+
vst1q_f32(out_copy + 4, out_vec_4_7);
143+
vst1q_f32(out_copy + 8, out_vec_8_11);
144+
vst1q_f32(out_copy + 12, out_vec_12_15);
145+
in_copy += kVecSize;
146+
out_copy += kVecSize;
147+
}
148+
i = i * kVecSize;
149+
#endif
150+
for (; i < numel; i++) {
151+
out[i] = (in[i] - zero_point) * scale;
152+
}
153+
}
154+
155+
bool can_use_optimized_dequantize_per_channel(
156+
const Tensor& in,
157+
const ScalarType in_dtype,
158+
exec_aten::optional<ScalarType>& out_dtype) {
159+
bool is_contiguous = false;
160+
#ifdef USE_ATEN_LIB
161+
is_contiguous = in.is_contiguous();
162+
#else
163+
is_contiguous = executorch::runtime::is_contiguous_dim_order(
164+
in.dim_order().data(), in.dim());
165+
#endif
166+
if (!is_contiguous || (in_dtype != ScalarType::Char) ||
167+
(out_dtype.has_value() && out_dtype.value() != ScalarType::Float)) {
168+
return false;
169+
}
170+
return true;
171+
}
172+
173+
void dequantize_per_channel_optimized(
174+
const Tensor& in,
175+
const Tensor& scales,
176+
const optional<Tensor>& opt_zero_points,
177+
Tensor& out,
178+
int64_t axis,
179+
int64_t quant_min,
180+
int64_t quant_max,
181+
ScalarType in_dtype,
182+
exec_aten::optional<ScalarType>& out_dtype) {
183+
check_dequantize_per_tensor_args(
184+
in, quant_min, quant_max, in_dtype, out_dtype, out);
185+
ET_CHECK_MSG(
186+
in_dtype == ScalarType::Char,
187+
"in.scalar_type() %" PRId8 " is not supported:",
188+
static_cast<int8_t>(in.scalar_type()));
189+
if (out_dtype.has_value()) {
190+
ET_CHECK_MSG(
191+
out_dtype.value() == ScalarType::Float,
192+
"Only float output is supported");
193+
}
194+
const int8_t* in_data = in.const_data_ptr<int8_t>();
195+
float* out_data = out.mutable_data_ptr<float>();
196+
const int64_t* zero_points_data = nullptr;
197+
if (opt_zero_points.has_value()) {
198+
zero_points_data = opt_zero_points.value().const_data_ptr<int64_t>();
199+
}
200+
const double* scales_data = scales.const_data_ptr<double>();
201+
const StridesType axis_stride = in.strides()[axis];
202+
const StridesType outer_stride = in.size(axis) * axis_stride;
203+
apply_over_unpacked_dim(
204+
[in_data,
205+
out_data,
206+
scales_data,
207+
zero_points_data,
208+
axis_stride,
209+
outer_stride,
210+
quant_min,
211+
quant_max](
212+
SizesType numel, SizesType outer_idx, SizesType unpacked_dim_idx) {
213+
const int8_t* in_data_local =
214+
in_data + outer_idx * outer_stride + unpacked_dim_idx * axis_stride;
215+
const double scale = scales_data[unpacked_dim_idx];
216+
const int64_t zero_point = zero_points_data != nullptr
217+
? zero_points_data[unpacked_dim_idx]
218+
: 0;
219+
float* out_data_local = out_data + outer_idx * outer_stride +
220+
unpacked_dim_idx * axis_stride;
221+
dequantize_optimized(
222+
in_data_local,
223+
scale,
224+
zero_point,
225+
out_data_local,
226+
quant_min,
227+
quant_max,
228+
numel);
229+
},
230+
in,
231+
axis);
232+
}
233+
64234
} // namespace
65235

66236
/**
@@ -225,6 +395,20 @@ Tensor& dequantize_per_channel_out(
225395
check_dequantize_per_tensor_args(
226396
input, quant_min, quant_max, dtype, out_dtype, out);
227397

398+
if (can_use_optimized_dequantize_per_channel(input, dtype, out_dtype)) {
399+
dequantize_per_channel_optimized(
400+
input,
401+
scale,
402+
opt_zero_points,
403+
out,
404+
axis,
405+
quant_min,
406+
quant_max,
407+
dtype,
408+
out_dtype);
409+
return out;
410+
}
411+
228412
// a list contains all dimensions except axis
229413
int64_t dims[kTensorDimensionLimit];
230414
for (int64_t i = 0; i < input.dim() - 1; i++) {

kernels/quantized/test/op_dequantize_test.cpp

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -118,13 +118,13 @@ TEST(OpDequantizeOutTest, TensorArgOverload) {
118118
EXPECT_TENSOR_EQ(out, expected);
119119
}
120120

121-
TEST(OpDequantizeOutTest, DequantizePerChannel) {
122-
et_pal_init();
123-
TensorFactory<ScalarType::Byte> tf_byte;
121+
template <ScalarType DTYPE>
122+
void test_per_channel_dtype() {
123+
TensorFactory<DTYPE> tf;
124124
TensorFactory<ScalarType::Double> tf_double;
125125
TensorFactory<ScalarType::Long> tf_long;
126126

127-
Tensor input = tf_byte.full({3, 2}, 100);
127+
Tensor input = tf.full({3, 2}, 100);
128128
Tensor scale = tf_double.make({2}, {0.5, 1});
129129
Tensor zero_point = tf_long.make({2}, {30, 60});
130130
int64_t quant_min = 0;
@@ -142,7 +142,7 @@ TEST(OpDequantizeOutTest, DequantizePerChannel) {
142142
/*axis=*/1,
143143
quant_min,
144144
quant_max,
145-
ScalarType::Byte,
145+
DTYPE,
146146
optional<ScalarType>(),
147147
out);
148148

@@ -163,15 +163,15 @@ TEST(OpDequantizeOutTest, DequantizePerChannel) {
163163
/*axis=*/0,
164164
quant_min,
165165
quant_max,
166-
ScalarType::Byte,
166+
DTYPE,
167167
optional<ScalarType>(),
168168
out);
169169

170170
EXPECT_TENSOR_EQ(out, expected);
171171

172172
// Test with a different axis
173173
out = tfo.zeros({3});
174-
input = tf_byte.make({3}, {100, 100, 100});
174+
input = tf.make({3}, {100, 100, 100});
175175
scale = tf_double.make({3}, {0.5, 0.75, 1});
176176
zero_point = tf_long.make({3}, {30, 50, 60});
177177
// (100 - 30) * 0.5
@@ -185,8 +185,42 @@ TEST(OpDequantizeOutTest, DequantizePerChannel) {
185185
/*axis=*/0,
186186
quant_min,
187187
quant_max,
188-
ScalarType::Byte,
188+
DTYPE,
189+
optional<ScalarType>(),
190+
out);
191+
EXPECT_TENSOR_EQ(out, expected);
192+
193+
// Test with a different axis
194+
input = tf.full({3, 19}, 100);
195+
out = tfo.zeros({3, 19});
196+
scale = tf_double.make({3}, {0.5, 0.75, 1});
197+
zero_point = tf_long.make({3}, {30, 50, 60});
198+
// (100 - 30) * 0.5
199+
// (100 - 50) * 0.75
200+
// (100 - 60) * 1
201+
expected = tfo.make(
202+
{3, 19},
203+
{35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35,
204+
35, 35, 35, 35, 35, 35, 35, 37.5, 37.5, 37.5, 37.5, 37.5,
205+
37.5, 37.5, 37.5, 37.5, 37.5, 37.5, 37.5, 37.5, 37.5, 37.5, 37.5, 37.5,
206+
37.5, 37.5, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40,
207+
40, 40, 40, 40, 40, 40, 40, 40, 40});
208+
dequantize_per_channel_out(
209+
input,
210+
scale,
211+
zero_point,
212+
/*axis=*/0,
213+
quant_min,
214+
quant_max,
215+
DTYPE,
189216
optional<ScalarType>(),
190217
out);
218+
191219
EXPECT_TENSOR_EQ(out, expected);
192220
}
221+
222+
TEST(OpDequantizeOutTest, DequantizePerChannel) {
223+
et_pal_init();
224+
test_per_channel_dtype<ScalarType::Byte>();
225+
test_per_channel_dtype<ScalarType::Char>();
226+
}

0 commit comments

Comments
 (0)