Skip to content

Commit 1af37ae

Browse files
snadampalkleiti
authored andcommitted
[aarch64] Implement QGEMM kernels with UMMLA/SMMLA instructions (microsoft#17160)
### Description <!-- Describe your changes. --> This PR adds UMMLA and SMMLA based QGEMM kernels for aarch64. This covers (i) symmetric quantization (zero point is Zero) (ii) asymmetric quantization (zero point is non zero) (iii) per channel as well as per tensor quantization (iv) Signed weights (U8S8 Gemm) (v) Unsigned weights (U8U8 Gemm) and (vi) Signed activations and weights (S8S8 Gemm) scenarios I've enabled the ummla/smmla kernels based on cpuinfo check for `I8MM` support MMLA QGEMM kernels are enabled for all the devices that support I8MM instructions. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> This is to improve INT8 quantized MatMul performance on aarch64 platform. I have run the below benchmarking script (bert , roberta and gpt2 model inference) on AWS Graviton3 based c7g.4xl instance and observed up to 1.33x performance improvement compared to the optimized UDOT qgemm kernel performance. ``` cd onnxruntime/python/tools/transformers python3 benchmark.py ``` I have also run the unit tests, and made sure all are passing ``` ./build.sh --config RelWithDebInfo --build_shared_lib --parallel --compile_no_warning_as_error --skip_submodule_sync ```
1 parent 07bfbbb commit 1af37ae

9 files changed

Lines changed: 3833 additions & 0 deletions

File tree

cmake/onnxruntime_mlas.cmake

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,9 @@ else()
325325
${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelNeon.S
326326
${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelNeon.S
327327
${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUdot.S
328+
${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S
328329
${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSdot.S
330+
${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S
329331
${MLAS_SRC_DIR}/aarch64/SgemmKernelNeon.S
330332
${MLAS_SRC_DIR}/aarch64/SgemvKernelNeon.S
331333
${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelNeon.S
@@ -334,6 +336,8 @@ else()
334336
${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp
335337
${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp
336338
${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp
339+
${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp
340+
${MLAS_SRC_DIR}/qgemm_kernel_smmla.cpp
337341
)
338342
if (NOT APPLE)
339343
set(mlas_platform_srcs
@@ -348,6 +352,8 @@ else()
348352
set_source_files_properties(${MLAS_SRC_DIR}/activate_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
349353
set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
350354
set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
355+
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")
356+
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")
351357
endif()
352358

353359
if(ONNXRUNTIME_MLAS_MULTI_ARCH)

onnxruntime/core/common/cpuid_info.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,14 @@
2222
#define HWCAP_ASIMDDP (1 << 20)
2323
#endif
2424

25+
#ifndef HWCAP2_I8MM
26+
#define HWCAP2_I8MM (1 << 13)
27+
#endif
28+
29+
#ifndef HWCAP2_SVEI8MM
30+
#define HWCAP2_SVEI8MM (1 << 9)
31+
#endif
32+
2533
#endif // ARM
2634

2735
#endif // Linux
@@ -160,6 +168,9 @@ void CPUIDInfo::ArmLinuxInit() {
160168
is_hybrid_ = cpuinfo_get_uarchs_count() > 1;
161169
has_arm_neon_dot_ = cpuinfo_has_arm_neon_dot();
162170
has_fp16_ = cpuinfo_has_arm_neon_fp16_arith();
171+
has_arm_neon_i8mm_ = cpuinfo_has_arm_i8mm();
172+
has_arm_sve_i8mm_ = cpuinfo_has_arm_sve() && cpuinfo_has_arm_i8mm();
173+
163174
const uint32_t core_cnt = cpuinfo_get_cores_count();
164175
core_uarchs_.resize(core_cnt, cpuinfo_uarch_unknown);
165176
is_armv8_narrow_ld_.resize(core_cnt, false);
@@ -184,6 +195,10 @@ void CPUIDInfo::ArmLinuxInit() {
184195
pytorch_cpuinfo_init_ = false;
185196
has_arm_neon_dot_ = ((getauxval(AT_HWCAP) & HWCAP_ASIMDDP) != 0);
186197
has_fp16_ |= has_arm_neon_dot_;
198+
199+
has_arm_neon_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_I8MM) != 0);
200+
has_arm_sve_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_SVEI8MM) != 0);
201+
187202
#endif
188203
}
189204

@@ -278,6 +293,9 @@ void CPUIDInfo::ArmWindowsInit() {
278293

279294
has_arm_neon_dot_ = (IsProcessorFeaturePresent(PF_ARM_V82_DP_INSTRUCTIONS_AVAILABLE) != 0);
280295
has_fp16_ |= has_arm_neon_dot_;
296+
/* TODO: implement them when hw+sw is available for testing these features */
297+
has_arm_neon_i8mm_ = false;
298+
has_arm_sve_i8mm_ = false;
281299
}
282300

283301
#endif /* (arm or arm64) and windows */

onnxruntime/core/common/cpuid_info.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ class CPUIDInfo {
2929

3030
// ARM
3131
bool HasArmNeonDot() const { return has_arm_neon_dot_; }
32+
bool HasArmNeon_I8MM() const { return has_arm_neon_i8mm_; }
33+
bool HasArmSVE_I8MM() const { return has_arm_sve_i8mm_; }
3234

3335
uint32_t GetCurrentCoreIdx() const;
3436

@@ -123,6 +125,8 @@ class CPUIDInfo {
123125

124126
bool has_arm_neon_dot_{false};
125127
bool has_fp16_{false};
128+
bool has_arm_neon_i8mm_{false};
129+
bool has_arm_sve_i8mm_{false};
126130

127131
#ifdef CPUIDINFO_ARCH_X86
128132

0 commit comments

Comments
 (0)