Skip to content

Commit ddec0c7

Browse files
[Executorch][quant] Optimize per channel dequantize
Pull Request resolved: #5670 When using quantized kv cache, dequantization routine takes significantly long. This diff just vectorizes dequant per channel for common case. ghstack-source-id: 255730818 @exported-using-ghexport Differential Revision: [D63338858](https://our.internmc.facebook.com/intern/diff/D63338858/) Co-authored-by: Kimish Patel <[email protected]>
1 parent 9d084c4 commit ddec0c7

File tree

2 files changed

+238
-21
lines changed

2 files changed

+238
-21
lines changed

kernels/quantized/cpu/op_dequantize.cpp

Lines changed: 196 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
#include <algorithm>
1212
#include <cinttypes>
1313
#include <cmath>
14+
#if defined(__aarch64__) || defined(__ARM_NEON)
15+
#include <arm_neon.h>
16+
#endif
1417

1518
/**
1619
* For an input tensor, use the scale and zero_point arguments to quantize it.
@@ -22,6 +25,8 @@ namespace native {
2225
using Tensor = exec_aten::Tensor;
2326
using Scalar = exec_aten::Scalar;
2427
using ScalarType = exec_aten::ScalarType;
28+
using StridesType = exec_aten::StridesType;
29+
using SizesType = exec_aten::SizesType;
2530

2631
namespace {
2732

@@ -63,6 +68,183 @@ void check_dequantize_per_tensor_args(
6368
quant_max);
6469
}
6570

71+
/**
72+
* Useful to reduce a tensor `in` over a given dimension `dim` using the
73+
* reduce function `fn`, which should have the following signature:
74+
* void fn(const size_t size, const size_t stride, const size_t base_ix)
75+
* where `size` and `stride` are the size and stride of the dimension being
76+
* reduced and `base_ix` is the index of the first element of the reduction.
77+
*/
78+
template <typename Fn>
79+
void apply_over_unpacked_dim(
80+
const Fn& fn,
81+
const exec_aten::Tensor& in,
82+
const int64_t& dim) {
83+
if (in.numel() == 0) {
84+
return;
85+
}
86+
87+
ET_CHECK_MSG(in.dim() > 0, "Input tensor must have at least one dimension");
88+
ET_CHECK_VALID_DIM(dim, in.dim());
89+
90+
const size_t d = ET_NORMALIZE_IX(dim, in.dim());
91+
const size_t dim_size = in.size(d);
92+
const size_t outer_size = getLeadingDims(in, d);
93+
const size_t inner_size = getTrailingDims(in, d);
94+
// Loop through all outer dimensions
95+
for (size_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) {
96+
// Loop through dim
97+
for (size_t unpacked_dim_idx = 0; unpacked_dim_idx < dim_size;
98+
++unpacked_dim_idx) {
99+
fn(inner_size, outer_idx, unpacked_dim_idx);
100+
}
101+
}
102+
}
103+
104+
void dequantize_optimized(
105+
const int8_t* in,
106+
const double scale,
107+
const int64_t zero_point,
108+
float* out,
109+
int64_t quant_min,
110+
int64_t quant_max,
111+
size_t numel) {
112+
ET_CHECK_MSG(
113+
zero_point >= quant_min,
114+
"zero_point must be %" PRId64 " <= quant_min %" PRId64,
115+
zero_point,
116+
quant_min);
117+
ET_CHECK_MSG(
118+
zero_point <= quant_max,
119+
"zero_point must be %" PRId64 " >= quant_max %" PRId64,
120+
zero_point,
121+
quant_max);
122+
size_t i = 0;
123+
#if defined(__aarch64__) || defined(__ARM_NEON)
124+
int8x8_t zero_point_vec = vdup_n_s8(zero_point);
125+
float32x4_t scales = vdupq_n_f32(static_cast<float>(scale));
126+
constexpr int32_t kVecSize = 16;
127+
const size_t num_vecs = numel / kVecSize;
128+
const int8_t* in_copy = in;
129+
float* out_copy = out;
130+
for (; i < num_vecs; i++) {
131+
int8x16_t in_vec = vld1q_s8(in_copy);
132+
int16x8_t sub_vec_0_7 = vsubl_s8(vget_low_s8(in_vec), zero_point_vec);
133+
int32x4_t sub_vec_0_3 = vmovl_s16(vget_low_s16(sub_vec_0_7));
134+
int32x4_t sub_vec_4_7 = vmovl_s16(vget_high_s16(sub_vec_0_7));
135+
float32x4_t out_vec_0_3 = vmulq_f32(vcvtq_f32_s32(sub_vec_0_3), scales);
136+
float32x4_t out_vec_4_7 = vmulq_f32(vcvtq_f32_s32(sub_vec_4_7), scales);
137+
138+
int16x8_t sub_vec_8_15 = vsubl_s8(vget_high_s8(in_vec), zero_point_vec);
139+
int32x4_t sub_vec_8_11 = vmovl_s16(vget_low_s16(sub_vec_8_15));
140+
int32x4_t sub_vec_12_15 = vmovl_s16(vget_high_s16(sub_vec_8_15));
141+
float32x4_t out_vec_8_11 = vmulq_f32(vcvtq_f32_s32(sub_vec_8_11), scales);
142+
float32x4_t out_vec_12_15 = vmulq_f32(vcvtq_f32_s32(sub_vec_12_15), scales);
143+
vst1q_f32(out_copy + 0, out_vec_0_3);
144+
vst1q_f32(out_copy + 4, out_vec_4_7);
145+
vst1q_f32(out_copy + 8, out_vec_8_11);
146+
vst1q_f32(out_copy + 12, out_vec_12_15);
147+
in_copy += kVecSize;
148+
out_copy += kVecSize;
149+
}
150+
i = i * kVecSize;
151+
#endif
152+
for (; i < numel; i++) {
153+
out[i] = (in[i] - zero_point) * scale;
154+
}
155+
}
156+
157+
float get_scale(const Tensor& scale, size_t channel_ix) {
158+
ET_CHECK_MSG(
159+
(scale.scalar_type() == ScalarType::Double) ||
160+
(scale.scalar_type() == ScalarType::Float),
161+
"scale.scalar_type() %" PRId8 " is not double or float type",
162+
static_cast<int8_t>(scale.scalar_type()));
163+
if (scale.scalar_type() == ScalarType::Double) {
164+
return static_cast<float>(scale.const_data_ptr<double>()[channel_ix]);
165+
} else {
166+
return scale.const_data_ptr<float>()[channel_ix];
167+
}
168+
}
169+
170+
bool can_use_optimized_dequantize_per_channel(
171+
const Tensor& in,
172+
const ScalarType in_dtype,
173+
exec_aten::optional<ScalarType>& out_dtype) {
174+
bool is_contiguous = false;
175+
#ifdef USE_ATEN_LIB
176+
is_contiguous = in.is_contiguous();
177+
#else
178+
is_contiguous = executorch::runtime::is_contiguous_dim_order(
179+
in.dim_order().data(), in.dim());
180+
#endif
181+
if (!is_contiguous || (in_dtype != ScalarType::Char) ||
182+
(out_dtype.has_value() && out_dtype.value() != ScalarType::Float)) {
183+
return false;
184+
}
185+
return true;
186+
}
187+
188+
void dequantize_per_channel_optimized(
189+
const Tensor& in,
190+
const Tensor& scales,
191+
const optional<Tensor>& opt_zero_points,
192+
Tensor& out,
193+
int64_t axis,
194+
int64_t quant_min,
195+
int64_t quant_max,
196+
ScalarType in_dtype,
197+
exec_aten::optional<ScalarType>& out_dtype) {
198+
check_dequantize_per_tensor_args(
199+
in, quant_min, quant_max, in_dtype, out_dtype, out);
200+
ET_CHECK_MSG(
201+
in_dtype == ScalarType::Char,
202+
"in.scalar_type() %" PRId8 " is not supported:",
203+
static_cast<int8_t>(in.scalar_type()));
204+
if (out_dtype.has_value()) {
205+
ET_CHECK_MSG(
206+
out_dtype.value() == ScalarType::Float,
207+
"Only float output is supported");
208+
}
209+
const int8_t* in_data = in.const_data_ptr<int8_t>();
210+
float* out_data = out.mutable_data_ptr<float>();
211+
const int64_t* zero_points_data = nullptr;
212+
if (opt_zero_points.has_value()) {
213+
zero_points_data = opt_zero_points.value().const_data_ptr<int64_t>();
214+
}
215+
const StridesType axis_stride = in.strides()[axis];
216+
const StridesType outer_stride = in.size(axis) * axis_stride;
217+
apply_over_unpacked_dim(
218+
[in_data,
219+
out_data,
220+
&scales,
221+
zero_points_data,
222+
axis_stride,
223+
outer_stride,
224+
quant_min,
225+
quant_max](
226+
SizesType numel, SizesType outer_idx, SizesType unpacked_dim_idx) {
227+
const int8_t* in_data_local =
228+
in_data + outer_idx * outer_stride + unpacked_dim_idx * axis_stride;
229+
const double scale = get_scale(scales, unpacked_dim_idx);
230+
const int64_t zero_point = zero_points_data != nullptr
231+
? zero_points_data[unpacked_dim_idx]
232+
: 0;
233+
float* out_data_local = out_data + outer_idx * outer_stride +
234+
unpacked_dim_idx * axis_stride;
235+
dequantize_optimized(
236+
in_data_local,
237+
scale,
238+
zero_point,
239+
out_data_local,
240+
quant_min,
241+
quant_max,
242+
numel);
243+
},
244+
in,
245+
axis);
246+
}
247+
66248
} // namespace
67249

