Skip to content

[Executorch][quant] Optimize per channel dequantize #5622

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 173 additions & 0 deletions kernels/quantized/cpu/op_dequantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ namespace native {
using Tensor = exec_aten::Tensor;
using Scalar = exec_aten::Scalar;
using ScalarType = exec_aten::ScalarType;
using StridesType = exec_aten::StridesType;
using SizesType = exec_aten::SizesType;

namespace {

Expand Down Expand Up @@ -61,6 +63,163 @@ void check_dequantize_per_tensor_args(
quant_max);
}

/**
* Useful to reduce a tensor `in` over a given dimension `dim` using the
* reduce function `fn`, which should have the following signature:
* void fn(const size_t size, const size_t stride, const size_t base_ix)
* where `size` and `stride` are the size and stride of the dimension being
* reduced and `base_ix` is the index of the first element of the reduction.
*/
template <typename Fn>
void apply_over_unpacked_dim(
const Fn& fn,
const exec_aten::Tensor& in,
const int64_t& dim) {
if (in.numel() == 0) {
return;
}

ET_CHECK_MSG(in.dim() > 0, "Input tensor must have at least one dimension");
ET_CHECK_VALID_DIM(dim, in.dim());

const size_t d = ET_NORMALIZE_IX(dim, in.dim());
const size_t dim_size = in.size(d);
const size_t outer_size = getLeadingDims(in, d);
const size_t inner_size = getTrailingDims(in, d);
// Loop through all outer dimensions
for (size_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) {
// Loop through dim
for (size_t unpacked_dim_idx = 0; unpacked_dim_idx < dim_size;
++unpacked_dim_idx) {
fn(inner_size, outer_idx, unpacked_dim_idx);
}
}
}

void dequantize_optimized(
const int8_t* in,
const double scale,
const int64_t zero_point,
float* out,
int64_t quant_min,
int64_t quant_max,
size_t numel) {
ET_CHECK_MSG(
zero_point >= quant_min,
"zero_point must be %" PRId64 " <= quant_min %" PRId64,
zero_point,
quant_min);
ET_CHECK_MSG(
zero_point <= quant_max,
"zero_point must be %" PRId64 " >= quant_max %" PRId64,
zero_point,
quant_max);
size_t i = 0;
#if defined(__aarch64__) || defined(__ARM_NEON)
int8x8_t zero_point_vec = vdup_n_s8(zero_point);
float32x4_t scales = vdupq_n_f32(static_cast<float>(scale));
constexpr int32_t kVecSize = 16;
const size_t num_vecs = numel / kVecSize;
const size_t rem = numel % kVecSize;
for (; i < numel; i += kVecSize) {
int8x16_t in_vec = vld1q_s8(in);
int16x8_t sub_vec_0_7 = vsubl_s8(vget_low_s8(in_vec), zero_point_vec);
int32x4_t sub_vec_0_3 = vmovl_s16(vget_low_s16(sub_vec_0_7));
int32x4_t sub_vec_4_7 = vmovl_s16(vget_high_s16(sub_vec_0_7));
float32x4_t out_vec_0_3 = vmulq_f32(vcvtq_f32_s32(sub_vec_0_3), scales);
float32x4_t out_vec_4_7 = vmulq_f32(vcvtq_f32_s32(sub_vec_4_7), scales);

int16x8_t sub_vec_8_15 = vsubl_s8(vget_high_s8(in_vec), zero_point_vec);
int32x4_t sub_vec_8_11 = vmovl_s16(vget_low_s16(sub_vec_8_15));
int32x4_t sub_vec_12_15 = vmovl_s16(vget_high_s16(sub_vec_8_15));
float32x4_t out_vec_8_11 = vmulq_f32(vcvtq_f32_s32(sub_vec_8_11), scales);
float32x4_t out_vec_12_15 = vmulq_f32(vcvtq_f32_s32(sub_vec_12_15), scales);
in += kVecSize;
}
#endif
for (; i < numel; i++) {
out[i] = (in[i] - zero_point) * scale;
}
}

bool can_use_optimized_dequantize_per_channel(
const Tensor& in,
const ScalarType in_dtype,
exec_aten::optional<ScalarType>& out_dtype) {
if (!executorch::runtime::is_contiguous_dim_order(
in.dim_order().data(), in.dim()) ||
(in_dtype != ScalarType::Char) ||
(out_dtype.has_value() && out_dtype.value() != ScalarType::Float)) {
return false;
}
return true;
}

void dequantize_per_channel_optimized(
const Tensor& in,
const Tensor& scales,
const optional<Tensor>& opt_zero_points,
Tensor& out,
int64_t axis,
int64_t quant_min,
int64_t quant_max,
ScalarType in_dtype,
exec_aten::optional<ScalarType>& out_dtype) {
check_dequantize_per_tensor_args(
in, quant_min, quant_max, in_dtype, out_dtype, out);
ET_CHECK_MSG(
executorch::runtime::is_contiguous_dim_order(
in.dim_order().data(), in.dim()),
"in must be in contiguous dim order");
ET_CHECK_MSG(
in_dtype == ScalarType::Char,
"in.scalar_type() %" PRId8 " is not supported:",
static_cast<int8_t>(in.scalar_type()));
if (out_dtype.has_value()) {
ET_CHECK_MSG(
out_dtype.value() == ScalarType::Float,
"Only float output is supported");
}
const int8_t* in_data = in.const_data_ptr<int8_t>();
float* out_data = out.mutable_data_ptr<float>();
const int64_t* zero_points_data = nullptr;
if (opt_zero_points.has_value()) {
zero_points_data = opt_zero_points.value().const_data_ptr<int64_t>();
}
const double* scales_data = scales.const_data_ptr<double>();
const StridesType axis_stride = in.strides()[axis];
const StridesType outer_stride = in.size(axis) * axis_stride;
apply_over_unpacked_dim(
[in_data,
out_data,
scales_data,
zero_points_data,
axis_stride,
outer_stride,
quant_min,
quant_max](
SizesType numel, SizesType outer_idx, SizesType unpacked_dim_idx) {
const int8_t* in_data_local =
in_data + outer_idx * outer_stride + unpacked_dim_idx * axis_stride;
const double scale = scales_data[unpacked_dim_idx];
const int64_t zero_point = zero_points_data != nullptr
? zero_points_data[unpacked_dim_idx]
: 0;
float* out_data_local = out_data + outer_idx * outer_stride +
unpacked_dim_idx * axis_stride;
dequantize_optimized(
in_data_local,
scale,
zero_point,
out_data_local,
quant_min,
quant_max,
numel);
},
in,
axis);
}

} // namespace

/**
Expand Down Expand Up @@ -217,6 +376,20 @@ Tensor& dequantize_per_channel_out(
check_dequantize_per_tensor_args(
input, quant_min, quant_max, dtype, out_dtype, out);

if (can_use_optimized_dequantize_per_channel(input, dtype, out_dtype)) {
dequantize_per_channel_optimized(
input,
scale,
opt_zero_points,
out,
axis,
quant_min,
quant_max,
dtype,
out_dtype);
return out;
}

// a list contains all dimensions except axis
int64_t dims[kTensorDimensionLimit];
for (int64_t i = 0; i < input.dim() - 1; i++) {
Expand Down
Loading