Skip to content

Commit bed6c96

Browse files
committed
[Executorch][quant] Optimize per channel dequantize
When using quantized kv cache, dequantization routine takes significantly long. This diff just vectorizes dequant per channel for common case. Differential Revision: [D63338858](https://our.internmc.facebook.com/intern/diff/D63338858/) [ghstack-poisoned]
1 parent 540e91e commit bed6c96

File tree

1 file changed

+173
-0
lines changed

1 file changed

+173
-0
lines changed

kernels/quantized/cpu/op_dequantize.cpp

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ namespace native {
2222
using Tensor = exec_aten::Tensor;
2323
using Scalar = exec_aten::Scalar;
2424
using ScalarType = exec_aten::ScalarType;
25+
using StridesType = exec_aten::StridesType;
26+
using SizesType = exec_aten::SizesType;
2527

2628
namespace {
2729

@@ -61,6 +63,163 @@ void check_dequantize_per_tensor_args(
6163
quant_max);
6264
}
6365

66+
/**
67+
* Useful to reduce a tensor `in` over a given dimension `dim` using the
68+
* reduce function `fn`, which should have the following signature:
69+
* void fn(const size_t size, const size_t stride, const size_t base_ix)
70+
* where `size` and `stride` are the size and stride of the dimension being
71+
* reduced and `base_ix` is the index of the first element of the reduction.
72+
*/
73+
template <typename Fn>
74+
void apply_over_unpacked_dim(
75+
const Fn& fn,
76+
const exec_aten::Tensor& in,
77+
const int64_t& dim) {
78+
if (in.numel() == 0) {
79+
return;
80+
}
81+
82+
ET_CHECK_MSG(in.dim() > 0, "Input tensor must have at least one dimension");
83+
ET_CHECK_VALID_DIM(dim, in.dim());
84+
85+
const size_t d = ET_NORMALIZE_IX(dim, in.dim());
86+
const size_t dim_size = in.size(d);
87+
const size_t outer_size = getLeadingDims(in, d);
88+
const size_t inner_size = getTrailingDims(in, d);
89+
// Loop through all outer dimensions
90+
for (size_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) {
91+
// Loop through dim
92+
for (size_t unpacked_dim_idx = 0; unpacked_dim_idx < dim_size;
93+
++unpacked_dim_idx) {
94+
fn(inner_size, outer_idx, unpacked_dim_idx);
95+
}
96+
}
97+
}
98+
99+
void dequantize_optimized(
100+
const int8_t* in,
101+
const double scale,
102+
const int64_t zero_point,
103+
float* out,
104+
int64_t quant_min,
105+
int64_t quant_max,
106+
size_t numel) {
107+
ET_CHECK_MSG(
108+
zero_point >= quant_min,
109+
"zero_point must be %" PRId64 " <= quant_min %" PRId64,
110+
zero_point,
111+
quant_min);
112+
ET_CHECK_MSG(
113+
zero_point <= quant_max,
114+
"zero_point must be %" PRId64 " >= quant_max %" PRId64,
115+
zero_point,
116+
quant_max);
117+
size_t i = 0;
118+
#if defined(__aarch64__) || defined(__ARM_NEON)
119+
int8x8_t zero_point_vec = vdup_n_s8(zero_point);
120+
float32x4_t scales = vdupq_n_f32(static_cast<float>(scale));
121+
constexpr int32_t kVecSize = 16;
122+
const size_t num_vecs = numel / kVecSize;
123+
const size_t rem = numel % kVecSize;
124+
for (; i < numel; i += kVecSize) {
125+
int8x16_t in_vec = vld1q_s8(in);
126+
int16x8_t sub_vec_0_7 = vsubl_s8(vget_low_s8(in_vec), zero_point_vec);
127+
int32x4_t sub_vec_0_3 = vmovl_s16(vget_low_s16(sub_vec_0_7));
128+
int32x4_t sub_vec_4_7 = vmovl_s16(vget_high_s16(sub_vec_0_7));
129+
float32x4_t out_vec_0_3 = vmulq_f32(vcvtq_f32_s32(sub_vec_0_3), scales);
130+
float32x4_t out_vec_4_7 = vmulq_f32(vcvtq_f32_s32(sub_vec_4_7), scales);
131+
132+
int16x8_t sub_vec_8_15 = vsubl_s8(vget_high_s8(in_vec), zero_point_vec);
133+
int32x4_t sub_vec_8_11 = vmovl_s16(vget_low_s16(sub_vec_8_15));
134+
int32x4_t sub_vec_12_15 = vmovl_s16(vget_high_s16(sub_vec_8_15));
135+
float32x4_t out_vec_8_11 = vmulq_f32(vcvtq_f32_s32(sub_vec_8_11), scales);
136+
float32x4_t out_vec_12_15 = vmulq_f32(vcvtq_f32_s32(sub_vec_12_15), scales);
137+
in += kVecSize;
138+
}
139+
#endif
140+
for (; i < numel; i++) {
141+
out[i] = (in[i] - zero_point) * scale;
142+
}
143+
}
144+
145+
bool can_use_optimized_dequantize_per_channel(
146+
const Tensor& in,
147+
const ScalarType in_dtype,
148+
exec_aten::optional<ScalarType>& out_dtype) {
149+
if (!executorch::runtime::is_contiguous_dim_order(
150+
in.dim_order().data(), in.dim()) ||
151+
(in_dtype != ScalarType::Char) ||
152+
(out_dtype.has_value() && out_dtype.value() != ScalarType::Float)) {
153+
return false;
154+
}
155+
return true;
156+
}
157+
158+
void dequantize_per_channel_optimized(
159+
const Tensor& in,
160+
const Tensor& scales,
161+
const optional<Tensor>& opt_zero_points,
162+
Tensor& out,
163+
int64_t axis,
164+
int64_t quant_min,
165+
int64_t quant_max,
166+
ScalarType in_dtype,
167+
exec_aten::optional<ScalarType>& out_dtype) {
168+
check_dequantize_per_tensor_args(
169+
in, quant_min, quant_max, in_dtype, out_dtype, out);
170+
ET_CHECK_MSG(
171+
executorch::runtime::is_contiguous_dim_order(
172+
in.dim_order().data(), in.dim()),
173+
"in must be in contiguous dim order");
174+
ET_CHECK_MSG(
175+
in_dtype == ScalarType::Char,
176+
"in.scalar_type() %" PRId8 " is not supported:",
177+
static_cast<int8_t>(in.scalar_type()));
178+
if (out_dtype.has_value()) {
179+
ET_CHECK_MSG(
180+
out_dtype.value() == ScalarType::Float,
181+
"Only float output is supported");
182+
}
183+
const int8_t* in_data = in.const_data_ptr<int8_t>();
184+
float* out_data = out.mutable_data_ptr<float>();
185+
const int64_t* zero_points_data = nullptr;
186+
if (opt_zero_points.has_value()) {
187+
zero_points_data = opt_zero_points.value().const_data_ptr<int64_t>();
188+
}
189+
const double* scales_data = scales.const_data_ptr<double>();
190+
const StridesType axis_stride = in.strides()[axis];
191+
const StridesType outer_stride = in.size(axis) * axis_stride;
192+
apply_over_unpacked_dim(
193+
[in_data,
194+
out_data,
195+
scales_data,
196+
zero_points_data,
197+
axis_stride,
198+
outer_stride,
199+
quant_min,
200+
quant_max](
201+
SizesType numel, SizesType outer_idx, SizesType unpacked_dim_idx) {
202+
const int8_t* in_data_local =
203+
in_data + outer_idx * outer_stride + unpacked_dim_idx * axis_stride;
204+
const double scale = scales_data[unpacked_dim_idx];
205+
const int64_t zero_point = zero_points_data != nullptr
206+
? zero_points_data[unpacked_dim_idx]
207+
: 0;
208+
float* out_data_local = out_data + outer_idx * outer_stride +
209+
unpacked_dim_idx * axis_stride;
210+
dequantize_optimized(
211+
in_data_local,
212+
scale,
213+
zero_point,
214+
out_data_local,
215+
quant_min,
216+
quant_max,
217+
numel);
218+
},
219+
in,
220+
axis);
221+
}
222+
64223
} // namespace
65224

66225
/**
@@ -217,6 +376,20 @@ Tensor& dequantize_per_channel_out(
217376
check_dequantize_per_tensor_args(
218377
input, quant_min, quant_max, dtype, out_dtype, out);
219378

379+
if (can_use_optimized_dequantize_per_channel(input, dtype, out_dtype)) {
380+
dequantize_per_channel_optimized(
381+
input,
382+
scale,
383+
opt_zero_points,
384+
out,
385+
axis,
386+
quant_min,
387+
quant_max,
388+
dtype,
389+
out_dtype);
390+
return out;
391+
}
392+
220393
// a list contains all dimensions except axis
221394
int64_t dims[kTensorDimensionLimit];
222395
for (int64_t i = 0; i < input.dim() - 1; i++) {

0 commit comments

Comments
 (0)