68250
/**
@@ -172,19 +354,6 @@ Tensor& dequantize_per_tensor_tensor_args_out(
172354
return out;
173355
}
174356

175-
float get_scale(const Tensor& scale, size_t channel_ix) {
176-
ET_CHECK_MSG(
177-
(scale.scalar_type() == ScalarType::Double) ||
178-
(scale.scalar_type() == ScalarType::Float),
179-
"scale.scalar_type() %" PRId8 " is not double or float type",
180-
static_cast<int8_t>(scale.scalar_type()));
181-
if (scale.scalar_type() == ScalarType::Double) {
182-
return static_cast<float>(scale.const_data_ptr<double>()[channel_ix]);
183-
} else {
184-
return scale.const_data_ptr<float>()[channel_ix];
185-
}
186-
}
187-
188357
Tensor& dequantize_per_channel_out(
189358
const Tensor& input,
190359
const Tensor& scale,
@@ -229,6 +398,20 @@ Tensor& dequantize_per_channel_out(
229398
check_dequantize_per_tensor_args(
230399
input, quant_min, quant_max, dtype, out_dtype, out);
231400

401+
if (can_use_optimized_dequantize_per_channel(input, dtype, out_dtype)) {
402+
dequantize_per_channel_optimized(
403+
input,
404+
scale,
405+
opt_zero_points,
406+
out,
407+
axis,
408+
quant_min,
409+
quant_max,
410+
dtype,
411+
out_dtype);
412+
return out;
413+
}
414+
232415
// a list contains all dimensions except axis
233416
int64_t dims[kTensorDimensionLimit];
234417
for (int64_t i = 0; i < input.dim() - 1; i++) {

kernels/quantized/test/op_dequantize_test.cpp

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -123,13 +123,13 @@ TEST(OpDequantizeOutTest, TensorArgOverload) {
123123
EXPECT_TENSOR_EQ(out, expected);
124124
}
125125

126-
TEST(OpDequantizeOutTest, DequantizePerChannel) {
127-
et_pal_init();
128-
TensorFactory<ScalarType::Byte> tf_byte;
126+
template <ScalarType DTYPE>
127+
void test_per_channel_dtype() {
128+
TensorFactory<DTYPE> tf;
129129
TensorFactory<ScalarType::Double> tf_double;
130130
TensorFactory<ScalarType::Long> tf_long;
131131

132-
Tensor input = tf_byte.full({3, 2}, 100);
132+
Tensor input = tf.full({3, 2}, 100);
133133
Tensor scale = tf_double.make({2}, {0.5, 1});
134134
Tensor zero_point = tf_long.make({2}, {30, 60});
135135
int64_t quant_min = 0;
@@ -147,7 +147,7 @@ TEST(OpDequantizeOutTest, DequantizePerChannel) {
147147
/*axis=*/1,
148148
quant_min,
149149
quant_max,
150-
ScalarType::Byte,
150+
DTYPE,
151151
optional<ScalarType>(),
152152
out);
153153

