Skip to content

Commit 91c84e1

Browse files
authored
Merge pull request #2796 from Guobing-Chen/BF16_dot_coversion_apis
Add bfloat16 based dot and conversion with single/double
2 parents 1ee1e7b + deaeb6c commit 91c84e1

31 files changed

+1392
-82
lines changed

Makefile.tail

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@ QBLASOBJS_P = $(QBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
55
CBLASOBJS_P = $(CBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
66
ZBLASOBJS_P = $(ZBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
77
XBLASOBJS_P = $(XBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
8+
SHEXTOBJS_P = $(SHEXTOBJS:.$(SUFFIX)=.$(PSUFFIX))
89

910
COMMONOBJS_P = $(COMMONOBJS:.$(SUFFIX)=.$(PSUFFIX))
1011

1112
HPLOBJS_P = $(HPLOBJS:.$(SUFFIX)=.$(PSUFFIX))
1213

13-
BLASOBJS = $(SHBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS)
14-
BLASOBJS_P = $(SHBLASOBJS_P) $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P)
14+
BLASOBJS = $(SHEXTOBJS) $(SHBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS)
15+
BLASOBJS_P = $(SHEXTOBJS_P) $(SHBLASOBJS_P) $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P)
1516

1617
ifdef EXPRECISION
1718
BLASOBJS += $(QBLASOBJS) $(XBLASOBJS)
@@ -30,6 +31,7 @@ $(QBLASOBJS) $(QBLASOBJS_P) : override CFLAGS += -DXDOUBLE -UCOMPLEX
3031
$(CBLASOBJS) $(CBLASOBJS_P) : override CFLAGS += -UDOUBLE -DCOMPLEX
3132
$(ZBLASOBJS) $(ZBLASOBJS_P) : override CFLAGS += -DDOUBLE -DCOMPLEX
3233
$(XBLASOBJS) $(XBLASOBJS_P) : override CFLAGS += -DXDOUBLE -DCOMPLEX
34+
$(SHEXTOBJS) $(SHEXTOBJS_P) : override CFLAGS += -DHALF -UDOUBLE -UCOMPLEX
3335

3436
$(SHBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
3537
$(SBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
@@ -38,6 +40,7 @@ $(QBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
3840
$(CBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
3941
$(ZBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
4042
$(XBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
43+
$(SHEXTOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
4144

4245
libs :: $(BLASOBJS) $(COMMONOBJS)
4346
$(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^

cblas.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,17 @@ void cblas_cgeadd(OPENBLAS_CONST enum CBLAS_ORDER CORDER,OPENBLAS_CONST blasint
382382
void cblas_zgeadd(OPENBLAS_CONST enum CBLAS_ORDER CORDER,OPENBLAS_CONST blasint crows, OPENBLAS_CONST blasint ccols, OPENBLAS_CONST double *calpha, double *a, OPENBLAS_CONST blasint clda, OPENBLAS_CONST double *cbeta,
383383
double *c, OPENBLAS_CONST blasint cldc);
384384

385+
/*** BFLOAT16 and INT8 extensions ***/
386+
/* convert float array to BFLOAT16 array by rounding */
387+
void cblas_shstobf16(OPENBLAS_CONST blasint n, OPENBLAS_CONST float *in, OPENBLAS_CONST blasint incin, bfloat16 *out, OPENBLAS_CONST blasint incout);
388+
/* convert double array to BFLOAT16 array by rounding */
389+
void cblas_shdtobf16(OPENBLAS_CONST blasint n, OPENBLAS_CONST double *in, OPENBLAS_CONST blasint incin, bfloat16 *out, OPENBLAS_CONST blasint incout);
390+
/* convert BFLOAT16 array to float array */
391+
void cblas_sbf16tos(OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 *in, OPENBLAS_CONST blasint incin, float *out, OPENBLAS_CONST blasint incout);
392+
/* convert BFLOAT16 array to double array */
393+
void cblas_dbf16tod(OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 *in, OPENBLAS_CONST blasint incin, double *out, OPENBLAS_CONST blasint incout);
394+
/* dot production of BFLOAT16 input arrays, and output as float */
395+
float cblas_shdot(OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 *x, OPENBLAS_CONST blasint incx, OPENBLAS_CONST bfloat16 *y, OPENBLAS_CONST blasint incy);
385396

386397
#ifdef __cplusplus
387398
}

cmake/kernel.cmake

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,12 +126,14 @@ if (BUILD_HALF)
126126
set(SHAXPYKERNEL ../arm/axpy.c)
127127
set(SHAXPBYKERNEL ../arm/axpby.c)
128128
set(SHCOPYKERNEL ../arm/copy.c)
129-
set(SHDOTKERNEL ../arm/dot.c)
129+
set(SHDOTKERNEL ../x86_64/shdot.c)
130130
set(SHROTKERNEL ../arm/rot.c)
131131
set(SHSCALKERNEL ../arm/scal.c)
132132
set(SHNRM2KERNEL ../arm/nrm2.c)
133133
set(SHSUMKERNEL ../arm/sum.c)
134134
set(SHSWAPKERNEL ../arm/swap.c)
135+
set(TOBF16KERNEL ../x86_64/tobf16.c)
136+
set(BF16TOKERNEL ../x86_64/bf16to.c)
135137
endif ()
136138
endmacro ()
137139

common.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,8 @@ typedef unsigned long BLASULONG;
258258
#endif
259259

260260
#ifndef BFLOAT16
261-
typedef unsigned short bfloat16;
261+
#include <stdint.h>
262+
typedef uint16_t bfloat16;
262263
#define HALFCONVERSION 1
263264
#endif
264265

common_interface.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,11 @@ double BLASFUNC(dsdot) (blasint *, float *, blasint *, float *, blasint *);
5454
double BLASFUNC(ddot) (blasint *, double *, blasint *, double *, blasint *);
5555
xdouble BLASFUNC(qdot) (blasint *, xdouble *, blasint *, xdouble *, blasint *);
5656

57+
float BLASFUNC(shdot) (blasint *, bfloat16 *, blasint *, bfloat16 *, blasint *);
58+
void BLASFUNC(shstobf16) (blasint *, float *, blasint *, bfloat16 *, blasint *);
59+
void BLASFUNC(shdtobf16) (blasint *, double *, blasint *, bfloat16 *, blasint *);
60+
void BLASFUNC(sbf16tos) (blasint *, bfloat16 *, blasint *, float *, blasint *);
61+
void BLASFUNC(dbf16tod) (blasint *, bfloat16 *, blasint *, double *, blasint *);
5762

5863
#ifdef RETURN_BY_STRUCT
5964
typedef struct {

common_level1.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,12 @@ float sdot_k(BLASLONG, float *, BLASLONG, float *, BLASLONG);
4646
double dsdot_k(BLASLONG, float *, BLASLONG, float *, BLASLONG);
4747
double ddot_k(BLASLONG, double *, BLASLONG, double *, BLASLONG);
4848
xdouble qdot_k(BLASLONG, xdouble *, BLASLONG, xdouble *, BLASLONG);
49+
float shdot_k(BLASLONG, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG);
50+
51+
void shstobf16_k(BLASLONG, float *, BLASLONG, bfloat16 *, BLASLONG);
52+
void shdtobf16_k(BLASLONG, double *, BLASLONG, bfloat16 *, BLASLONG);
53+
void sbf16tos_k (BLASLONG, bfloat16 *, BLASLONG, float *, BLASLONG);
54+
void dbf16tod_k (BLASLONG, bfloat16 *, BLASLONG, double *, BLASLONG);
4955

5056
openblas_complex_float cdotc_k (BLASLONG, float *, BLASLONG, float *, BLASLONG);
5157
openblas_complex_float cdotu_k (BLASLONG, float *, BLASLONG, float *, BLASLONG);

common_macro.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,11 @@
646646

647647
#elif defined(HALF)
648648

649+
#define D_TO_BF16_K SHDTOBF16_K
650+
#define D_BF16_TO_K DBF16TOD_K
651+
#define S_TO_BF16_K SHSTOBF16_K
652+
#define S_BF16_TO_K SBF16TOS_K
653+
649654
#define AMAX_K SAMAX_K
650655
#define AMIN_K SAMIN_K
651656
#define MAX_K SMAX_K
@@ -657,6 +662,7 @@
657662
#define ASUM_K SASUM_K
658663
#define DOTU_K SDOTU_K
659664
#define DOTC_K SDOTC_K
665+
#define BF16_DOT_K SHDOT_K
660666
#define AXPYU_K SAXPYU_K
661667
#define AXPYC_K SAXPYC_K
662668
#define AXPBY_K SAXPBY_K

common_param.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,11 @@ typedef struct {
5151
int shgemm_p, shgemm_q, shgemm_r;
5252
int shgemm_unroll_m, shgemm_unroll_n, shgemm_unroll_mn;
5353

54+
void (*shstobf16_k) (BLASLONG, float *, BLASLONG, bfloat16 *, BLASLONG);
55+
void (*shdtobf16_k) (BLASLONG, double *, BLASLONG, bfloat16 *, BLASLONG);
56+
void (*sbf16tos_k) (BLASLONG, bfloat16 *, BLASLONG, float *, BLASLONG);
57+
void (*dbf16tod_k) (BLASLONG, bfloat16 *, BLASLONG, double *, BLASLONG);
58+
5459
float (*shamax_k) (BLASLONG, float *, BLASLONG);
5560
float (*shamin_k) (BLASLONG, float *, BLASLONG);
5661
float (*shmax_k) (BLASLONG, float *, BLASLONG);
@@ -64,7 +69,7 @@ BLASLONG (*ishmin_k) (BLASLONG, float *, BLASLONG);
6469
float (*shasum_k) (BLASLONG, float *, BLASLONG);
6570
float (*shsum_k) (BLASLONG, float *, BLASLONG);
6671
int (*shcopy_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG);
67-
float (*shdot_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG);
72+
float (*shdot_k) (BLASLONG, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG);
6873
double (*dshdot_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG);
6974

7075
int (*shrot_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG, float, float);

common_sh.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,12 @@
33

44
#ifndef DYNAMIC_ARCH
55

6+
#define SHDOT_K shdot_k
7+
#define SHSTOBF16_K shstobf16_k
8+
#define SHDTOBF16_K shdtobf16_k
9+
#define SBF16TOS_K sbf16tos_k
10+
#define DBF16TOD_K dbf16tod_k
11+
612
#define SHGEMM_ONCOPY shgemm_oncopy
713
#define SHGEMM_OTCOPY shgemm_otcopy
814

@@ -18,6 +24,12 @@
1824

1925
#else
2026

27+
#define SHDOT_K gotoblas -> shdot_k
28+
#define SHSTOBF16_K gotoblas -> shstobf16_k
29+
#define SHDTOBF16_K gotoblas -> shdtobf16_k
30+
#define SBF16TOS_K gotoblas -> sbf16tos_k
31+
#define DBF16TOD_K gotoblas -> dbf16tod_k
32+
2133
#define SHGEMM_ONCOPY gotoblas -> shgemm_oncopy
2234
#define SHGEMM_OTCOPY gotoblas -> shgemm_otcopy
2335
#define SHGEMM_INCOPY gotoblas -> shgemm_incopy

common_thread.h

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,19 @@ extern int blas_omp_linked;
5959
#define BLAS_PTHREAD 0x4000U
6060
#define BLAS_NODE 0x2000U
6161

62-
#define BLAS_PREC 0x0003U
63-
#define BLAS_SINGLE 0x0000U
64-
#define BLAS_DOUBLE 0x0001U
65-
#define BLAS_XDOUBLE 0x0002U
66-
#define BLAS_REAL 0x0000U
67-
#define BLAS_COMPLEX 0x0004U
62+
#define BLAS_PREC 0x000FU
63+
#define BLAS_INT8 0x0000U
64+
#define BLAS_BFLOAT16 0x0001U
65+
#define BLAS_SINGLE 0x0002U
66+
#define BLAS_DOUBLE 0x0003U
67+
#define BLAS_XDOUBLE 0x0004U
68+
#define BLAS_STOBF16 0x0008U
69+
#define BLAS_DTOBF16 0x0009U
70+
#define BLAS_BF16TOS 0x000AU
71+
#define BLAS_BF16TOD 0x000BU
72+
73+
#define BLAS_REAL 0x0000U
74+
#define BLAS_COMPLEX 0x1000U
6875

6976
#define BLAS_TRANSA 0x0030U /* 2bit */
7077
#define BLAS_TRANSA_N 0x0000U

common_x86_64.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,29 @@ static __inline void cpuid(int op, int *eax, int *ebx, int *ecx, int *edx){
142142
#endif
143143
}
144144

145+
static __inline void cpuid_count(int op, int count, int *eax, int *ebx, int *ecx, int *edx)
146+
{
147+
#ifdef C_MSVC
148+
int cpuInfo[4] = {-1};
149+
__cpuidex(cpuInfo, op, count);
150+
*eax = cpuInfo[0];
151+
*ebx = cpuInfo[1];
152+
*ecx = cpuInfo[2];
153+
*edx = cpuInfo[3];
154+
#else
155+
#if defined(__i386__) && defined(__PIC__)
156+
__asm__ __volatile__
157+
("mov %%ebx, %%edi;"
158+
"cpuid;"
159+
"xchgl %%ebx, %%edi;"
160+
: "=a" (*eax), "=D" (*ebx), "=c" (*ecx), "=d" (*edx) : "0" (op), "2" (count) : "cc");
161+
#else
162+
__asm__ __volatile__
163+
("cpuid": "=a" (*eax), "=b" (*ebx), "=c" (*ecx), "=d" (*edx) : "0" (op), "2" (count) : "cc");
164+
#endif
165+
#endif
166+
}
167+
145168
/*
146169
#define WHEREAMI
147170
*/

driver/others/blas_l1_thread.c

Lines changed: 64 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,36 @@ int blas_level1_thread(int mode, BLASLONG m, BLASLONG n, BLASLONG k, void *alpha
4949
blas_arg_t args [MAX_CPU_NUMBER];
5050

5151
BLASLONG i, width, astride, bstride;
52-
int num_cpu, calc_type;
53-
54-
calc_type = (mode & BLAS_PREC) + ((mode & BLAS_COMPLEX) != 0) + 2;
52+
int num_cpu, calc_type_a, calc_type_b;
53+
54+
switch (mode & BLAS_PREC) {
55+
case BLAS_INT8 :
56+
case BLAS_BFLOAT16:
57+
case BLAS_SINGLE :
58+
case BLAS_DOUBLE :
59+
case BLAS_XDOUBLE :
60+
calc_type_a = calc_type_b = (mode & BLAS_PREC) + ((mode & BLAS_COMPLEX) != 0);
61+
break;
62+
case BLAS_STOBF16 :
63+
calc_type_a = 2 + ((mode & BLAS_COMPLEX) != 0);
64+
calc_type_b = 1 + ((mode & BLAS_COMPLEX) != 0);
65+
break;
66+
case BLAS_DTOBF16 :
67+
calc_type_a = 3 + ((mode & BLAS_COMPLEX) != 0);
68+
calc_type_b = 1 + ((mode & BLAS_COMPLEX) != 0);
69+
break;
70+
case BLAS_BF16TOS :
71+
calc_type_a = 1 + ((mode & BLAS_COMPLEX) != 0);
72+
calc_type_b = 2 + ((mode & BLAS_COMPLEX) != 0);
73+
break;
74+
case BLAS_BF16TOD :
75+
calc_type_a = 1 + ((mode & BLAS_COMPLEX) != 0);
76+
calc_type_b = 3 + ((mode & BLAS_COMPLEX) != 0);
77+
break;
78+
default:
79+
calc_type_a = calc_type_b = 0;
80+
break;
81+
}
5582

5683
mode |= BLAS_LEGACY;
5784

@@ -77,8 +104,8 @@ int blas_level1_thread(int mode, BLASLONG m, BLASLONG n, BLASLONG k, void *alpha
77104
bstride = width;
78105
}
79106

80-
astride <<= calc_type;
81-
bstride <<= calc_type;
107+
astride <<= calc_type_a;
108+
bstride <<= calc_type_b;
82109

83110
args[num_cpu].m = width;
84111
args[num_cpu].n = n;
@@ -120,9 +147,36 @@ int blas_level1_thread_with_return_value(int mode, BLASLONG m, BLASLONG n, BLASL
120147
blas_arg_t args [MAX_CPU_NUMBER];
121148

122149
BLASLONG i, width, astride, bstride;
123-
int num_cpu, calc_type;
124-
125-
calc_type = (mode & BLAS_PREC) + ((mode & BLAS_COMPLEX) != 0) + 2;
150+
int num_cpu, calc_type_a, calc_type_b;
151+
152+
switch (mode & BLAS_PREC) {
153+
case BLAS_INT8 :
154+
case BLAS_BFLOAT16:
155+
case BLAS_SINGLE :
156+
case BLAS_DOUBLE :
157+
case BLAS_XDOUBLE :
158+
calc_type_a = calc_type_b = (mode & BLAS_PREC) + ((mode & BLAS_COMPLEX) != 0);
159+
break;
160+
case BLAS_STOBF16 :
161+
calc_type_a = 2 + ((mode & BLAS_COMPLEX) != 0);
162+
calc_type_b = 1 + ((mode & BLAS_COMPLEX) != 0);
163+
break;
164+
case BLAS_DTOBF16 :
165+
calc_type_a = 3 + ((mode & BLAS_COMPLEX) != 0);
166+
calc_type_b = 1 + ((mode & BLAS_COMPLEX) != 0);
167+
break;
168+
case BLAS_BF16TOS :
169+
calc_type_a = 1 + ((mode & BLAS_COMPLEX) != 0);
170+
calc_type_b = 2 + ((mode & BLAS_COMPLEX) != 0);
171+
break;
172+
case BLAS_BF16TOD :
173+
calc_type_a = 1 + ((mode & BLAS_COMPLEX) != 0);
174+
calc_type_b = 3 + ((mode & BLAS_COMPLEX) != 0);
175+
break;
176+
default:
177+
calc_type_a = calc_type_b = 0;
178+
break;
179+
}
126180

127181
mode |= BLAS_LEGACY;
128182

@@ -148,8 +202,8 @@ int blas_level1_thread_with_return_value(int mode, BLASLONG m, BLASLONG n, BLASL
148202
bstride = width;
149203
}
150204

151-
astride <<= calc_type;
152-
bstride <<= calc_type;
205+
astride <<= calc_type_a;
206+
bstride <<= calc_type_b;
153207

154208
args[num_cpu].m = width;
155209
args[num_cpu].n = n;

0 commit comments

Comments
 (0)