Skip to content

Commit 780ee18

Browse files
authored
[aarch64] Implement QGEMM kernels with UMMLA/SMMLA instructions (#17160)
### Description <!-- Describe your changes. --> This PR adds UMMLA and SMMLA based QGEMM kernels for aarch64. This covers (i) symmetric quantization (zero point is Zero) (ii) asymmetric quantization (zero point is non zero) (iii) per channel as well as per tensor quantization (iv) Signed weights (U8S8 Gemm) (v) Unsigned weights (U8U8 Gemm) and (vi) Signed activations and weights (S8S8 Gemm) scenarios I've enabled the ummla/smmla kernels based on cpuinfo check for `I8MM` support MMLA QGEMM kernels are enabled for all the devices that support I8MM instructions. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> This is to improve INT8 quantized MatMul performance on aarch64 platform. I have run the below benchmarking script (bert , roberta and gpt2 model inference) on AWS Graviton3 based c7g.4xl instance and observed up to 1.33x performance improvement compared to the optimized UDOT qgemm kernel performance. ``` cd onnxruntime/python/tools/transformers python3 benchmark.py ``` I have also run the unit tests, and made sure all are passing ``` ./build.sh --config RelWithDebInfo --build_shared_lib --parallel --compile_no_warning_as_error --skip_submodule_sync ```
1 parent 2a17d5c commit 780ee18

9 files changed

Lines changed: 3833 additions & 0 deletions

File tree

cmake/onnxruntime_mlas.cmake

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,9 @@ else()
325325
${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelNeon.S
326326
${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelNeon.S
327327
${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUdot.S
328+
${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S
328329
${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSdot.S
330+
${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S
329331
${MLAS_SRC_DIR}/aarch64/SgemmKernelNeon.S
330332
${MLAS_SRC_DIR}/aarch64/SgemvKernelNeon.S
331333
${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelNeon.S
@@ -334,6 +336,8 @@ else()
334336
${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp
335337
${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp
336338
${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp
339+
${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp
340+
${MLAS_SRC_DIR}/qgemm_kernel_smmla.cpp
337341
)
338342
if (NOT APPLE)
339343
set(mlas_platform_srcs
@@ -348,6 +352,8 @@ else()
348352
set_source_files_properties(${MLAS_SRC_DIR}/activate_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
349353
set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
350354
set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
355+
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")
356+
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")
351357
endif()
352358

353359
if(ONNXRUNTIME_MLAS_MULTI_ARCH)

onnxruntime/core/common/cpuid_info.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,14 @@
2222
#define HWCAP_ASIMDDP (1 << 20)
2323
#endif
2424

25+
#ifndef HWCAP2_I8MM
26+
#define HWCAP2_I8MM (1 << 13)
27+
#endif
28+
29+
#ifndef HWCAP2_SVEI8MM
30+
#define HWCAP2_SVEI8MM (1 << 9)
31+
#endif
32+
2533
#endif // ARM
2634

2735
#endif // Linux
@@ -138,6 +146,9 @@ void CPUIDInfo::ArmLinuxInit() {
138146
is_hybrid_ = cpuinfo_get_uarchs_count() > 1;
139147
has_arm_neon_dot_ = cpuinfo_has_arm_neon_dot();
140148
has_fp16_ = cpuinfo_has_arm_neon_fp16_arith();
149+
has_arm_neon_i8mm_ = cpuinfo_has_arm_i8mm();
150+
has_arm_sve_i8mm_ = cpuinfo_has_arm_sve() && cpuinfo_has_arm_i8mm();
151+
141152
const uint32_t core_cnt = cpuinfo_get_cores_count();
142153
core_uarchs_.resize(core_cnt, cpuinfo_uarch_unknown);
143154
is_armv8_narrow_ld_.resize(core_cnt, false);
@@ -162,6 +173,10 @@ void CPUIDInfo::ArmLinuxInit() {
162173
pytorch_cpuinfo_init_ = false;
163174
has_arm_neon_dot_ = ((getauxval(AT_HWCAP) & HWCAP_ASIMDDP) != 0);
164175
has_fp16_ |= has_arm_neon_dot_;
176+
177+
has_arm_neon_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_I8MM) != 0);
178+
has_arm_sve_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_SVEI8MM) != 0);
179+
165180
#endif
166181
}
167182

@@ -256,6 +271,9 @@ void CPUIDInfo::ArmWindowsInit() {
256271

257272
has_arm_neon_dot_ = (IsProcessorFeaturePresent(PF_ARM_V82_DP_INSTRUCTIONS_AVAILABLE) != 0);
258273
has_fp16_ |= has_arm_neon_dot_;
274+
/* TODO: implement them when hw+sw is available for testing these features */
275+
has_arm_neon_i8mm_ = false;
276+
has_arm_sve_i8mm_ = false;
259277
}
260278

261279
#endif /* (arm or arm64) and windows */

onnxruntime/core/common/cpuid_info.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ class CPUIDInfo {
2828

2929
// ARM
3030
bool HasArmNeonDot() const { return has_arm_neon_dot_; }
31+
bool HasArmNeon_I8MM() const { return has_arm_neon_i8mm_; }
32+
bool HasArmSVE_I8MM() const { return has_arm_sve_i8mm_; }
3133

3234
uint32_t GetCurrentCoreIdx() const;
3335

@@ -121,6 +123,8 @@ class CPUIDInfo {
121123

122124
bool has_arm_neon_dot_{false};
123125
bool has_fp16_{false};
126+
bool has_arm_neon_i8mm_{false};
127+
bool has_arm_sve_i8mm_{false};
124128

125129
#ifdef CPUIDINFO_ARCH_X86
126130

0 commit comments

Comments
 (0)