Skip to content

Commit dfeb094

Browse files
Nicoshevfacebook-github-bot
authored andcommitted
Add NEON implementation of Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf (#3707)
Summary: X-link: facebookresearch/FBGEMM#789 Pull Request resolved: #3707 QuantUtilsNeon.cc has been introduced, Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfNeon implemented as first function We have observed a ~x12 performance improvement for the downcasting case. The case where a float32_t is returned maintains the same speed: Full results: before: P1732996851 after: P1732996401 Differential Revision: D69573878
1 parent dcefaaa commit dfeb094

File tree

5 files changed

+157
-0
lines changed

5 files changed

+157
-0
lines changed

defs.bzl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def get_fbgemm_public_headers():
7373
"include/fbgemm/QuantUtils.h",
7474
"include/fbgemm/QuantUtilsAvx2.h",
7575
"include/fbgemm/QuantUtilsAvx512.h",
76+
"include/fbgemm/QuantUtilsNeon.h",
7677
"include/fbgemm/spmmUtils.h",
7778
"include/fbgemm/spmmUtilsAvx2.h",
7879
"include/fbgemm/SimdUtils.h",
@@ -153,6 +154,7 @@ def get_fbgemm_inline_sve_srcs(msvc = False, buck = False):
153154
intrinsics_srcs = [
154155
"src/FbgemmFP16UKernelsSve128.cc",
155156
"src/KleidiAIFP16UKernelsNeon.cc",
157+
"src/QuantUtilsNeon.cc",
156158
"src/UtilsSve.cc",
157159
] + select({
158160
"DEFAULT": [],
@@ -165,6 +167,7 @@ def get_fbgemm_inline_sve_srcs(msvc = False, buck = False):
165167
asm_srcs = [
166168
"src/FbgemmFP16UKernelsSve128.cc",
167169
"src/KleidiAIFP16UKernelsNeon.cc",
170+
"src/QuantUtilsNeon.cc",
168171
"src/UtilsSve.cc",
169172
] + select({
170173
"DEFAULT": [],

include/fbgemm/QuantUtils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include "./FbgemmBuild.h"
1212
#include "./QuantUtilsAvx2.h"
13+
#include "./QuantUtilsNeon.h"
1314
#include "./Types.h"
1415
#include "./Utils.h"
1516

include/fbgemm/QuantUtilsNeon.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#ifdef __aarch64__
12+
13+
#include <cstdint>
14+
#include "./FbgemmBuild.h"
15+
16+
/// @defgroup fbgemm-quant-utils-avx2 Quantization Utilities (AVX2)
17+
///
18+
19+
namespace fbgemm {
20+
21+
////////////////////////////////////////////////////////////////////////////////
22+
// Utility functions
23+
////////////////////////////////////////////////////////////////////////////////
24+
25+
template <typename OutputType>
26+
void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfNeon(
27+
const std::uint8_t* input,
28+
size_t input_rows,
29+
int input_columns,
30+
OutputType* output);
31+
32+
} // namespace fbgemm
33+
34+
#endif // __aarch64__

src/QuantUtils.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -836,6 +836,10 @@ void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf(
836836
size_t input_rows,
837837
int input_columns,
838838
OutputType* output) {
839+
#if HAVE_SVE
840+
Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfNeon<OutputType>(
841+
input, input_rows, input_columns, output);
842+
#else
839843
if (cpuinfo_initialize() && fbgemmHasAvx2Support()) {
840844
#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
841845
Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2<OutputType>(
@@ -845,6 +849,7 @@ void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf(
845849
Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfRef<OutputType>(
846850
input, input_rows, input_columns, output);
847851
}
852+
#endif
848853
}
849854

850855
#define INSTANTIATE_QuantizationFunctions(type) \

src/QuantUtilsNeon.cc

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "fbgemm/Utils.h"
10+
11+
#if HAVE_SVE
12+
13+
#define FBGEMM_EXPORTS
14+
#include <arm_neon.h>
15+
#include <arm_sve.h>
16+
17+
#include <arm_neon_sve_bridge.h>
18+
#include <algorithm> //for std::min/std::max
19+
#include <cassert> //for assert
20+
#include <cfloat> // for FLT_MAX
21+
#include <cmath> //for nearbyint
22+
#include <cstring> //for memcpy
23+
#include <limits> //for numeric_limits
24+
#include "fbgemm/QuantUtilsNeon.h"
25+
#include "fbgemm/Types.h"
26+
27+
namespace fbgemm {
28+
29+
using namespace std;
30+
////////////////////////////////////////////////////////////////////////////////
31+
// Utility functions
32+
33+
template <typename OutputType>
34+
void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfNeon(
35+
const std::uint8_t* input,
36+
size_t input_rows,
37+
int input_columns,
38+
OutputType* output) {
39+
int output_columns = input_columns - 2 * sizeof(float);
40+
41+
for (size_t row = 0; row < input_rows; ++row) {
42+
const std::uint8_t* input_row = input + row * input_columns;
43+
const float* input_row_scale_bias =
44+
reinterpret_cast<const float*>(input_row + output_columns);
45+
OutputType* output_row = output + row * output_columns;
46+
47+
svbool_t pred = svptrue_b32();
48+
49+
float scale = input_row_scale_bias[0];
50+
float bias = input_row_scale_bias[1];
51+
svfloat32_t scale_v = svdup_n_f32(scale);
52+
svfloat32_t bias_v = svdup_n_f32(bias);
53+
54+
const uint64_t* input_row_v_0 =
55+
reinterpret_cast<const uint64_t*>(input_row);
56+
const uint64_t* input_row_v_1 =
57+
reinterpret_cast<const uint64_t*>(input_row + 4);
58+
float32x4x2_t* output_row_v = reinterpret_cast<float32x4x2_t*>(output_row);
59+
float16x8_t* output_row_v_half = reinterpret_cast<float16x8_t*>(output_row);
60+
61+
int colIndex = 0;
62+
for (int colMax = output_columns / 8; colIndex < colMax; ++colIndex) {
63+
svuint32_t in_v_0 = svld1ub_u32(
64+
pred, reinterpret_cast<const uint8_t*>(input_row_v_0 + colIndex));
65+
svuint32_t in_v_1 = svld1ub_u32(
66+
pred, reinterpret_cast<const uint8_t*>(input_row_v_1 + colIndex));
67+
svfloat32_t in_v_0_f = svcvt_f32_u32_x(pred, in_v_0);
68+
svfloat32_t in_v_1_f = svcvt_f32_u32_x(pred, in_v_1);
69+
70+
in_v_0_f = svmad_f32_m(pred, in_v_0_f, scale_v, bias_v);
71+
in_v_1_f = svmad_f32_m(pred, in_v_1_f, scale_v, bias_v);
72+
73+
if constexpr (std::is_same<OutputType, float>()) {
74+
output_row_v[colIndex].val[0] = svget_neonq(in_v_0_f);
75+
output_row_v[colIndex].val[1] = svget_neonq(in_v_1_f);
76+
} else {
77+
float16x4_t dequantzed_v_half_low_low =
78+
vcvt_f16_f32(svget_neonq(in_v_0_f));
79+
float16x8_t dequantzed_v_half_low =
80+
vcvt_high_f16_f32(dequantzed_v_half_low_low, svget_neonq(in_v_1_f));
81+
output_row_v_half[colIndex] = dequantzed_v_half_low;
82+
}
83+
}
84+
85+
#pragma clang loop vectorize(disable)
86+
#pragma clang loop unroll(disable)
87+
for (colIndex *= 8; colIndex < output_columns; ++colIndex) {
88+
float output_value = input_row[colIndex] * input_row_scale_bias[0] +
89+
input_row_scale_bias[1];
90+
if (std::is_same<OutputType, float>()) {
91+
output_row[colIndex] = output_value;
92+
} else {
93+
output_row[colIndex] = cpu_float2half_rn(output_value);
94+
}
95+
}
96+
} // for each row
97+
}
98+
99+
#define INSTANTIATE_QuantizationNeonFunctions8Bits(type) \
100+
template void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfNeon<type>( \
101+
const std::uint8_t* input, \
102+
size_t input_rows, \
103+
int input_columns, \
104+
type* output);
105+
106+
// clang-format off
107+
INSTANTIATE_QuantizationNeonFunctions8Bits(float)
108+
INSTANTIATE_QuantizationNeonFunctions8Bits(float16)
109+
// clang-format on
110+
#undef INSTANTIATE_QuantizationNeonFunctions8Bits
111+
112+
} // namespace fbgemm
113+
114+
#endif // __aarch64__

0 commit comments

Comments
 (0)