|
1 | 1 | /***************************************************************************
|
| 2 | +Copyright (c) 2023, The OpenBLAS Project |
2 | 3 | Copyright (c) 2022, Arm Ltd
|
3 | 4 | All rights reserved.
|
4 | 5 | Redistribution and use in source and binary forms, with or without
|
@@ -30,37 +31,84 @@ THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
30 | 31 | #include <arm_sve.h>
|
31 | 32 |
|
32 | 33 | #ifdef DOUBLE
|
33 |
| -#define SVE_TYPE svfloat64_t |
34 |
| -#define SVE_ZERO svdup_f64(0.0) |
35 |
| -#define SVE_WHILELT svwhilelt_b64 |
36 |
| -#define SVE_ALL svptrue_b64() |
37 |
| -#define SVE_WIDTH svcntd() |
| 34 | +#define DTYPE "d" |
| 35 | +#define WIDTH "d" |
| 36 | +#define SHIFT "3" |
38 | 37 | #else
|
39 |
| -#define SVE_TYPE svfloat32_t |
40 |
| -#define SVE_ZERO svdup_f32(0.0) |
41 |
| -#define SVE_WHILELT svwhilelt_b32 |
42 |
| -#define SVE_ALL svptrue_b32() |
43 |
| -#define SVE_WIDTH svcntw() |
| 38 | +#define DTYPE "s" |
| 39 | +#define WIDTH "w" |
| 40 | +#define SHIFT "2" |
44 | 41 | #endif
|
45 | 42 |
|
46 |
| -static FLOAT dot_kernel_sve(BLASLONG n, FLOAT *x, FLOAT *y) { |
47 |
| - SVE_TYPE acc_a = SVE_ZERO; |
48 |
| - SVE_TYPE acc_b = SVE_ZERO; |
| 43 | +#define COUNT \ |
| 44 | +" cnt"WIDTH" x9 \n" |
| 45 | +#define SETUP_TRUE \ |
| 46 | +" ptrue p0."DTYPE" \n" |
| 47 | +#define OFFSET_INPUTS \ |
| 48 | +" add x12, %[X_], x9, lsl #"SHIFT" \n" \ |
| 49 | +" add x13, %[Y_], x9, lsl #"SHIFT" \n" |
| 50 | +#define TAIL_WHILE \ |
| 51 | +" whilelo p1."DTYPE", x8, x0 \n" |
| 52 | +#define UPDATE(pg, x,y,out) \ |
| 53 | +" ld1"WIDTH" { z2."DTYPE" }, "pg"/z, ["x", x8, lsl #"SHIFT"] \n" \ |
| 54 | +" ld1"WIDTH" { z3."DTYPE" }, "pg"/z, ["y", x8, lsl #"SHIFT"] \n" \ |
| 55 | +" fmla "out"."DTYPE", "pg"/m, z2."DTYPE", z3."DTYPE" \n" |
| 56 | +#define SUM_VECTOR(v) \ |
| 57 | +" faddv "DTYPE""v", p0, z"v"."DTYPE" \n" |
| 58 | +#define RET \ |
| 59 | +" fadd %"DTYPE"[RET_], "DTYPE"1, "DTYPE"0 \n" |
49 | 60 |
|
50 |
| - BLASLONG sve_width = SVE_WIDTH; |
| 61 | +#define DOT_KERNEL \ |
| 62 | + COUNT \ |
| 63 | +" mov z1.d, #0 \n" \ |
| 64 | +" mov z0.d, #0 \n" \ |
| 65 | +" mov x8, #0 \n" \ |
| 66 | +" movi d1, #0x0 \n" \ |
| 67 | + SETUP_TRUE \ |
| 68 | +" neg x10, x9, lsl #1 \n" \ |
| 69 | +" ands x11, x10, x0 \n" \ |
| 70 | +" b.eq .Lskip_2x \n" \ |
| 71 | + OFFSET_INPUTS \ |
| 72 | +".Lvector_2x: \n" \ |
| 73 | + UPDATE("p0", "%[X_]", "%[Y_]", "z1") \ |
| 74 | + UPDATE("p0", "x12", "x13", "z0") \ |
| 75 | +" sub x8, x8, x10 \n" \ |
| 76 | +" cmp x8, x11 \n" \ |
| 77 | +" b.lo .Lvector_2x \n" \ |
| 78 | + SUM_VECTOR("1") \ |
| 79 | +".Lskip_2x: \n" \ |
| 80 | +" neg x10, x9 \n" \ |
| 81 | +" and x10, x10, x0 \n" \ |
| 82 | +" cmp x8, x10 \n" \ |
| 83 | +" b.hs .Ltail \n" \ |
| 84 | +".Lvector_1x: \n" \ |
| 85 | + UPDATE("p0", "%[X_]", "%[Y_]", "z0") \ |
| 86 | +" add x8, x8, x9 \n" \ |
| 87 | +" cmp x8, x10 \n" \ |
| 88 | +" b.lo .Lvector_1x \n" \ |
| 89 | +".Ltail: \n" \ |
| 90 | +" cmp x10, x0 \n" \ |
| 91 | +" b.eq .Lend \n" \ |
| 92 | + TAIL_WHILE \ |
| 93 | + UPDATE("p1", "%[X_]", "%[Y_]", "z0") \ |
| 94 | +".Lend: \n" \ |
| 95 | + SUM_VECTOR("0") \ |
| 96 | + RET |
51 | 97 |
|
52 |
| - for (BLASLONG i = 0; i < n; i += sve_width * 2) { |
53 |
| - svbool_t pg_a = SVE_WHILELT((uint64_t)i, (uint64_t)n); |
54 |
| - svbool_t pg_b = SVE_WHILELT((uint64_t)(i + sve_width), (uint64_t)n); |
| 98 | +static |
| 99 | +FLOAT |
| 100 | +dot_kernel_sve(BLASLONG n, FLOAT* x, FLOAT* y) |
| 101 | +{ |
| 102 | + FLOAT ret; |
55 | 103 |
|
56 |
| - SVE_TYPE x_vec_a = svld1(pg_a, &x[i]); |
57 |
| - SVE_TYPE y_vec_a = svld1(pg_a, &y[i]); |
58 |
| - SVE_TYPE x_vec_b = svld1(pg_b, &x[i + sve_width]); |
59 |
| - SVE_TYPE y_vec_b = svld1(pg_b, &y[i + sve_width]); |
| 104 | + asm(DOT_KERNEL |
| 105 | + : |
| 106 | + [RET_] "=&w" (ret) |
| 107 | + : |
| 108 | + [N_] "r" (n), |
| 109 | + [X_] "r" (x), |
| 110 | + [Y_] "r" (y) |
| 111 | + :); |
60 | 112 |
|
61 |
| - acc_a = svmla_m(pg_a, acc_a, x_vec_a, y_vec_a); |
62 |
| - acc_b = svmla_m(pg_b, acc_b, x_vec_b, y_vec_b); |
63 |
| - } |
64 |
| - |
65 |
| - return svaddv(SVE_ALL, acc_a) + svaddv(SVE_ALL, acc_b); |
| 113 | + return ret; |
66 | 114 | }
|
0 commit comments