Skip to content

unroll 2 loops, int64_t -> int, 309 µs #4

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 3, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 47 additions & 35 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6467,10 +6467,22 @@ static __global__ void flash_attn_ext_f16(
half16x16_acc lo[Q16][D16];

// load heads from Q to shared memory
for (int64_t j = warp_id; j < Q; j += num_warps) {
#pragma unroll
for (int j0 = 0; j0 < Q; j0 += num_warps) {
const int j = j0 + warp_id;
if (j >= Q) {
break;
}

const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));

for (int64_t i = lane_id; i < D2; i += NW) {
#pragma unroll
for (int i0 = 0; i0 < D2; i0 += NW) {
const int i = i0 + lane_id;
if (i >= D2) {
break;
}

if (iq1 + j < ne01) {
sq2[j*T2 + i] = __float22half2_rn(q2[i]);
} else {
Expand All @@ -6482,15 +6494,15 @@ static __global__ void flash_attn_ext_f16(
nvcuda::wmma::fill_fragment(zr, 0.0);

// zero out lo
for (int64_t j = 0; j < Q16; ++j) {
for (int64_t i = 0; i < D16; ++i) {
for (int j = 0; j < Q16; ++j) {
for (int i = 0; i < D16; ++i) {
nvcuda::wmma::fill_fragment(lo[j][i], 0.0);
}
}

// zero out shared memory SH
for (int64_t j = 0; j < Q; ++j) {
for (int64_t i = lane_id; i < SH; i += NW) {
for (int j = 0; j < Q; ++j) {
for (int i = lane_id; i < SH; i += NW) {
ss[j*T + i] = 0.0;
}
}
Expand Down Expand Up @@ -6531,8 +6543,8 @@ static __global__ void flash_attn_ext_f16(

// load the queries from shared memory into local memory
half16x16_a mq[Q16][D16];
for (int64_t j = 0; j < Q16; ++j) {
for (int64_t i = 0; i < D16; ++i) {
for (int j = 0; j < Q16; ++j) {
for (int i = 0; i < D16; ++i) {
nvcuda::wmma::load_matrix_sync(mq[j][i], sq + 16*j*T + i*16, T);
}
}
Expand All @@ -6549,28 +6561,28 @@ static __global__ void flash_attn_ext_f16(

// loop over the KV cache
// each simdgroup handles blocks of Q rows and C columns
for (int64_t ic = C*warp_id; ic < ne11; ic += C*num_warps) {
for (int ic = C*warp_id; ic < ne11; ic += C*num_warps) {
// Q*K^T
{
for (int cc = 0; cc < C/16; ++cc) {
half16x16_acc mqk[Q16];
for (int64_t j = 0; j < Q16; ++j) {
for (int j = 0; j < Q16; ++j) {
nvcuda::wmma::fill_fragment(mqk[j], 0);
}

const half * pk = (const half *) ((const char *) k + ((ic + 16*cc)*nb11 + ik2*nb12 + ik3*nb13));

for (int64_t i = 0; i < D16; ++i) {
for (int i = 0; i < D16; ++i) {
half16x16_bT mk; // transposed key
nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half));

for (int64_t j = 0; j < Q16; ++j) {
for (int j = 0; j < Q16; ++j) {
nvcuda::wmma::mma_sync(mqk[j], mq[j][i], mk, mqk[j]);
}
}

// mqk = mqk*scale + mask
for (int64_t j = 0; j < Q16; ++j) {
for (int j = 0; j < Q16; ++j) {
half16x16_a mqka;
half16x16_acc mm;
if(mp) {
Expand All @@ -6592,8 +6604,8 @@ static __global__ void flash_attn_ext_f16(

// online softmax
if (C == 32) {
for (int64_t j = 0; j < Q; ++j) {
const int64_t p = lane_id;
for (int j = 0; j < Q; ++j) {
const int p = lane_id;

const half m = M[j];
const half s = ss[j*T + p];
Expand All @@ -6615,10 +6627,10 @@ static __global__ void flash_attn_ext_f16(
ss[j*T + p] = vs;
}
} else {
for (int64_t j = 0; j < Q; ++j) {
for (int j = 0; j < Q; ++j) {
const half m = M[j];

for (int64_t p = lane_id; p < C; p += NW) {
for (int p = lane_id; p < C; p += NW) {
const half s = ss[j*T + p];

smax = __hmax(smax, s);
Expand All @@ -6638,7 +6650,7 @@ static __global__ void flash_attn_ext_f16(
// local sum
half ls = 0.0f;

for (int64_t p = lane_id; p < C; p += NW) {
for (int p = lane_id; p < C; p += NW) {
const half s = ss[j*T + p];

const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]);
Expand All @@ -6659,13 +6671,13 @@ static __global__ void flash_attn_ext_f16(
}

// O = diag(ms)*O
for (int64_t j = 0; j < Q16; ++j) {
for (int j = 0; j < Q16; ++j) {
half16x16_a mm;
half16x16_b lob;

nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T);

for (int64_t i = 0; i < D16; ++i) {
for (int i = 0; i < D16; ++i) {
// convert accumulator to matrix_b
nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + 16*j, lo[j][i], T, nvcuda::wmma::mem_row_major);
nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + 16*j, T);
Expand All @@ -6684,17 +6696,17 @@ static __global__ void flash_attn_ext_f16(
const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23));

half16x16_b mk[D16];
for (int64_t i = 0; i < D16; ++i) {
for (int i = 0; i < D16; ++i) {
nvcuda::wmma::load_matrix_sync(mk[i], pv + i*16, nb21/sizeof(half));
}

half16x16_a mv[Q16];
for (int64_t j = 0; j < Q16; ++j) {
for (int j = 0; j < Q16; ++j) {
nvcuda::wmma::load_matrix_sync(mv[j], ss + 16*j*T + 16*cc, T);
}

for (int64_t j = 0; j < Q16; ++j) {
for (int64_t i = 0; i < D16; ++i) {
for (int j = 0; j < Q16; ++j) {
for (int i = 0; i < D16; ++i) {
nvcuda::wmma::mma_sync(lo[j][i], mv[j], mk[i], lo[j][i]);
}
}
Expand All @@ -6703,7 +6715,7 @@ static __global__ void flash_attn_ext_f16(
}

// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
for (int64_t j = 0; j < Q; ++j) {
for (int j = 0; j < Q; ++j) {
if (lane_id == 0) {
ss[j*T + 0] = S[j];
ss[j*T + 1] = M[j];
Expand All @@ -6712,16 +6724,16 @@ static __global__ void flash_attn_ext_f16(
}

// reduce the warps sequentially
for (int64_t sg = 1; sg < num_warps; ++sg) {
for (int sg = 1; sg < num_warps; ++sg) {
half S = __float2half(0.0f);
half M = __float2half(-INFINITY);

__syncthreads();

// each simdgroup stores its output to shared memory, reusing sq
if (warp_id == sg) {
for (int64_t j = 0; j < Q16; ++j) {
for (int64_t i = 0; i < D16; ++i) {
for (int j = 0; j < Q16; ++j) {
for (int i = 0; i < D16; ++i) {
nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
}
}
Expand All @@ -6731,7 +6743,7 @@ static __global__ void flash_attn_ext_f16(

// the first simdgroup accumulates the results from the other simdgroups
if (warp_id == 0) {
for (int64_t j = 0; j < Q; ++j) {
for (int j = 0; j < Q; ++j) {
const half S0 = ss[j*T + 0];
const half S1 = ss[j*T + sg*SH + 0];

Expand All @@ -6755,7 +6767,7 @@ static __global__ void flash_attn_ext_f16(
}

// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
for (int64_t j = 0; j < Q16; ++j) {
for (int j = 0; j < Q16; ++j) {
half16x16_a ms0;
half16x16_a ms1;
half16x16_b t;
Expand All @@ -6764,7 +6776,7 @@ static __global__ void flash_attn_ext_f16(
nvcuda::wmma::load_matrix_sync(ms0, ss + 16*j*T + C + 16*j, T);
nvcuda::wmma::load_matrix_sync(ms1, ss + 16*j*T + C + 16*j + sg*SH, T);

for (int64_t i = 0; i < D16; ++i) {
for (int i = 0; i < D16; ++i) {
nvcuda::wmma::fill_fragment(t2, 0.0);
nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T);
nvcuda::wmma::mma_sync(t2, ms1, t, t2);
Expand All @@ -6781,19 +6793,19 @@ static __global__ void flash_attn_ext_f16(

// store result to shared memory (reuse sq)
if (warp_id == 0) {
for (int64_t j = 0; j < Q16; ++j) {
for (int64_t i = 0; i < D16; ++i) {
for (int j = 0; j < Q16; ++j) {
for (int i = 0; i < D16; ++i) {
nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
}
}
}

// final rescale with 1/S and store to global memory
if (warp_id == 0) {
for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) {
for (int j = 0; j < Q && iq1 + j < ne01; ++j) {
const half S = ss[j*T + 0];

for (int64_t i = lane_id; i < D; i += NW) {
for (int i = lane_id; i < D; i += NW) {
dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i] / S);
}
}
Expand Down