From a83d9931839c9d926364f52befcc8a48ecbc1885 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Tue, 9 Apr 2024 11:39:16 +0200 Subject: [PATCH 1/7] CUDA: refactor host code, dyn. par. blocks --- ggml-cuda.cu | 1 + ggml-cuda/common.cuh | 6 + ggml-cuda/fattn.cu | 542 +++++++++++++++++++------------------------ 3 files changed, 248 insertions(+), 301 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 11adbabd655e1..2cf6c8d98bd89 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -141,6 +141,7 @@ static ggml_cuda_device_info ggml_cuda_init() { info.devices[id].cc = 100*prop.major + 10*prop.minor; #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) info.devices[id].smpb = prop.sharedMemPerBlock; + info.devices[id].nsm = prop.multiProcessorCount; } for (int id = 0; id < info.device_count; ++id) { diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index b0149b7be22b3..989780dbce88c 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -390,6 +390,11 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) { } #endif // defined(GGML_USE_HIPBLAS) +#define FP16_AVAILABLE defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) ? \ + defined(RDNA1) || defined(RDNA2) || defined(RDNA3) : __CUDA_ARCH__ >= CC_PASCAL +#define FP16_MMA_AVAILABLE defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) ? \ + defined(RDNA3) : __CUDA_ARCH__ >= CC_VOLTA + // TODO: move to ggml-common.h static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; @@ -403,6 +408,7 @@ struct ggml_cuda_device_info { struct cuda_device_info { int cc; // compute capability + int nsm; // number of streaming multiprocessors size_t smpb; // max. shared memory per block bool vmm; // virtual memory support size_t vmm_granularity; // granularity of virtual memory diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index 91ef5551e025a..5f1345a7fe94f 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -36,18 +36,17 @@ static __global__ void flash_attn_vec_ext_f16( const int ne1, const int ne2, const int ne3) { -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL +#if FP16_AVAILABLE //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + + const int ic = blockIdx.x / parallel_blocks; // Index of the Q/QKV column to work on. + const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y); + const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic); const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio)); const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) mask; - - if (parallel_blocks == 1) { - Q_f2 += blockIdx.x*nb01/sizeof(float2); - maskh += blockIdx.x*ne11; - } + const half * maskh = (const half *) mask + ne11*ic; const int stride_KV = nb11 / sizeof(half); const int stride_KV2 = nb11 / sizeof(half2); @@ -85,7 +84,7 @@ static __global__ void flash_attn_vec_ext_f16( half2 VKQ = make_half2(0.0f, 0.0f); // Each thread calculates a single VKQ value. - const int k_start = parallel_blocks == 1 ? 0 : blockIdx.x*D; + const int k_start = parallel_blocks == 1 ? 0 : ip*D; for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) { // Calculate KQ tile and keep track of new maximum KQ values: half kqmax_new = kqmax; @@ -168,18 +167,19 @@ static __global__ void flash_attn_vec_ext_f16( return; } + half dst_val = (__low2half(VKQ) + __high2half(VKQ)); if (parallel_blocks == 1) { - dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = (__low2half(VKQ) + __high2half(VKQ)) / kqsum; - } else { - dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = (__low2half(VKQ) + __high2half(VKQ)); + dst_val /= kqsum; + } + dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = dst_val; - if (tid == 0) { - dst_meta[blockIdx.y*parallel_blocks + blockIdx.x] = make_half2(kqmax, kqsum); - } + if (parallel_blocks == 1 || tid != 0) { + return; } + dst_meta[ic*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_half2(kqmax, kqsum); #else NO_DEVICE_CODE; -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL +#endif // FP16_AVAILABLE } template // D == head size, VKQ_stride == num VKQ rows calculated in parallel @@ -212,8 +212,12 @@ static __global__ void flash_attn_ext_f16( const int ne1, const int ne2, const int ne3) { -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA +#if FP16_MMA_AVAILABLE //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + + const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on. + const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. + static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE."); static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16."); constexpr int frag_m = ncols == 8 ? 32 : 16; @@ -233,15 +237,10 @@ static __global__ void flash_attn_ext_f16( constexpr int kqs_padded = FATTN_KQ_STRIDE + 8; const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float * Q_f = (const float *) (Q + nb02* blockIdx.y); + const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0); const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio)); const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape - const half2 * mask2 = (half2 *) mask; - - if (parallel_blocks == 1) { - Q_f += blockIdx.x * ncols*nb01/sizeof(float); - mask2 += blockIdx.x * ncols*ne11/2; - } + const half2 * mask2 = (const half2 *) mask + ne11*(ic0/2); const int stride_Q = nb01 / sizeof(float); const int stride_KV = nb11 / sizeof(half); @@ -283,11 +282,7 @@ static __global__ void flash_attn_ext_f16( if (i0 + WARP_SIZE > D && i >= D) { break; } - if (parallel_blocks == 1) { - KQ[j*D_padded + i] = ncols*blockIdx.x + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; - } else { - KQ[j*D_padded + i] = j == 0 ? Q_f[j*stride_Q + i] * scale : 0.0f; - } + KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; } } @@ -305,8 +300,7 @@ static __global__ void flash_attn_ext_f16( __syncthreads(); // Iterate over ne11 == previous tokens: - const int k_start = parallel_blocks == 1 ? 0 : blockIdx.x*FATTN_KQ_STRIDE; - for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) { + for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) { // Calculate tile of KQ: #pragma unroll for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) { @@ -439,41 +433,39 @@ static __global__ void flash_attn_ext_f16( __syncthreads(); } - if (parallel_blocks == 1) { -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; - if (ncols*blockIdx.x + j >= ne01) { - return; - } - const float KQ_rowsum_j = __low2float(KQ_rowsum[j0/nwarps]) + __high2float(KQ_rowsum[j0/nwarps]); #pragma unroll - for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE > D && i >= D) { - break; - } - dst[D*gridDim.y*(ncols*blockIdx.x + j) + D*blockIdx.y + i] = __half2float(VKQ[j*D_padded + i]) / KQ_rowsum_j; - } + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j_VKQ = j0 + threadIdx.y; + if (ic0 + j_VKQ >= ne01) { + return; } - } else { + const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip; + + const half KQ_rowsum_j = __low2half(KQ_rowsum[j0/nwarps]) + __high2half(KQ_rowsum[j0/nwarps]); #pragma unroll - for (int i0 = 0; i0 < D; i0 += nwarps*WARP_SIZE) { - const int i = i0 + threadIdx.y*WARP_SIZE + threadIdx.x; - if (i0 + nwarps*WARP_SIZE > D && i >= D) { - return; + for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D && i >= D) { + break; + } + half dst_val = VKQ[j_VKQ*D_padded + i]; + if (parallel_blocks == 1) { + dst_val /= KQ_rowsum_j; } - dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + i] = VKQ[i]; + dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val; } - if (threadIdx.y == 0 && threadIdx.x == 0) { - dst_meta[blockIdx.y*parallel_blocks + blockIdx.x] = make_half2( - __low2half(KQ_max[0]), __low2half(KQ_rowsum[0]) + __high2half(KQ_rowsum[0])); + if (parallel_blocks == 1 || threadIdx.x != 0) { + continue; } + + half2 dst_meta_val = KQ_max[j0/nwarps]; + reinterpret_cast(dst_meta_val.y) = KQ_rowsum_j; + dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val; } #else NO_DEVICE_CODE; -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA +#endif // FP16_MMA_AVAILABLE } template // D == head size @@ -482,7 +474,10 @@ static __global__ void flash_attn_combine_results( const float * __restrict__ VKQ_parts, const half2 * __restrict__ VKQ_meta, float * __restrict__ dst) { -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL +#if FP16_AVAILABLE + VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x; + VKQ_meta += parallel_blocks * gridDim.y*blockIdx.x; + dst += D * gridDim.y*blockIdx.x; const int tid = threadIdx.x; __builtin_assume(tid < D); @@ -513,7 +508,7 @@ static __global__ void flash_attn_combine_results( dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator; #else NO_DEVICE_CODE; -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL +#endif // FP16_AVAILABLE } constexpr int get_max_power_of_2(int x) { @@ -540,26 +535,124 @@ static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed."); static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed."); static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed."); -#define FATTN_SWITCH_CASE(D, ncols, nwarps) \ - case ncols: { \ - constexpr int frag_m = (ncols) == 8 && (D) % 32 == 0 ? 32 : 16; \ - flash_attn_ext_f16 \ - <<>> ( \ - (const char *) Q->data, \ - (const char *) K->data, \ - (const char *) V->data, \ - mask ? ((const char *) mask->data) : nullptr, \ - (float *) KQV->data, nullptr, \ - scale, \ - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], \ - K->ne[0], K->ne[1], K->ne[2], K->ne[3], \ - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, \ - Q->nb[1], Q->nb[2], Q->nb[3], \ - K->nb[1], K->nb[2], K->nb[3], \ - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] \ - ); \ - } \ - break; \ +template void launch_fattn_vec_f16( + const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask, + ggml_cuda_pool & pool, cudaStream_t main_stream +) { + ggml_cuda_pool_alloc dst_tmp(pool); + ggml_cuda_pool_alloc dst_tmp_meta(pool); + + if (parallel_blocks > 1) { + dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); + dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); + } + + constexpr int nwarps = ((D) + WARP_SIZE - 1) / WARP_SIZE; + constexpr dim3 block_dim(WARP_SIZE, nwarps, 1); + const dim3 blocks_num(parallel_blocks*Q->ne[1], Q->ne[2], Q->ne[3]); + const int shmem = 0; + + float scale; + memcpy(&scale, KQV->op_params, sizeof(float)); + + flash_attn_vec_ext_f16 + <<>> ( + (const char *) Q->data, + (const char *) K->data, + (const char *) V->data, + mask ? ((const char *) mask->data) : nullptr, + (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + CUDA_CHECK(cudaGetLastError()); + + if ((parallel_blocks) == 1) { + return; + } + + constexpr dim3 block_dim_combine(D, 1, 1); + const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z); + const int shmem_combine = 0; + + flash_attn_combine_results + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + CUDA_CHECK(cudaGetLastError()); +} + +template void launch_fattn_f16_impl( + const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask, + ggml_cuda_pool & pool, cudaStream_t main_stream +) { + ggml_cuda_pool_alloc dst_tmp(pool); + ggml_cuda_pool_alloc dst_tmp_meta(pool); + + if (parallel_blocks > 1) { + dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); + dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); + } + + constexpr int frag_m = (cols_per_block) == 8 && (D) % 32 == 0 ? 32 : 16; + constexpr dim3 block_dim(WARP_SIZE, nwarps, 1); + const dim3 blocks_num(parallel_blocks*(Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]); + const int shmem = 0; + + float scale; + memcpy(&scale, KQV->op_params, sizeof(float)); + + flash_attn_ext_f16 + <<>> ( + (const char *) Q->data, + (const char *) K->data, + (const char *) V->data, + mask ? ((const char *) mask->data) : nullptr, + (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + CUDA_CHECK(cudaGetLastError()); + + if ((parallel_blocks) == 1) { + return; + } + + constexpr dim3 block_dim_combine(D, 1, 1); + const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z); + const int shmem_combine = 0; + + flash_attn_combine_results + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + CUDA_CHECK(cudaGetLastError()); +} + +template void launch_fattn_f16( + const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask, + const int nsm, ggml_cuda_pool & pool, cudaStream_t main_stream +) { + const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3]; + + if (4*blocks_num_pb1 < 2*nsm) { + launch_fattn_f16_impl(Q, K, V, KQV, mask, pool, main_stream); + return; + } + if (2*blocks_num_pb1 < 2*nsm) { + launch_fattn_f16_impl(Q, K, V, KQV, mask, pool, main_stream); + return; + } + launch_fattn_f16_impl(Q, K, V, KQV, mask, pool, main_stream); +} void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; @@ -583,259 +676,106 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst ggml_cuda_set_device(ctx.device); - const cudaStream_t main_stream = ctx.stream(); - - float scale; - memcpy(&scale, KQV->op_params, sizeof(float)); - - if (Q->ne[1] == 1) { + if (Q->ne[1] == 1 && Q->ne[0] % WARP_SIZE == 0) { constexpr int parallel_blocks = 4; - - ggml_cuda_pool_alloc dst_tmp(ctx.pool()); - ggml_cuda_pool_alloc dst_tmp_meta(ctx.pool()); - - const int nwarps = (Q->ne[0] + WARP_SIZE - 1) / WARP_SIZE; - const dim3 blocks_num(parallel_blocks*Q->ne[1], Q->ne[2], Q->ne[3]); - const dim3 block_dim(WARP_SIZE, nwarps, 1); - const int shmem = 0; - - // Performance of the vector kernel is very bad for head sizes 80 and 112, use the tensor core kernel instead: - constexpr int nwarps_tc = 4; - constexpr dim3 block_dim_tc(WARP_SIZE, nwarps_tc, 1); - - const dim3 blocks_num_combine(1, blocks_num.y, blocks_num.z); - const dim3 block_dim_combine(Q->ne[0], 1, 1); - const int shmem_combine = 0; - - if (parallel_blocks > 1) { - dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); - dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); - } - switch (Q->ne[0]) { case 64: - flash_attn_vec_ext_f16<64, parallel_blocks> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - if (parallel_blocks == 1) { - break; - } - CUDA_CHECK(cudaGetLastError()); - flash_attn_combine_results<64, parallel_blocks> - <<>> - (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); - break; - case 80: - flash_attn_ext_f16<80, 16, nwarps_tc, get_VKQ_stride(80, nwarps_tc, 16), parallel_blocks> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - if (parallel_blocks == 1) { - break; - } - CUDA_CHECK(cudaGetLastError()); - flash_attn_combine_results<80, parallel_blocks> - <<>> - (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + launch_fattn_vec_f16< 64, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); break; case 96: - flash_attn_vec_ext_f16<96, parallel_blocks> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - if (parallel_blocks == 1) { - break; - } - CUDA_CHECK(cudaGetLastError()); - flash_attn_combine_results<96, parallel_blocks> - <<>> - (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); - break; - case 112: - flash_attn_vec_ext_f16<112, parallel_blocks> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - if (parallel_blocks == 1) { - break; - } - CUDA_CHECK(cudaGetLastError()); - flash_attn_combine_results<112, parallel_blocks> - <<>> - (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + launch_fattn_vec_f16< 96, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); break; case 128: - flash_attn_vec_ext_f16<128, parallel_blocks> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - if (parallel_blocks == 1) { - break; - } - CUDA_CHECK(cudaGetLastError()); - flash_attn_combine_results<128, parallel_blocks> - <<>> - (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + launch_fattn_vec_f16<128, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); break; case 256: - flash_attn_vec_ext_f16<256, parallel_blocks> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - if (parallel_blocks == 1) { - break; - } - CUDA_CHECK(cudaGetLastError()); - flash_attn_combine_results<256, parallel_blocks> - <<>> - (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + launch_fattn_vec_f16<256, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); break; default: GGML_ASSERT(false); break; } - CUDA_CHECK(cudaGetLastError()); return; } - int cols_per_block; - if (Q->ne[1] >= 64 && (Q->ne[0] <= 128 || ggml_cuda_info().devices[ctx.device].cc >= CC_AMPERE)) { - cols_per_block = 32; - } else if (Q->ne[1] >= 32 || Q->ne[0] % 32 != 0) { - cols_per_block = 16; - } else { - cols_per_block = 8; - } - constexpr int nwarps = 4; - const dim3 blocks_num((Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]); - const dim3 block_dim(WARP_SIZE, nwarps, 1); - const size_t shmem = 0; + const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; - switch (Q->ne[0]) { - case 64: switch (cols_per_block) { - FATTN_SWITCH_CASE(64, 8, nwarps); - FATTN_SWITCH_CASE(64, 16, nwarps); - FATTN_SWITCH_CASE(64, 32, nwarps); - default: - fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); - GGML_ASSERT(false); + if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) { + constexpr int cols_per_block = 8; + constexpr int nwarps = 4; + switch (Q->ne[0]) { + case 64: + launch_fattn_f16< 64, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; - } break; - case 80: switch (cols_per_block) { - // FATTN_SWITCH_CASE(80, 8, nwarps); - FATTN_SWITCH_CASE(80, 16, nwarps); - FATTN_SWITCH_CASE(80, 32, nwarps); - default: - fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); - GGML_ASSERT(false); + case 96: + launch_fattn_f16< 96, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; - } break; - case 96: switch (cols_per_block) { - FATTN_SWITCH_CASE(96, 8, nwarps); - FATTN_SWITCH_CASE(96, 16, nwarps); - FATTN_SWITCH_CASE(96, 32, nwarps); - default: - fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); - GGML_ASSERT(false); + case 128: + launch_fattn_f16<128, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; - } break; - case 112: switch (cols_per_block) { - // FATTN_SWITCH_CASE(112, 8, nwarps); - FATTN_SWITCH_CASE(112, 16, nwarps); - FATTN_SWITCH_CASE(112, 32, nwarps); - default: - fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); - GGML_ASSERT(false); + case 256: + launch_fattn_f16<256, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; - } break; - case 128: switch (cols_per_block) { - FATTN_SWITCH_CASE(128, 8, nwarps); - FATTN_SWITCH_CASE(128, 16, nwarps); - FATTN_SWITCH_CASE(128, 32, nwarps); default: - fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); GGML_ASSERT(false); break; - } break; - case 256: switch (cols_per_block) { - FATTN_SWITCH_CASE(256, 8, nwarps); - FATTN_SWITCH_CASE(256, 16, nwarps); - FATTN_SWITCH_CASE(256, 32, nwarps); + } + return; + } + + if (Q->ne[1] <= 32) { + constexpr int cols_per_block = 16; + constexpr int nwarps = 4; + switch (Q->ne[0]) { + case 64: + launch_fattn_f16< 64, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 80: + launch_fattn_f16< 80, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 96: + launch_fattn_f16< 96, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 112: + launch_fattn_f16<112, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 128: + launch_fattn_f16<128, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 256: + launch_fattn_f16<256, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; default: - fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); GGML_ASSERT(false); break; - } break; + } + return; + } + + constexpr int cols_per_block = 32; + constexpr int nwarps = 4; + switch (Q->ne[0]) { + case 64: + launch_fattn_f16< 64, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 80: + launch_fattn_f16< 80, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 96: + launch_fattn_f16< 96, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 112: + launch_fattn_f16<112, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 128: + launch_fattn_f16<128, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 256: + launch_fattn_f16<256, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; default: GGML_ASSERT(false); break; } - CUDA_CHECK(cudaGetLastError()); + return; } From 359d0f565e63b1d1527e18919d68e23e83a9c1aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sat, 13 Apr 2024 22:05:43 +0200 Subject: [PATCH 2/7] fix flash_attn_vec_f16 race condition --- ggml-cuda/fattn.cu | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index 5f1345a7fe94f..36479b2170979 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -149,6 +149,8 @@ static __global__ void flash_attn_vec_ext_f16( VKQ += V_k*KQ2[k0/2]; } } + + __syncthreads(); } if (tid >= D) { @@ -547,7 +549,7 @@ template void launch_fattn_vec_f16( dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); } - constexpr int nwarps = ((D) + WARP_SIZE - 1) / WARP_SIZE; + constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE; constexpr dim3 block_dim(WARP_SIZE, nwarps, 1); const dim3 blocks_num(parallel_blocks*Q->ne[1], Q->ne[2], Q->ne[3]); const int shmem = 0; @@ -561,7 +563,7 @@ template void launch_fattn_vec_f16( (const char *) K->data, (const char *) V->data, mask ? ((const char *) mask->data) : nullptr, - (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, + parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, scale, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], @@ -572,7 +574,7 @@ template void launch_fattn_vec_f16( ); CUDA_CHECK(cudaGetLastError()); - if ((parallel_blocks) == 1) { + if (parallel_blocks == 1) { return; } From 049533d99fdc4c84ba38c0fb3468ac1ced017052 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Mon, 15 Apr 2024 16:05:07 +0200 Subject: [PATCH 3/7] flush softmax exp below threshold to 0 --- ggml-cuda/fattn.cu | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index 36479b2170979..f6289822e0ea0 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -3,8 +3,9 @@ #include -#define FATTN_KQ_STRIDE 256 -#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction. +#define FATTN_KQ_STRIDE 256 +#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction. +#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs. template // D == head size __launch_bounds__(((D + WARP_SIZE - 1) / WARP_SIZE)*WARP_SIZE, 1) @@ -338,10 +339,16 @@ static __global__ void flash_attn_ext_f16( #pragma unroll for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { const int k = k0 + threadIdx.x; - KQ_max_new = __hmax2(KQ_max_new, KQ2[j*(kqs_padded/2) + k]); + half2 val = KQ2[j*(kqs_padded/2) + k]; + val += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); + KQ_max_new = __hmax2(KQ_max_new, val); + KQ2[j*(kqs_padded/2) + k] = val; } KQ_max_new = __half2half2(warp_reduce_max(__hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); - KQ_max_scale[j0/nwarps] = h2exp(KQ_max[j0/nwarps] - KQ_max_new); + const half2 diff = KQ_max[j0/nwarps] - KQ_max_new; + KQ_max_scale[j0/nwarps] = h2exp(diff); + const uint ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); + *((uint *) &KQ_max_scale[j0/nwarps]) &= ftz_mask; KQ_max[j0/nwarps] = KQ_max_new; half2 KQ_rowsum_add = make_half2(0.0f, 0.0f); @@ -350,8 +357,10 @@ static __global__ void flash_attn_ext_f16( const int k = k0 + threadIdx.x; half2 val = KQ2[j*(kqs_padded/2) + k]; - val += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); - val = h2exp(val - KQ_max[j0/nwarps]); + const half2 diff = val - KQ_max[j0/nwarps]; + val = h2exp(diff); + const uint ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); + *((uint *) &val) &= ftz_mask; KQ_rowsum_add += val; KQ2[j*(kqs_padded/2) + k] = val; } @@ -501,7 +510,10 @@ static __global__ void flash_attn_combine_results( float VKQ_denominator = 0.0f; #pragma unroll for (int l = 0; l < parallel_blocks; ++l) { - float KQ_max_scale = hexp(__low2half(meta[l]) - kqmax); + const half diff = __low2half(meta[l]) - kqmax; + float KQ_max_scale = hexp(diff); + const uint ftz_mask = 0xFFFFFFFF * (diff > __float2half(SOFTMAX_FTZ_THRESHOLD)); + *((uint *) &KQ_max_scale) &= ftz_mask; VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid]; VKQ_denominator += KQ_max_scale * __high2float(meta[l]); From aef96ff40abe7fc6040f54078f0a845378bf3b68 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Tue, 16 Apr 2024 15:58:21 +0200 Subject: [PATCH 4/7] store temp KQ in registers --- ggml-cuda/fattn.cu | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index f6289822e0ea0..b889cdb3b9b01 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -335,14 +335,21 @@ static __global__ void flash_attn_ext_f16( for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j = j0 + threadIdx.y; + half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)]; +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + + KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k]; + } + half2 KQ_max_new = KQ_max[j0/nwarps]; #pragma unroll for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { const int k = k0 + threadIdx.x; - half2 val = KQ2[j*(kqs_padded/2) + k]; - val += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); - KQ_max_new = __hmax2(KQ_max_new, val); - KQ2[j*(kqs_padded/2) + k] = val; + + KQ2_tmp[k0/WARP_SIZE] += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); + KQ_max_new = __hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]); } KQ_max_new = __half2half2(warp_reduce_max(__hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); const half2 diff = KQ_max[j0/nwarps] - KQ_max_new; @@ -356,13 +363,12 @@ static __global__ void flash_attn_ext_f16( for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { const int k = k0 + threadIdx.x; - half2 val = KQ2[j*(kqs_padded/2) + k]; - const half2 diff = val - KQ_max[j0/nwarps]; - val = h2exp(diff); + const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max[j0/nwarps]; + KQ2_tmp[k0/WARP_SIZE] = h2exp(diff); const uint ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); - *((uint *) &val) &= ftz_mask; - KQ_rowsum_add += val; - KQ2[j*(kqs_padded/2) + k] = val; + *((uint *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask; + KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE]; + KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE]; } KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); From a9d6591652a239ec8519dba78505592d7665d282 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Tue, 16 Apr 2024 16:22:29 +0200 Subject: [PATCH 5/7] Calculate KQ as FP32 if KQV has GGML_PREC_F32 --- ggml-cuda/fattn.cu | 286 +++++++++++++++++++++++++++++++++------------ 1 file changed, 213 insertions(+), 73 deletions(-) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index b889cdb3b9b01..dda344531335c 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -1,6 +1,7 @@ #include "common.cuh" #include "fattn.cuh" +#include #include #define FATTN_KQ_STRIDE 256 @@ -185,7 +186,8 @@ static __global__ void flash_attn_vec_ext_f16( #endif // FP16_AVAILABLE } -template // D == head size, VKQ_stride == num VKQ rows calculated in parallel +// D == head size, VKQ_stride == num VKQ rows calculated in parallel: +template __launch_bounds__(nwarps*WARP_SIZE, 1) static __global__ void flash_attn_ext_f16( const char * __restrict__ Q, @@ -229,7 +231,8 @@ static __global__ void flash_attn_ext_f16( typedef nvcuda::wmma::fragment frag_a_K; typedef nvcuda::wmma::fragment frag_a_V; typedef nvcuda::wmma::fragment frag_b; - typedef nvcuda::wmma::fragment frag_c; + typedef nvcuda::wmma::fragment frag_c_KQ; + typedef nvcuda::wmma::fragment frag_c_VKQ; constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel. constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy. @@ -238,12 +241,14 @@ static __global__ void flash_attn_ext_f16( // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts: constexpr int D_padded = D + 8; constexpr int kqs_padded = FATTN_KQ_STRIDE + 8; + constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half); const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0); const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio)); const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape - const half2 * mask2 = (const half2 *) mask + ne11*(ic0/2); + const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0; + const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2); const int stride_Q = nb01 / sizeof(float); const int stride_KV = nb11 / sizeof(half); @@ -251,14 +256,29 @@ static __global__ void flash_attn_ext_f16( frag_b Q_b[D/16][ncols/frag_n]; // A single buffer for temporarily holding tiles of KQ and VKQ parts: - constexpr int mem_KQ = ncols*kqs_padded; + constexpr int mem_KQ = ncols*kqs_padded*kqar; constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded; __shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts]; + float * KQ_f = (float *) KQ; half2 * KQ2 = (half2 *) KQ; - half2 KQ_rowsum[ncols/nwarps] = {{ 0.0f, 0.0f}}; - half2 KQ_max[ncols/nwarps] = {{-HALF_MAX_HALF, -HALF_MAX_HALF}}; - half2 KQ_max_scale[ncols/nwarps] = {{ 0.0f, 0.0f}}; + float KQ_rowsum_f[ncols/nwarps] = {0.0f}; + float KQ_max_f[ncols/nwarps]; + float KQ_max_scale_f[ncols/nwarps] = {0.0f}; + +#pragma unroll + for (int j = 0; j < ncols/nwarps; ++j) { + KQ_max_f[j] = -FLT_MAX/2.0f; + } + + half2 KQ_rowsum_h2[ncols/nwarps] = {{0.0f, 0.0f}}; + half2 KQ_max_h2[ncols/nwarps]; + half2 KQ_max_scale_h2[ncols/nwarps] = {{0.0f, 0.0f}}; + +#pragma unroll + for (int j = 0; j < ncols/nwarps; ++j) { + KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF); + } __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice. half2 * VKQ2 = (half2 *) VKQ; @@ -307,7 +327,7 @@ static __global__ void flash_attn_ext_f16( // Calculate tile of KQ: #pragma unroll for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) { - frag_c KQ_c[ncols/frag_n]; + frag_c_KQ KQ_c[ncols/frag_n]; #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f); @@ -323,7 +343,7 @@ static __global__ void flash_attn_ext_f16( } #pragma unroll for (int j0 = 0; j0 < ncols; j0 += frag_n) { - nvcuda::wmma::store_matrix_sync(KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major); + nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major); } } @@ -335,45 +355,90 @@ static __global__ void flash_attn_ext_f16( for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j = j0 + threadIdx.y; - half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)]; + if (std::is_same::value) { + float KQ_f_tmp[FATTN_KQ_STRIDE / WARP_SIZE]; #pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { - const int k = k0 + threadIdx.x; + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; - KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k]; - } + KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k]; + } - half2 KQ_max_new = KQ_max[j0/nwarps]; + float KQ_max_new = KQ_max_f[j0/nwarps]; #pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { - const int k = k0 + threadIdx.x; + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; - KQ2_tmp[k0/WARP_SIZE] += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); - KQ_max_new = __hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]); - } - KQ_max_new = __half2half2(warp_reduce_max(__hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); - const half2 diff = KQ_max[j0/nwarps] - KQ_max_new; - KQ_max_scale[j0/nwarps] = h2exp(diff); - const uint ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); - *((uint *) &KQ_max_scale[j0/nwarps]) &= ftz_mask; - KQ_max[j0/nwarps] = KQ_max_new; + KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f; + KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]); + } + KQ_max_new = warp_reduce_max(KQ_max_new); + + const float diff = KQ_max_f[j0/nwarps] - KQ_max_new; + KQ_max_scale_f[j0/nwarps] = expf(diff); + if (diff <= SOFTMAX_FTZ_THRESHOLD) { + KQ_max_scale_f[j0/nwarps] = 0.0f; + } + KQ_max_f[j0/nwarps] = KQ_max_new; + + float KQ_rowsum_add = 0.0f; +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; - half2 KQ_rowsum_add = make_half2(0.0f, 0.0f); + const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps]; + KQ_f_tmp[k0/WARP_SIZE] = expf(diff); + if (diff <= SOFTMAX_FTZ_THRESHOLD) { + KQ_f_tmp[k0/WARP_SIZE] = 0.0f; + } + KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE]; + KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/WARP_SIZE]; + } + KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); + + // Scale previous KQ_rowsum to account for a potential increase in KQ_max: + KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add; + } else { + half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)]; +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + + KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k]; + } + + half2 KQ_max_new = KQ_max_h2[j0/nwarps]; #pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { - const int k = k0 + threadIdx.x; + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; - const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max[j0/nwarps]; - KQ2_tmp[k0/WARP_SIZE] = h2exp(diff); + KQ2_tmp[k0/WARP_SIZE] += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); + KQ_max_new = __hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]); + } + KQ_max_new = __half2half2(warp_reduce_max(__hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); + const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new; + KQ_max_scale_h2[j0/nwarps] = h2exp(diff); const uint ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); - *((uint *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask; - KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE]; - KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE]; - } - KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); + *((uint *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask; + KQ_max_h2[j0/nwarps] = KQ_max_new; - // Scale previous KQ_rowsum to account for a potential increase in KQ_max: - KQ_rowsum[j0/nwarps] = KQ_max_scale[j0/nwarps]*KQ_rowsum[j0/nwarps] + KQ_rowsum_add; + half2 KQ_rowsum_add = make_half2(0.0f, 0.0f); +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + + const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps]; + KQ2_tmp[k0/WARP_SIZE] = h2exp(diff); + const uint ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); + *((uint *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask; + KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE]; + KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE]; + } + KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); + + // Scale previous KQ_rowsum to account for a potential increase in KQ_max: + KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add; + } } __syncthreads(); @@ -386,12 +451,12 @@ static __global__ void flash_attn_ext_f16( const int k = k0 + (threadIdx.y % VKQ_ratio)*16; nvcuda::wmma::load_matrix_sync( KQ_b[k0/(VKQ_ratio*16)][j0/frag_n], - KQ + j0*kqs_padded + k, - kqs_padded); + KQ + j0*(kqar*kqs_padded) + k, + kqar*kqs_padded); } } - frag_c VKQ_c[D/VKQ_stride][ncols/frag_n]; + frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n]; #pragma unroll for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) { #pragma unroll @@ -431,6 +496,14 @@ static __global__ void flash_attn_ext_f16( #pragma unroll for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j = j0 + threadIdx.y; + + half2 VKQ_scale; + if (std::is_same::value) { + VKQ_scale = make_half2(KQ_max_scale_f[j0/nwarps], KQ_max_scale_f[j0/nwarps]); + } else { + VKQ_scale = KQ_max_scale_h2[j0/nwarps]; + } + #pragma unroll for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; @@ -443,7 +516,7 @@ static __global__ void flash_attn_ext_f16( for (int l = 0; l < VKQ_ratio; ++l) { VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i]; } - VKQ2[j*(D_padded/2) + i] = KQ_max_scale[j0/nwarps]*VKQ2[j*(D_padded/2) + i] + VKQ_add; + VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add; } } @@ -458,14 +531,20 @@ static __global__ void flash_attn_ext_f16( } const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip; - const half KQ_rowsum_j = __low2half(KQ_rowsum[j0/nwarps]) + __high2half(KQ_rowsum[j0/nwarps]); + float KQ_rowsum_j; + if (std::is_same::value) { + KQ_rowsum_j = KQ_rowsum_f[j0/nwarps]; + } else { + KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]); + } + #pragma unroll for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; if (i0 + WARP_SIZE > D && i >= D) { break; } - half dst_val = VKQ[j_VKQ*D_padded + i]; + float dst_val = VKQ[j_VKQ*D_padded + i]; if (parallel_blocks == 1) { dst_val /= KQ_rowsum_j; } @@ -476,7 +555,12 @@ static __global__ void flash_attn_ext_f16( continue; } - half2 dst_meta_val = KQ_max[j0/nwarps]; + half2 dst_meta_val; + if (std::is_same::value) { + reinterpret_cast(dst_meta_val.x) = KQ_max_f[j0/nwarps]; + } else { + dst_meta_val = KQ_max_h2[j0/nwarps]; + } reinterpret_cast(dst_meta_val.y) = KQ_rowsum_j; dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val; } @@ -606,7 +690,7 @@ template void launch_fattn_vec_f16( CUDA_CHECK(cudaGetLastError()); } -template void launch_fattn_f16_impl( +template void launch_fattn_f16_impl( const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask, ggml_cuda_pool & pool, cudaStream_t main_stream ) { @@ -626,7 +710,7 @@ template void launc float scale; memcpy(&scale, KQV->op_params, sizeof(float)); - flash_attn_ext_f16 + flash_attn_ext_f16 <<>> ( (const char *) Q->data, (const char *) K->data, @@ -657,21 +741,21 @@ template void launc CUDA_CHECK(cudaGetLastError()); } -template void launch_fattn_f16( +template void launch_fattn_f16( const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask, const int nsm, ggml_cuda_pool & pool, cudaStream_t main_stream ) { const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3]; if (4*blocks_num_pb1 < 2*nsm) { - launch_fattn_f16_impl(Q, K, V, KQV, mask, pool, main_stream); + launch_fattn_f16_impl(Q, K, V, KQV, mask, pool, main_stream); return; } if (2*blocks_num_pb1 < 2*nsm) { - launch_fattn_f16_impl(Q, K, V, KQV, mask, pool, main_stream); + launch_fattn_f16_impl(Q, K, V, KQV, mask, pool, main_stream); return; } - launch_fattn_f16_impl(Q, K, V, KQV, mask, pool, main_stream); + launch_fattn_f16_impl(Q, K, V, KQV, mask, pool, main_stream); } void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -696,15 +780,73 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst ggml_cuda_set_device(ctx.device); - if (Q->ne[1] == 1 && Q->ne[0] % WARP_SIZE == 0) { + const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; + + const int32_t precision = KQV->op_params[1]; + + if (precision != GGML_PREC_DEFAULT) { + if (Q->ne[1] <= 32 || Q->ne[0] > 128) { + constexpr int cols_per_block = 16; + constexpr int nwarps = 4; + switch (Q->ne[0]) { + case 64: + launch_fattn_f16< 64, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 80: + launch_fattn_f16< 80, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 96: + launch_fattn_f16< 96, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 112: + launch_fattn_f16<112, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 128: + launch_fattn_f16<128, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 256: + launch_fattn_f16<256, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + default: + GGML_ASSERT(false); + break; + } + } else { + constexpr int cols_per_block = 32; + constexpr int nwarps = 4; + switch (Q->ne[0]) { + case 64: + launch_fattn_f16< 64, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 80: + launch_fattn_f16< 80, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 96: + launch_fattn_f16< 96, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 112: + launch_fattn_f16<112, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 128: + launch_fattn_f16<128, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + // case 256: + // launch_fattn_f16<256, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + // break; + default: + GGML_ASSERT(false); + break; + } + } + return; + } + + if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) { constexpr int parallel_blocks = 4; switch (Q->ne[0]) { case 64: launch_fattn_vec_f16< 64, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); break; - case 96: - launch_fattn_vec_f16< 96, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); - break; case 128: launch_fattn_vec_f16<128, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); break; @@ -718,23 +860,21 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst return; } - const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; - if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) { constexpr int cols_per_block = 8; constexpr int nwarps = 4; switch (Q->ne[0]) { case 64: - launch_fattn_f16< 64, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; case 96: - launch_fattn_f16< 96, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; case 128: - launch_fattn_f16<128, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; case 256: - launch_fattn_f16<256, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; default: GGML_ASSERT(false); @@ -748,22 +888,22 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst constexpr int nwarps = 4; switch (Q->ne[0]) { case 64: - launch_fattn_f16< 64, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; case 80: - launch_fattn_f16< 80, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16< 80, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; case 96: - launch_fattn_f16< 96, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; case 112: - launch_fattn_f16<112, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16<112, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; case 128: - launch_fattn_f16<128, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; case 256: - launch_fattn_f16<256, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; default: GGML_ASSERT(false); @@ -776,22 +916,22 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst constexpr int nwarps = 4; switch (Q->ne[0]) { case 64: - launch_fattn_f16< 64, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; case 80: - launch_fattn_f16< 80, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16< 80, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; case 96: - launch_fattn_f16< 96, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; case 112: - launch_fattn_f16<112, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16<112, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; case 128: - launch_fattn_f16<128, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; case 256: - launch_fattn_f16<256, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; default: GGML_ASSERT(false); From 4e4d58ab6ac002942cbfb9a2879368cce0e7dd5b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Wed, 17 Apr 2024 16:29:28 +0200 Subject: [PATCH 6/7] Add __hgt2_mask implementation for CUDA 11 --- ggml-cuda/common.cuh | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index 989780dbce88c..ac6de643d668e 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -306,6 +306,13 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX } +#if CUDART_VERSION < 12000 +static __device__ __forceinline__ uint __hgt2_mask(const half2 a, const half2 b) { + const uint mask_low = 0x0000FFFF * ( __low2half(a) > __low2half(b)); + const uint mask_high = 0xFFFF0000 * (__high2half(a) > __high2half(b)); + return mask_low | mask_high; +} +#endif // CUDART_VERSION < 12000 #if defined(GGML_USE_HIPBLAS) #define __CUDA_ARCH__ 1300 From 44ca5764d621d8693f8f8c01b9b920d7620c9076 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Wed, 17 Apr 2024 17:31:03 +0200 Subject: [PATCH 7/7] fix KQ FP32 precision fpr parallel_blocks > 1 --- ggml-cuda/fattn.cu | 48 +++++++++++++++++++++++----------------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index dda344531335c..4cf2907e8d10c 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -15,8 +15,8 @@ static __global__ void flash_attn_vec_ext_f16( const char * __restrict__ K, const char * __restrict__ V, const char * __restrict__ mask, - float * __restrict__ dst, - half2 * __restrict__ dst_meta, + float * __restrict__ dst, + float2 * __restrict__ dst_meta, const float scale, const int ne00, const int ne01, @@ -180,7 +180,7 @@ static __global__ void flash_attn_vec_ext_f16( if (parallel_blocks == 1 || tid != 0) { return; } - dst_meta[ic*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_half2(kqmax, kqsum); + dst_meta[ic*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax, kqsum); #else NO_DEVICE_CODE; #endif // FP16_AVAILABLE @@ -194,8 +194,8 @@ static __global__ void flash_attn_ext_f16( const char * __restrict__ K, const char * __restrict__ V, const char * __restrict__ mask, - float * __restrict__ dst, - half2 * __restrict__ dst_meta, + float * __restrict__ dst, + float2 * __restrict__ dst_meta, const float scale, const int ne00, const int ne01, @@ -555,13 +555,13 @@ static __global__ void flash_attn_ext_f16( continue; } - half2 dst_meta_val; + float2 dst_meta_val; if (std::is_same::value) { - reinterpret_cast(dst_meta_val.x) = KQ_max_f[j0/nwarps]; + dst_meta_val.x = KQ_max_f[j0/nwarps]; } else { - dst_meta_val = KQ_max_h2[j0/nwarps]; + dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]); } - reinterpret_cast(dst_meta_val.y) = KQ_rowsum_j; + dst_meta_val.y = KQ_rowsum_j; dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val; } #else @@ -572,8 +572,8 @@ static __global__ void flash_attn_ext_f16( template // D == head size __launch_bounds__(D, 1) static __global__ void flash_attn_combine_results( - const float * __restrict__ VKQ_parts, - const half2 * __restrict__ VKQ_meta, + const float * __restrict__ VKQ_parts, + const float2 * __restrict__ VKQ_meta, float * __restrict__ dst) { #if FP16_AVAILABLE VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x; @@ -583,30 +583,30 @@ static __global__ void flash_attn_combine_results( const int tid = threadIdx.x; __builtin_assume(tid < D); - __shared__ half2 meta[parallel_blocks]; - if (tid < parallel_blocks) { - meta[threadIdx.x] = VKQ_meta[blockIdx.y*parallel_blocks + tid]; + __shared__ float2 meta[parallel_blocks]; + if (tid < 2*parallel_blocks) { + ((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.y*(2*parallel_blocks) + tid]; } __syncthreads(); - half kqmax = __low2half(meta[0]); + float kqmax = meta[0].x; #pragma unroll for (int l = 1; l < parallel_blocks; ++l) { - kqmax = __hmax(kqmax, __low2half(meta[l])); + kqmax = max(kqmax, meta[l].x); } float VKQ_numerator = 0.0f; float VKQ_denominator = 0.0f; #pragma unroll for (int l = 0; l < parallel_blocks; ++l) { - const half diff = __low2half(meta[l]) - kqmax; - float KQ_max_scale = hexp(diff); - const uint ftz_mask = 0xFFFFFFFF * (diff > __float2half(SOFTMAX_FTZ_THRESHOLD)); + const float diff = meta[l].x - kqmax; + const float KQ_max_scale = expf(diff); + const uint ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD); *((uint *) &KQ_max_scale) &= ftz_mask; VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid]; - VKQ_denominator += KQ_max_scale * __high2float(meta[l]); + VKQ_denominator += KQ_max_scale * meta[l].y; } dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator; @@ -643,8 +643,8 @@ template void launch_fattn_vec_f16( const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask, ggml_cuda_pool & pool, cudaStream_t main_stream ) { - ggml_cuda_pool_alloc dst_tmp(pool); - ggml_cuda_pool_alloc dst_tmp_meta(pool); + ggml_cuda_pool_alloc dst_tmp(pool); + ggml_cuda_pool_alloc dst_tmp_meta(pool); if (parallel_blocks > 1) { dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); @@ -694,8 +694,8 @@ template dst_tmp(pool); - ggml_cuda_pool_alloc dst_tmp_meta(pool); + ggml_cuda_pool_alloc dst_tmp(pool); + ggml_cuda_pool_alloc dst_tmp_meta(pool); if (parallel_blocks > 1) { dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));