Skip to content

Commit 27d6cbe

Browse files
Nicoshevfacebook-github-bot
authored andcommitted
Fix KleidiAI FP16 (#3769)
Summary: X-link: facebookresearch/FBGEMM#849 FP32Test was failing when KleidiAI was enabled. It turns out FPCommon.h and PackMatrixB.h were not acconditioned to handle using kleidAI for FP16 and not for FP32 PackedGemmMatrixFP16 constructors were moved to a .cc file, compiled with the rest of fbgemm. This ensures the KleidiAI flag is set when compiling such code. Previously, consumers of the library would include FbgemmPackMatrixB.h from their .cpp files. Because the KleidiAI flag was not set when compiling their code, the intended case for the float16 was never run. Additionally, we have ingested a change in KleidiAI's inline assembly, which handles nan beta values as 0. Reviewed By: psaab Differential Revision: D70606808
1 parent 9187dc8 commit 27d6cbe

File tree

5 files changed

+171
-41
lines changed

5 files changed

+171
-41
lines changed

defs.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def get_fbgemm_generic_srcs(with_base = False):
2626
"src/FbgemmI64.cc",
2727
"src/FbgemmSparseDense.cc",
2828
"src/FbgemmI8Spmdm.cc",
29+
"src/FbgemmPackMatrixB.cc",
2930
# "src/fp32/FbgemmFP32.cc",
3031
"src/GenerateKernelDirectConvU8S8S32ACC32.cc",
3132
"src/GenerateKernel.cc",

include/fbgemm/FbgemmFPCommon.h

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,18 @@ struct GemmParams {
3838
float* C;
3939
uint64_t ldc;
4040
uint64_t b_block_cols;
41+
uint64_t b_block_size;
42+
};
43+
44+
template <>
45+
struct GemmParams<float16> {
46+
uint64_t k;
47+
float* A;
48+
const float16* B;
49+
float beta;
50+
float* C;
51+
uint64_t ldc;
52+
uint64_t b_block_cols;
4153
#ifdef FBGEMM_ENABLE_KLEIDIAI
4254
uint64_t lda;
4355
#else
@@ -163,10 +175,15 @@ void cblas_gemm_compute(
163175
assert(kernel_nrows * kb < static_cast<int64_t>(scratchpad->size()));
164176
if (m != 1) {
165177
#ifdef FBGEMM_ENABLE_KLEIDIAI
166-
gp.A = const_cast<float*>(&A[m2 * k + k_ind]);
167-
#else
168-
PackA(kernel_nrows, kb, &A[m2 * k + k_ind], k, scratchpad->data());
169-
gp.A = scratchpad->data();
178+
if constexpr (std::is_same<T, float16>::value) {
179+
gp.A = const_cast<float*>(&A[m2 * k + k_ind]);
180+
} else {
181+
#endif
182+
PackA(
183+
kernel_nrows, kb, &A[m2 * k + k_ind], k, scratchpad->data());
184+
gp.A = scratchpad->data();
185+
#ifdef FBGEMM_ENABLE_KLEIDIAI
186+
}
170187
#endif
171188
} else {
172189
// When m == 1, it is actually vector matrix multiplication. We
@@ -184,11 +201,14 @@ void cblas_gemm_compute(
184201
gp.ldc = ldc * sizeof(C[0]);
185202
gp.b_block_cols = nbcol;
186203
#ifdef FBGEMM_ENABLE_KLEIDIAI
187-
gp.lda = k * sizeof(A[0]);
188-
#else
189-
gp.b_block_size = gp.k * Bp.blockColSize() * sizeof(gp.B[0]);
204+
if constexpr (std::is_same<T, float16>::value) {
205+
gp.lda = k * sizeof(A[0]);
206+
} else {
207+
#endif
208+
gp.b_block_size = gp.k * Bp.blockColSize() * sizeof(gp.B[0]);
209+
#ifdef FBGEMM_ENABLE_KLEIDIAI
210+
}
190211
#endif
191-
192212
if ((n % Bp.blockColSize()) == 0) {
193213
int64_t jb_begin, jb_end;
194214
fbgemmPartition1D(

include/fbgemm/FbgemmPackMatrixB.h

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -62,15 +62,12 @@ class PackedGemmMatrixB {
6262
const float alpha,
6363
const float* smat,
6464
const int brow = 512)
65-
: nrow_(nrow),
66-
ncol_(ncol),
67-
brow_(brow),
68-
#ifdef FBGEMM_ENABLE_KLEIDIAI
69-
kernel_ncol_blocks_(1)
70-
#else
71-
kernel_ncol_blocks_(2)
65+
: nrow_(nrow), ncol_(ncol), brow_(brow), kernel_ncol_blocks_(2) {
66+
#if defined(FBGEMM_ENABLE_KLEIDIAI)
67+
if (std::is_same<T, float16>::value) {
68+
kernel_ncol_blocks_ = 1;
69+
}
7270
#endif
73-
{
7471
initializeParam();
7572
initializeMemory();
7673
// copy source matrix into packed matrix
@@ -95,6 +92,11 @@ class PackedGemmMatrixB {
9592
nbcol_(nbcol),
9693
size_(size),
9794
kernel_ncol_blocks_(2) {
95+
#if defined(FBGEMM_ENABLE_KLEIDIAI)
96+
if (std::is_same<T, float16>::value) {
97+
kernel_ncol_blocks_ = 1;
98+
}
99+
#endif
98100
initializeMemory();
99101
}
100102

@@ -297,4 +299,30 @@ class PackedGemmMatrixB {
297299
bool pmat_passed_in{false};
298300
};
299301

302+
#ifndef FBGEMM_STATIC
303+
304+
template <>
305+
FBGEMM_API
306+
PackedGemmMatrixB<float16, TypeConverter<float16>>::PackedGemmMatrixB(
307+
const matrix_op_t trans,
308+
const int nrow,
309+
const int ncol,
310+
const float alpha,
311+
const float* smat,
312+
const int brow);
313+
314+
template <>
315+
FBGEMM_API
316+
PackedGemmMatrixB<float16, TypeConverter<float16>>::PackedGemmMatrixB(
317+
const int nrow,
318+
const int ncol,
319+
const int brow,
320+
const int last_brow,
321+
const int bcol,
322+
const int nbrow,
323+
const int nbcol,
324+
const uint64_t size);
325+
326+
#endif
327+
300328
} // namespace fbgemm

src/FbgemmPackMatrixB.cc

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "fbgemm/FbgemmFP16.h"
10+
#define FBGEMM_EXPORTS
11+
12+
namespace fbgemm {
13+
14+
// takes smat input mamtrix in row-major format;
15+
// packs it into gemm-friendly blocked format;
16+
// allocate space and sets up all the internal variables;
17+
// also premultiplies by alpha during packing.
18+
// brow_ contains tile size along k dimension
19+
// and also is # of fmas updates into int16 container
20+
// before flushing into fp32.
21+
// the smaller the brow_, the higher overhead
22+
// of flushing is.
23+
// kernel_ncol_blocks is the number of column blocks (in the size of 8 fp16,
24+
// or 128 bit, or 1 xmm register size) in the kernel. Because the batch size
25+
// can be dynamic and we need to prepack the weight matrix B, the internal
26+
// packing layout of the weight matrix and kernel_ncol_blocks have to be
27+
// fixed. We can choose kernel_ncol_blocks = 1 (with kernels of 1x1~14x1
28+
// register layouts), 2 (with kernels of 1x2~6x2 register layout), or 3 (with
29+
// kernels of 1x3~4x3 register layout).
30+
31+
#ifndef FBGEMM_STATIC
32+
33+
template <>
34+
FBGEMM_API
35+
PackedGemmMatrixB<float16, TypeConverter<float16>>::PackedGemmMatrixB(
36+
const matrix_op_t trans,
37+
const int nrow,
38+
const int ncol,
39+
const float alpha,
40+
const float* smat,
41+
const int brow)
42+
: nrow_(nrow), ncol_(ncol), brow_(brow), kernel_ncol_blocks_(2) {
43+
#if defined(FBGEMM_ENABLE_KLEIDIAI)
44+
kernel_ncol_blocks_ = 1;
45+
#endif
46+
initializeParam();
47+
initializeMemory();
48+
// copy source matrix into packed matrix
49+
this->PackedGemmMatrixB<float16, TypeConverter<float16>>::packFromSrc(
50+
trans, alpha, smat);
51+
}
52+
53+
template <>
54+
FBGEMM_API
55+
PackedGemmMatrixB<float16, TypeConverter<float16>>::PackedGemmMatrixB(
56+
const int nrow,
57+
const int ncol,
58+
const int brow,
59+
const int last_brow,
60+
const int bcol,
61+
const int nbrow,
62+
const int nbcol,
63+
const uint64_t size)
64+
: nrow_(nrow),
65+
ncol_(ncol),
66+
brow_(brow),
67+
last_brow_(last_brow),
68+
bcol_(bcol),
69+
nbrow_(nbrow),
70+
nbcol_(nbcol),
71+
size_(size),
72+
kernel_ncol_blocks_(2) {
73+
#if defined(FBGEMM_ENABLE_KLEIDIAI)
74+
kernel_ncol_blocks_ = 1;
75+
#endif
76+
initializeMemory();
77+
}
78+
79+
#endif
80+
81+
} // namespace fbgemm

src/KleidiAIFP16UKernelsNeon.cc

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
// @lint-ignore-every LICENSELINT
22
//
3-
// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates
3+
// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates
44
55
//
66
// SPDX-License-Identifier: Apache-2.0
@@ -15,15 +15,15 @@ namespace kleidiai {
1515
void NOINLINE gemmkernel_1x1_Neon_fp16_fA0fB0fC0(GemmParamsFP16* gp) {
1616
#ifdef __aarch64__
1717
__asm__ __volatile__(
18-
"ldr w20, [%x[gp], %[offsetof_beta]]\n"
18+
"ldr s16, [%x[gp], %[offsetof_beta]]\n"
1919
"mov x25, #0x1\n"
2020
"fmov v29.8h, #1.0\n"
2121
"ldr x24, [%x[gp], %[offsetof_b_block_cols]]\n"
2222
"ldr x23, [%x[gp], %[offsetof_B]]\n"
2323
"ldr x22, [%x[gp], %[offsetof_C]]\n"
24-
"bic x20, x20, #0x80000000\n"
25-
"cmp x20, #0x0\n"
24+
"fcmp s16, #0.0\n"
2625
"csel x25, XZR, x25, EQ\n"
26+
"csel x25, XZR, x25, VS\n"
2727
"1:" // Height 1: Column loop
2828
"tbz x25, #0, 2f\n"
2929
"ldr q30, [x22, #0x0]\n"
@@ -177,15 +177,15 @@ void NOINLINE gemmkernel_1x1_Neon_fp16_fA0fB0fC0(GemmParamsFP16* gp) {
177177
void NOINLINE gemmkernel_2x1_Neon_fp16_fA0fB0fC0(GemmParamsFP16* gp) {
178178
#ifdef __aarch64__
179179
__asm__ __volatile__(
180-
"ldr w20, [%x[gp], %[offsetof_beta]]\n"
180+
"ldr s16, [%x[gp], %[offsetof_beta]]\n"
181181
"mov x26, #0x1\n"
182182
"fmov v27.8h, #1.0\n"
183183
"ldr x25, [%x[gp], %[offsetof_b_block_cols]]\n"
184184
"ldr x24, [%x[gp], %[offsetof_B]]\n"
185185
"ldr x23, [%x[gp], %[offsetof_C]]\n"
186-
"bic x20, x20, #0x80000000\n"
187-
"cmp x20, #0x0\n"
186+
"fcmp s16, #0.0\n"
188187
"csel x26, XZR, x26, EQ\n"
188+
"csel x26, XZR, x26, VS\n"
189189
"1:" // Height 2: Column loop
190190
"tbz x26, #0, 2f\n"
191191
"ldr q28, [x23, #0x0]\n"
@@ -384,15 +384,15 @@ void NOINLINE gemmkernel_2x1_Neon_fp16_fA0fB0fC0(GemmParamsFP16* gp) {
384384
void NOINLINE gemmkernel_3x1_Neon_fp16_fA0fB0fC0(GemmParamsFP16* gp) {
385385
#ifdef __aarch64__
386386
__asm__ __volatile__(
387-
"ldr w20, [%x[gp], %[offsetof_beta]]\n"
387+
"ldr s16, [%x[gp], %[offsetof_beta]]\n"
388388
"mov x27, #0x1\n"
389389
"fmov v25.8h, #1.0\n"
390390
"ldr x26, [%x[gp], %[offsetof_b_block_cols]]\n"
391391
"ldr x25, [%x[gp], %[offsetof_B]]\n"
392392
"ldr x24, [%x[gp], %[offsetof_C]]\n"
393-
"bic x20, x20, #0x80000000\n"
394-
"cmp x20, #0x0\n"
393+
"fcmp s16, #0.0\n"
395394
"csel x27, XZR, x27, EQ\n"
395+
"csel x27, XZR, x27, VS\n"
396396
"1:" // Height 3: Column loop
397397
"tbz x27, #0, 2f\n"
398398
"ldr q26, [x24, #0x0]\n"
@@ -632,15 +632,15 @@ void NOINLINE gemmkernel_3x1_Neon_fp16_fA0fB0fC0(GemmParamsFP16* gp) {
632632
void NOINLINE gemmkernel_4x1_Neon_fp16_fA0fB0fC0(GemmParamsFP16* gp) {
633633
#ifdef __aarch64__
634634
__asm__ __volatile__(
635-
"ldr w20, [%x[gp], %[offsetof_beta]]\n"
635+
"ldr s16, [%x[gp], %[offsetof_beta]]\n"
636636
"mov x28, #0x1\n"
637637
"fmov v23.8h, #1.0\n"
638638
"ldr x27, [%x[gp], %[offsetof_b_block_cols]]\n"
639639
"ldr x26, [%x[gp], %[offsetof_B]]\n"
640640
"ldr x25, [%x[gp], %[offsetof_C]]\n"
641-
"bic x20, x20, #0x80000000\n"
642-
"cmp x20, #0x0\n"
641+
"fcmp s16, #0.0\n"
643642
"csel x28, XZR, x28, EQ\n"
643+
"csel x28, XZR, x28, VS\n"
644644
"1:" // Height 4: Column loop
645645
"tbz x28, #0, 2f\n"
646646
"ldr q24, [x25, #0x0]\n"
@@ -921,15 +921,15 @@ void NOINLINE gemmkernel_4x1_Neon_fp16_fA0fB0fC0(GemmParamsFP16* gp) {
921921
void NOINLINE gemmkernel_5x1_Neon_fp16_fA0fB0fC0(GemmParamsFP16* gp) {
922922
#ifdef __aarch64__
923923
__asm__ __volatile__(
924-
"ldr w20, [%x[gp], %[offsetof_beta]]\n"
924+
"ldr s16, [%x[gp], %[offsetof_beta]]\n"
925925
"mov x9, #0x1\n"
926926
"fmov v21.8h, #1.0\n"
927927
"ldr x28, [%x[gp], %[offsetof_b_block_cols]]\n"
928928
"ldr x27, [%x[gp], %[offsetof_B]]\n"
929929
"ldr x26, [%x[gp], %[offsetof_C]]\n"
930-
"bic x20, x20, #0x80000000\n"
931-
"cmp x20, #0x0\n"
930+
"fcmp s16, #0.0\n"
932931
"csel x9, XZR, x9, EQ\n"
932+
"csel x9, XZR, x9, VS\n"
933933
"1:" // Height 5: Column loop
934934
"tbz x9, #0, 2f\n"
935935
"ldr q22, [x26, #0x0]\n"
@@ -1251,15 +1251,15 @@ void NOINLINE gemmkernel_5x1_Neon_fp16_fA0fB0fC0(GemmParamsFP16* gp) {
12511251
void NOINLINE gemmkernel_6x1_Neon_fp16_fA0fB0fC0(GemmParamsFP16* gp) {
12521252
#ifdef __aarch64__
12531253
__asm__ __volatile__(
1254-
"ldr w20, [%x[gp], %[offsetof_beta]]\n"
1254+
"ldr s16, [%x[gp], %[offsetof_beta]]\n"
12551255
"mov x10, #0x1\n"
12561256
"fmov v19.8h, #1.0\n"
12571257
"ldr x9, [%x[gp], %[offsetof_b_block_cols]]\n"
12581258
"ldr x28, [%x[gp], %[offsetof_B]]\n"
12591259
"ldr x27, [%x[gp], %[offsetof_C]]\n"
1260-
"bic x20, x20, #0x80000000\n"
1261-
"cmp x20, #0x0\n"
1260+
"fcmp s16, #0.0\n"
12621261
"csel x10, XZR, x10, EQ\n"
1262+
"csel x10, XZR, x10, VS\n"
12631263
"1:" // Height 6: Column loop
12641264
"tbz x10, #0, 2f\n"
12651265
"ldr q20, [x27, #0x0]\n"
@@ -1620,15 +1620,15 @@ void NOINLINE gemmkernel_6x1_Neon_fp16_fA0fB0fC0(GemmParamsFP16* gp) {
16201620
void NOINLINE gemmkernel_7x1_Neon_fp16_fA0fB0fC0(GemmParamsFP16* gp) {
16211621
#ifdef __aarch64__
16221622
__asm__ __volatile__(
1623-
"ldr w20, [%x[gp], %[offsetof_beta]]\n"
1623+
"ldr s16, [%x[gp], %[offsetof_beta]]\n"
16241624
"mov x11, #0x1\n"
16251625
"fmov v17.8h, #1.0\n"
16261626
"ldr x10, [%x[gp], %[offsetof_b_block_cols]]\n"
16271627
"ldr x9, [%x[gp], %[offsetof_B]]\n"
16281628
"ldr x28, [%x[gp], %[offsetof_C]]\n"
1629-
"bic x20, x20, #0x80000000\n"
1630-
"cmp x20, #0x0\n"
1629+
"fcmp s16, #0.0\n"
16311630
"csel x11, XZR, x11, EQ\n"
1631+
"csel x11, XZR, x11, VS\n"
16321632
"1:" // Height 7: Column loop
16331633
"tbz x11, #0, 2f\n"
16341634
"ldr q18, [x28, #0x0]\n"
@@ -2027,15 +2027,15 @@ void NOINLINE gemmkernel_7x1_Neon_fp16_fA0fB0fC0(GemmParamsFP16* gp) {
20272027
void NOINLINE gemmkernel_8x1_Neon_fp16_fA0fB0fC0(GemmParamsFP16* gp) {
20282028
#ifdef __aarch64__
20292029
__asm__ __volatile__(
2030-
"ldr w20, [%x[gp], %[offsetof_beta]]\n"
2030+
"ldr s16, [%x[gp], %[offsetof_beta]]\n"
20312031
"mov x12, #0x1\n"
20322032
"fmov v15.8h, #1.0\n"
20332033
"ldr x11, [%x[gp], %[offsetof_b_block_cols]]\n"
20342034
"ldr x10, [%x[gp], %[offsetof_B]]\n"
20352035
"ldr x9, [%x[gp], %[offsetof_C]]\n"
2036-
"bic x20, x20, #0x80000000\n"
2037-
"cmp x20, #0x0\n"
2036+
"fcmp s16, #0.0\n"
20382037
"csel x12, XZR, x12, EQ\n"
2038+
"csel x12, XZR, x12, VS\n"
20392039
"1:" // Height 8: Column loop
20402040
"tbz x12, #0, 2f\n"
20412041
"ldr q16, [x9, #0x0]\n"

0 commit comments

Comments
 (0)