diff --git a/kernels/quantized/cpu/op_dequantize.cpp b/kernels/quantized/cpu/op_dequantize.cpp index 0620fafd363..312ad243d99 100644 --- a/kernels/quantized/cpu/op_dequantize.cpp +++ b/kernels/quantized/cpu/op_dequantize.cpp @@ -22,6 +22,8 @@ namespace native { using Tensor = exec_aten::Tensor; using Scalar = exec_aten::Scalar; using ScalarType = exec_aten::ScalarType; +using StridesType = exec_aten::StridesType; +using SizesType = exec_aten::SizesType; namespace { @@ -61,6 +63,163 @@ void check_dequantize_per_tensor_args( quant_max); } +/** + * Useful to reduce a tensor `in` over a given dimension `dim` using the + * reduce function `fn`, which should have the following signature: + * void fn(const size_t size, const size_t stride, const size_t base_ix) + * where `size` and `stride` are the size and stride of the dimension being + * reduced and `base_ix` is the index of the first element of the reduction. + */ +template +void apply_over_unpacked_dim( + const Fn& fn, + const exec_aten::Tensor& in, + const int64_t& dim) { + if (in.numel() == 0) { + return; + } + + ET_CHECK_MSG(in.dim() > 0, "Input tensor must have at least one dimension"); + ET_CHECK_VALID_DIM(dim, in.dim()); + + const size_t d = ET_NORMALIZE_IX(dim, in.dim()); + const size_t dim_size = in.size(d); + const size_t outer_size = getLeadingDims(in, d); + const size_t inner_size = getTrailingDims(in, d); + // Loop through all outer dimensions + for (size_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) { + // Loop through dim + for (size_t unpacked_dim_idx = 0; unpacked_dim_idx < dim_size; + ++unpacked_dim_idx) { + fn(inner_size, outer_idx, unpacked_dim_idx); + } + } +} + +void dequantize_optimized( + const int8_t* in, + const double scale, + const int64_t zero_point, + float* out, + int64_t quant_min, + int64_t quant_max, + size_t numel) { + ET_CHECK_MSG( + zero_point >= quant_min, + "zero_point must be %" PRId64 " <= quant_min %" PRId64, + zero_point, + quant_min); + ET_CHECK_MSG( + zero_point <= quant_max, + "zero_point must be %" PRId64 " >= quant_max %" PRId64, + zero_point, + quant_max); + size_t i = 0; +#if defined(__aarch64__) || defined(__ARM_NEON) + int8x8_t zero_point_vec = vdup_n_s8(zero_point); + float32x4_t scales = vdupq_n_f32(static_cast(scale)); + constexpr int32_t kVecSize = 16; + const size_t num_vecs = numel / kVecSize; + const size_t rem = numel % kVecSize; + for (; i < numel; i += kVecSize) { + int8x16_t in_vec = vld1q_s8(in); + int16x8_t sub_vec_0_7 = vsubl_s8(vget_low_s8(in_vec), zero_point_vec); + int32x4_t sub_vec_0_3 = vmovl_s16(vget_low_s16(sub_vec_0_7)); + int32x4_t sub_vec_4_7 = vmovl_s16(vget_high_s16(sub_vec_0_7)); + float32x4_t out_vec_0_3 = vmulq_f32(vcvtq_f32_s32(sub_vec_0_3), scales); + float32x4_t out_vec_4_7 = vmulq_f32(vcvtq_f32_s32(sub_vec_4_7), scales); + + int16x8_t sub_vec_8_15 = vsubl_s8(vget_high_s8(in_vec), zero_point_vec); + int32x4_t sub_vec_8_11 = vmovl_s16(vget_low_s16(sub_vec_8_15)); + int32x4_t sub_vec_12_15 = vmovl_s16(vget_high_s16(sub_vec_8_15)); + float32x4_t out_vec_8_11 = vmulq_f32(vcvtq_f32_s32(sub_vec_8_11), scales); + float32x4_t out_vec_12_15 = vmulq_f32(vcvtq_f32_s32(sub_vec_12_15), scales); + in += kVecSize; + } +#endif + for (; i < numel; i++) { + out[i] = (in[i] - zero_point) * scale; + } +} + +bool can_use_optimized_dequantize_per_channel( + const Tensor& in, + const ScalarType in_dtype, + exec_aten::optional& out_dtype) { + if (!executorch::runtime::is_contiguous_dim_order( + in.dim_order().data(), in.dim()) || + (in_dtype != ScalarType::Char) || + (out_dtype.has_value() && out_dtype.value() != ScalarType::Float)) { + return false; + } + return true; +} + +void dequantize_per_channel_optimized( + const Tensor& in, + const Tensor& scales, + const optional& opt_zero_points, + Tensor& out, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + ScalarType in_dtype, + exec_aten::optional& out_dtype) { + check_dequantize_per_tensor_args( + in, quant_min, quant_max, in_dtype, out_dtype, out); + ET_CHECK_MSG( + executorch::runtime::is_contiguous_dim_order( + in.dim_order().data(), in.dim()), + "in must be in contiguous dim order"); + ET_CHECK_MSG( + in_dtype == ScalarType::Char, + "in.scalar_type() %" PRId8 " is not supported:", + static_cast(in.scalar_type())); + if (out_dtype.has_value()) { + ET_CHECK_MSG( + out_dtype.value() == ScalarType::Float, + "Only float output is supported"); + } + const int8_t* in_data = in.const_data_ptr(); + float* out_data = out.mutable_data_ptr(); + const int64_t* zero_points_data = nullptr; + if (opt_zero_points.has_value()) { + zero_points_data = opt_zero_points.value().const_data_ptr(); + } + const double* scales_data = scales.const_data_ptr(); + const StridesType axis_stride = in.strides()[axis]; + const StridesType outer_stride = in.size(axis) * axis_stride; + apply_over_unpacked_dim( + [in_data, + out_data, + scales_data, + zero_points_data, + axis_stride, + outer_stride, + quant_min, + quant_max]( + SizesType numel, SizesType outer_idx, SizesType unpacked_dim_idx) { + const int8_t* in_data_local = + in_data + outer_idx * outer_stride + unpacked_dim_idx * axis_stride; + const double scale = scales_data[unpacked_dim_idx]; + const int64_t zero_point = zero_points_data != nullptr + ? zero_points_data[unpacked_dim_idx] + : 0; + float* out_data_local = out_data + outer_idx * outer_stride + + unpacked_dim_idx * axis_stride; + dequantize_optimized( + in_data_local, + scale, + zero_point, + out_data_local, + quant_min, + quant_max, + numel); + }, + in, + axis); +} + } // namespace /** @@ -217,6 +376,20 @@ Tensor& dequantize_per_channel_out( check_dequantize_per_tensor_args( input, quant_min, quant_max, dtype, out_dtype, out); + if (can_use_optimized_dequantize_per_channel(input, dtype, out_dtype)) { + dequantize_per_channel_optimized( + input, + scale, + opt_zero_points, + out, + axis, + quant_min, + quant_max, + dtype, + out_dtype); + return out; + } + // a list contains all dimensions except axis int64_t dims[kTensorDimensionLimit]; for (int64_t i = 0; i < input.dim() - 1; i++) {