@@ -168,15 +168,15 @@ TEST(OpDequantizeOutTest, DequantizePerChannel) {
168168
/*axis=*/0,
169169
quant_min,
170170
quant_max,
171-
ScalarType::Byte,
171+
DTYPE,
172172
optional<ScalarType>(),
173173
out);
174174

175175
EXPECT_TENSOR_EQ(out, expected);
176176

177177
// Test with a different axis
178178
out = tfo.zeros({3});
179-
input = tf_byte.make({3}, {100, 100, 100});
179+
input = tf.make({3}, {100, 100, 100});
180180
scale = tf_double.make({3}, {0.5, 0.75, 1});
181181
zero_point = tf_long.make({3}, {30, 50, 60});
182182
// (100 - 30) * 0.5
@@ -190,8 +190,42 @@ TEST(OpDequantizeOutTest, DequantizePerChannel) {
190190
/*axis=*/0,
191191
quant_min,
192192
quant_max,
193-
ScalarType::Byte,
193+
DTYPE,
194+
optional<ScalarType>(),
195+
out);
196+
EXPECT_TENSOR_EQ(out, expected);
197+
198+
// Test with a different axis
199+
input = tf.full({3, 19}, 100);
200+
out = tfo.zeros({3, 19});
201+
scale = tf_double.make({3}, {0.5, 0.75, 1});
202+
zero_point = tf_long.make({3}, {30, 50, 60});
203+
// (100 - 30) * 0.5
204+
// (100 - 50) * 0.75
205+
// (100 - 60) * 1
206+
expected = tfo.make(
207+
{3, 19},
208+
{35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35,
209+
35, 35, 35, 35, 35, 35, 35, 37.5, 37.5, 37.5, 37.5, 37.5,
210+
37.5, 37.5, 37.5, 37.5, 37.5, 37.5, 37.5, 37.5, 37.5, 37.5, 37.5, 37.5,
211+
37.5, 37.5, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40,
212+
40, 40, 40, 40, 40, 40, 40, 40, 40});
213+
dequantize_per_channel_out(
214+
input,
215+
scale,
216+
zero_point,
217+
/*axis=*/0,
218+
quant_min,
219+
quant_max,
220+
DTYPE,
194221
optional<ScalarType>(),
195222
out);
223+
196224
EXPECT_TENSOR_EQ(out, expected);
197225
}
226+
227+
TEST(OpDequantizeOutTest, DequantizePerChannel) {
228+
et_pal_init();
229+
test_per_channel_dtype<ScalarType::Byte>();
230+
test_per_channel_dtype<ScalarType::Char>();
231+
}

0 commit comments

Comments
 (0)