diff --git a/ggml-cuda.cu b/ggml-cuda.cu index ec3837fb88d14..57cfba5857caf 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -104,6 +104,8 @@ #include #include #include +#include +#include #if CUDART_VERSION < 11020 #define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED @@ -122,13 +124,17 @@ #define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed) -#define CC_PASCAL 600 -#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products -#define CC_VOLTA 700 -#define CC_OFFSET_AMD 1000000 -#define CC_RDNA1 (CC_OFFSET_AMD + 1010) -#define CC_RDNA2 (CC_OFFSET_AMD + 1030) -#define CC_RDNA3 (CC_OFFSET_AMD + 1100) +#define CC_PASCAL 600 +#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products +#define CC_VOLTA 700 // minimum compute capability for mma, i.e. tensor cores +#define CC_TURING 750 +#define CC_AMPERE 800 +#define CC_ADA_LOVELACE 890 +#define CC_HOPPER 900 +#define CC_OFFSET_AMD 1000000 +#define CC_RDNA1 (CC_OFFSET_AMD + 1010) +#define CC_RDNA2 (CC_OFFSET_AMD + 1030) +#define CC_RDNA3 (CC_OFFSET_AMD + 1100) #define GGML_CUDA_MAX_NODES 8192 @@ -574,12 +580,14 @@ static std::array g_default_tensor_split = {}; struct cuda_device_capabilities { int cc; // compute capability - size_t smpb; // max. shared memory per block + int nsm; // number of streaming multiprocessors + size_t smem; // shared memory per SM + size_t smempb; // max. shared memory per block bool vmm; // virtual memory support size_t vmm_granularity; // granularity of virtual memory }; -static cuda_device_capabilities g_device_caps[GGML_CUDA_MAX_DEVICES] = { {0, 0, false, 0} }; +static cuda_device_capabilities g_device_caps[GGML_CUDA_MAX_DEVICES] = { {0, 0, 0, 0, false, 0} }; static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr}; @@ -2290,6 +2298,192 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest reinterpret_cast(y[ib].ds.y) = sum; } +template +static __global__ void convert_q8_0_to_i8( + const void * __restrict__ vx, int * __restrict__ y_qs_low, float * __restrict__ y_d, const int kx_par) { + + const int kx = kx_template == 0 ? kx_par : kx_template; + const int nint = kx*sizeof(block_q8_0)/(QK8_0*sizeof(int)); + + typedef cuda::barrier cuda_barrier; + const cuda::aligned_size_t<4*sizeof(int)> as(4*sizeof(int)); + + extern __shared__ float data_q8_0_i8[]; + cuda_barrier * barrier = (cuda_barrier *) data_q8_0_i8; + float * buf_iw = (data_q8_0_i8 + 32); + int * vals = (int *) (buf_iw + WARP_SIZE); + const block_q8_0 * valsb = (const block_q8_0 *) vals; + + if (threadIdx.x == 0) { + init(barrier, block_size); + } + if (threadIdx.x < WARP_SIZE) { + buf_iw[threadIdx.x] = 0.0f; + } + __syncthreads(); + + const int iy = blockDim.y*blockIdx.y + threadIdx.y; + + const int * x = (const int *) vx; + char4 * y_qsc = (char4 *) y_qs_low; + +#pragma unroll + for (int ix0 = 0; ix0 < nint; ix0 += 4*block_size) { + const int ix = ix0 + 4*threadIdx.x; + + if (ix >= nint) { + break; + } + + cuda::memcpy_async(&vals[ix], &x[iy*nint + ix], as, *barrier); + } + + barrier->arrive_and_wait(); + + float amax = 0.0f; +#pragma unroll + for (int ix0 = 0; ix0 < kx/QK8_0; ix0 += block_size) { + const int ix = ix0 + threadIdx.x; + + if (ix >= kx/QK8_0) { + break; + } + + const float d = __half2float(valsb[ix].d); + amax = max(amax, fabsf(d)); + } + + amax = warp_reduce_max(amax); + if (threadIdx.x % WARP_SIZE == 0) { + buf_iw[threadIdx.x / WARP_SIZE] = amax; + } + __syncthreads(); + amax = buf_iw[threadIdx.x % WARP_SIZE]; + amax = warp_reduce_max(amax); + +#pragma unroll + for (int ix0 = 0; ix0 < kx/sizeof(int); ix0 += block_size) { + const int ix = ix0 + threadIdx.x; + + if (ix >= kx/sizeof(int)) { + break; + } + + const block_q8_0 * bxi = valsb + ix/QI8_0; + const float scale = __half2float(bxi->d) / amax; + const int xi = get_int_from_int8(bxi->qs, ix % QI8_0); + const int8_t * xi8 = (const int8_t *) ξ + + int result32[4]; +#pragma unroll + for (int l = 0; l < 4; ++l) { + result32[l] = roundf(xi8[l] * scale); + } + + y_qsc[iy*kx/sizeof(int) + ix] = make_char4(result32[0], result32[1], result32[2], result32[3]); + } + + if (threadIdx.x > 0) { + return; + } + + y_d[iy] = amax; +} + +#define I8_MAX_FRAG_SCALE 256.0f + +template +static __global__ void convert_float_to_i8( + const float * __restrict__ x, int * __restrict__ y_qs_low, int * __restrict__ y_bs, float * __restrict__ y_d, const int kx) { + + extern __shared__ float data_convert_float_to_i8[]; + float * buf_iw = data_convert_float_to_i8; + half * valsh = (half *) (buf_iw + WARP_SIZE); + + + const int iy0 = 8*(blockDim.y*blockIdx.y + threadIdx.y); + + int8_t * qs_low = (int8_t *) y_qs_low; + + float amax_row[8] = {0.0f}; + +#pragma unroll + for (int j = 0; j < 8; ++j) { + const int iy = iy0 + j; + + float amax = 0.0f; + for (int ix0 = 0; ix0 < kx; ix0 += blockDim.x) { + const int ix = ix0 + threadIdx.x; + + if (ix >= kx) { + break; + } + + const float xi = x[iy*kx + ix]; + amax = max(amax, fabsf(xi)); + + if (j < nrows_smem) { + valsh[j*kx + ix] = xi; + } + } + + amax = warp_reduce_max(amax); + if (threadIdx.x % WARP_SIZE == 0) { + buf_iw[threadIdx.x / WARP_SIZE] = amax; + } + __syncthreads(); + amax = buf_iw[threadIdx.x % WARP_SIZE]; + amax = warp_reduce_max(amax); + + amax_row[j] = amax; + + if (threadIdx.x == 0) { + y_d[iy] = amax_row[j] / 127; + } + __syncthreads(); + } + + for (int ix0 = 0; ix0 < kx; ix0 += blockDim.x) { + const int ix = ix0 + threadIdx.x; + + float rmax = 0.0f; + float valsi[8]; + +#pragma unroll + for (int j = 0; j < 8; ++j) { + const int iy = iy0 + j; + + const float xi = ix >= kx ? 0.0f : (j < nrows_smem ? __half2float(valsh[j*kx + ix]) : x[iy*kx + ix]); + rmax = max(rmax, fabsf(xi) / amax_row[j]); + valsi[j] = xi; + } + +#pragma unroll + for (int mask = 8; mask > 0; mask >>= 1) { + rmax = max(rmax, __shfl_xor_sync(0xFFFFFFFF, rmax, mask, 32)); + } + const int bs = roundf(rmax * I8_MAX_FRAG_SCALE); + + if (ix >= kx) { + break; + } + +#pragma unroll + for (int j = 0; j < 8; ++j) { + const int iy = iy0 + j; + + const float xi = valsi[j]; + const int q = rmax == 0.0f ? 0 : roundf(xi * 127 / ((frag_scales ? rmax : 1.0f) * amax_row[j])); + + qs_low[iy*kx + ix] = q; + } + + if (frag_scales && ix % 16 == 0) { + y_bs[(iy0/8)*(kx/16) + ix/16] = bs; + } + } +} + template static __global__ void k_get_rows( const void * src0, const int32_t * src1, dst_t * dst, @@ -5116,6 +5310,248 @@ template static __global__ void #endif // __CUDA_ARCH__ >= CC_VOLTA } +#define MMI8_PADDING 4 +#define MMI8_TILE_STRIDE (2*WARP_SIZE + MMI8_PADDING) +#define MMI8_COPY_SIZE 2 +#define MMI8_N_BARRIERS 8 +static_assert(MMI8_N_BARRIERS < WARP_SIZE, "Max. 32 barrier support implemented."); + +#define MMI8_X_AMPERE 64 +#define MMI8_Y_AMPERE 144 +#define MMI8_NWARPS_AMPERE 4 + +typedef nvcuda::wmma::fragment frag_thin_a; +typedef nvcuda::wmma::fragment frag_thin_b; +typedef nvcuda::wmma::fragment frag_thin_c; + +typedef cuda::barrier cuda_barrier; + +template +static __device__ __forceinline__ void load_tiles_i8( + const int * __restrict__ x_qs_low, const int * __restrict__ y_qs_low, const int * __restrict__ y_bs, + int * __restrict__ tile_x_qs, int * __restrict__ tile_y_qs, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int k0, + cuda_barrier * barriers, const int ib_next) { + + constexpr int nwarps = mmi8_y/32; + const cuda::aligned_size_t as(MMI8_COPY_SIZE*sizeof(int)); + const cuda::aligned_size_t as4(sizeof(int)); + + if (k0 % (4*WARP_SIZE) == 0) { + const int ib_next_4 = (ib_next + 3) % MMI8_N_BARRIERS; +#pragma unroll + for (int j0 = 0; j0 < mmi8_x; j0 += 8*nwarps) { + if (j0 + 8*nwarps > mmi8_x && j0 + 8*threadIdx.y > mmi8_x-8) { + break; + } + + const int j_bs = (blockIdx.y*(mmi8_x/8) + j0/8 + threadIdx.y); + const int k_bs = k0/4 + prefetch*WARP_SIZE + threadIdx.x; + + const int j_tile = (ib_next_4/4)*mmi8_x + j0 + 8*threadIdx.y + threadIdx.x/4; + const int k_tile = 2*WARP_SIZE + threadIdx.x % 4; + + cuda::memcpy_async(&tile_x_qs[j_tile*MMI8_TILE_STRIDE + k_tile], + &y_bs[j_bs*(nrows_y/16) + k_bs], as4, barriers[ib_next_4]); + } + } + +#pragma unroll + for (int i0 = 0; i0 < mmi8_y; i0 += nwarps*MMI8_COPY_SIZE) { + const int x = MMI8_COPY_SIZE*(threadIdx.x % (WARP_SIZE/MMI8_COPY_SIZE)); + const int y = MMI8_COPY_SIZE*threadIdx.y + threadIdx.x/(WARP_SIZE/MMI8_COPY_SIZE); + + const int i_tile = i0 + y; + int i_qs = blockIdx.x*mmi8_y + i_tile; + if (need_check_x) { + i_qs = min(i_qs, nrows_x-1); + } + + const int index_tile = i_tile * MMI8_TILE_STRIDE + (ib_next % 2)*WARP_SIZE + x; + const int index_qs = i_qs * (ncols_x/sizeof(int)) + k0 + prefetch*WARP_SIZE + x; + cuda::memcpy_async(&tile_x_qs[index_tile], &x_qs_low[index_qs], as, barriers[ib_next]); + } + +#pragma unroll + for (int j0 = 0; j0 < mmi8_x; j0 += nwarps*MMI8_COPY_SIZE) { + const int x = MMI8_COPY_SIZE*(threadIdx.x % (WARP_SIZE/MMI8_COPY_SIZE)); + const int y = MMI8_COPY_SIZE*threadIdx.y + threadIdx.x/(WARP_SIZE/MMI8_COPY_SIZE); + + int j_tile = j0 + y; + if (j0 + nwarps*MMI8_COPY_SIZE > mmi8_x) { + j_tile = min(j_tile, mmi8_x-1); + } + + int j_qs = blockIdx.y*mmi8_x + j_tile; + if (need_check_y) { + j_qs = min(j_qs, ncols_y-1); + } + + const int index_tile = j_tile * MMI8_TILE_STRIDE + (ib_next % 2)*WARP_SIZE + x; + const int index_qs = j_qs * (nrows_y/sizeof(int)) + k0 + prefetch*WARP_SIZE + x; + cuda::memcpy_async(&tile_y_qs[index_tile], &y_qs_low[index_qs], as, barriers[ib_next]); + } +} + +template +static __device__ __forceinline__ void vec_dot_i8( + const int * __restrict__ tile_x_qs, const int * __restrict__ tile_y_qs, frag_thin_c * fc, + const int k0, const int ib_current) { + + +#pragma unroll + for (int k = 0; k < 32; k += 16/sizeof(int)) { + frag_thin_a fa; + + const int ibs = ((k0/WARP_SIZE) % 4)*8 + k/4; + + nvcuda::wmma::load_matrix_sync( + fa, (int8_t *) &tile_x_qs[threadIdx.y*(32*MMI8_TILE_STRIDE) + (ib_current % 2)*WARP_SIZE + k], + MMI8_TILE_STRIDE*sizeof(int)); + +#pragma unroll + for (int j = 0; j < mmi8_x; j += 8) { + frag_thin_b fb; + frag_thin_c fc_tmp; + + const int bs = tile_x_qs[((ib_current/4)*mmi8_x + j + ibs/4)*MMI8_TILE_STRIDE + 2*WARP_SIZE + ibs % 4]; + nvcuda::wmma::load_matrix_sync( + fb, (int8_t *) &tile_y_qs[j*MMI8_TILE_STRIDE + (ib_current % 2)*WARP_SIZE + k], + MMI8_TILE_STRIDE*sizeof(int)); + + nvcuda::wmma::fill_fragment(fc_tmp, 0); + nvcuda::wmma::mma_sync(fc_tmp, fa, fb, fc_tmp); +#pragma unroll + for (int l = 0; l < 32*8/WARP_SIZE; ++l) { + fc[j/8].x[l] += bs * fc_tmp.x[l]; + } + } + } +} + +// Set launch bounds based on available SRAM: +#define MMI8_LAUNCH_BOUNDS(kiB) __launch_bounds__(WARP_SIZE*mmi8_y/32, (kiB)*1024 / (512 + sizeof(int)*(mmi8_x + mmi8_y)*MMI8_TILE_STRIDE)) + +template +#if __CUDA_ARCH__ >= CC_HOPPER +MMI8_LAUNCH_BOUNDS(228) +#elif __CUDA_ARCH__ >= CC_ADA_LOVELACE +MMI8_LAUNCH_BOUNDS(100) +#elif __CUDA_ARCH__ >= 870 // Jetson +MMI8_LAUNCH_BOUNDS(164) +#elif __CUDA_ARCH__ >= 860 // Ampere consumer +MMI8_LAUNCH_BOUNDS(100) +#elif __CUDA_ARCH__ >= CC_AMPERE // Ampere A100 +MMI8_LAUNCH_BOUNDS(164) +#elif __CUDA_ARCH__ >= CC_TURING +MMI8_LAUNCH_BOUNDS(64) +#else // Volta, Jetson +MMI8_LAUNCH_BOUNDS(96) +#endif +static __global__ void mul_mat_i8( + const int * __restrict__ x_qs_low, const float * x_d, const int * __restrict__ y_qs_low, const int * __restrict__ y_bs, + const float * y_d, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int ncols_y, + const int nrows_y, const int nrows_dst) { + +// #if __CUDA_ARCH__ >= CC_VOLTA && !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) + + constexpr int nwarps = mmi8_y/32; + + extern __shared__ char data_mmi8[]; + cuda_barrier * barriers = (cuda_barrier *) data_mmi8; + int * tile_x_qs = (int *) (data_mmi8 + 512); + int * tile_y_qs = tile_x_qs + mmi8_y*MMI8_TILE_STRIDE; + + if (threadIdx.x < MMI8_N_BARRIERS && threadIdx.y == 0) { + init(&barriers[threadIdx.x], nwarps*WARP_SIZE); + } + __syncthreads(); + + const int & ncols_dst = ncols_y; + + frag_thin_c fc[mmi8_x/8]; + + { + constexpr int k0 = 0; + constexpr int ib_next = 0; + constexpr int prefetch = 0; + + load_tiles_i8( + x_qs_low, y_qs_low, y_bs, tile_x_qs, tile_y_qs, + ncols_x, nrows_x, nrows_y, ncols_y, k0, barriers, ib_next); + } + +#pragma unroll + for (int j = 0; j < mmi8_x; j += 8) { + nvcuda::wmma::fill_fragment(fc[j/8], 0); + } + + for (int k0 = 0; k0 < ncols_x/sizeof(int) - WARP_SIZE; k0 += WARP_SIZE) { + const int ib_current = (k0/WARP_SIZE + 0) % MMI8_N_BARRIERS; + const int ib_next = (k0/WARP_SIZE + 1) % MMI8_N_BARRIERS; + constexpr int prefetch = 1; + + load_tiles_i8( + x_qs_low, y_qs_low, y_bs, tile_x_qs, tile_y_qs, + ncols_x, nrows_x, nrows_y, ncols_y, k0, barriers, ib_next); + + barriers[ib_current].arrive_and_wait(); + + vec_dot_i8(tile_x_qs, tile_y_qs, fc, k0, ib_current); + + __syncthreads(); + } + + { + const int k0 = ncols_x/sizeof(int) - WARP_SIZE; + const int ib_current = (k0/WARP_SIZE + 0) % MMI8_N_BARRIERS; + + barriers[ib_current].arrive_and_wait(); + + vec_dot_i8(tile_x_qs, tile_y_qs, fc, k0, ib_current); + + __syncthreads(); + } + + int * tmp_fc = tile_x_qs + threadIdx.y*(32*8); + float * tmp_d_j = ((float *) tile_y_qs) + threadIdx.y*WARP_SIZE; + + const int row_dst = blockIdx.x*mmi8_y + 32*threadIdx.y + threadIdx.x; + const float d_i = x_d[row_dst] / I8_MAX_FRAG_SCALE; +#pragma unroll + for (int j0 = 0; j0 < mmi8_x; j0 += 8) { + + nvcuda::wmma::store_matrix_sync(tmp_fc, fc[j0/8], 32, nvcuda::wmma::mem_col_major); + + if ((mmi8_y % WARP_SIZE != 0 && 32*threadIdx.y + threadIdx.x >= mmi8_y) || (need_check_x && row_dst >= nrows_dst)) { + continue; + } + + if (j0 % WARP_SIZE == 0) { + const int col_dst = blockIdx.y*mmi8_x + j0 + threadIdx.x; + tmp_d_j[threadIdx.x] = y_d[col_dst]; + } + +#pragma unroll + for (int l = 0; l < 32*8; l += WARP_SIZE) { + const int col_dst = blockIdx.y*mmi8_x + j0 + l/32; + + if (need_check_y && col_dst >= ncols_dst) { + continue; + } + + const float d_j = tmp_d_j[(j0 + l/32) % WARP_SIZE]; + dst[col_dst*nrows_dst + row_dst] = tmp_fc[l + threadIdx.x] * d_i*d_j; + } + } +// #else +// (void)x_qs_low;(void)x_d;(void)y_qs_low;(void)y_qs_high;(void)y_d;(void)dst; +// (void)ncols_x;(void)nrows_x;(void)ncols_y;(void)nrows_y;(void)nrows_dst; +// bad_arch(); +// #endif // __CUDA_ARCH__ >= CC_VOLTA && !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +} + + template static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows) { const int row = blockIdx.x*blockDim.y + threadIdx.y; @@ -6273,6 +6709,99 @@ static void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, con quantize_q8_1<<>>(x, vy, kx, kx_padded); } +template +static void convert_float_to_i8_cuda( + const float * x, int * y_qs_low, int * y_bs, float * y_d, const int kx, const int ky, cudaStream_t stream) { + + GGML_ASSERT(ky % 8 == 0); + const dim3 num_blocks(1, ky/8, 1); + const dim3 block_size(1024, 1, 1); + + int id; + CUDA_CHECK(cudaGetDevice(&id)); + + const int smempb = g_device_caps[id].smempb; + static bool smem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; + if (!smem_limit_raised[id]) { + CUDA_CHECK(cudaFuncSetAttribute(convert_float_to_i8, cudaFuncAttributeMaxDynamicSharedMemorySize, smempb)); + CUDA_CHECK(cudaFuncSetAttribute(convert_float_to_i8, cudaFuncAttributeMaxDynamicSharedMemorySize, smempb)); + CUDA_CHECK(cudaFuncSetAttribute(convert_float_to_i8, cudaFuncAttributeMaxDynamicSharedMemorySize, smempb)); + CUDA_CHECK(cudaFuncSetAttribute(convert_float_to_i8, cudaFuncAttributeMaxDynamicSharedMemorySize, smempb)); + CUDA_CHECK(cudaFuncSetAttribute(convert_float_to_i8, cudaFuncAttributeMaxDynamicSharedMemorySize, smempb)); + CUDA_CHECK(cudaFuncSetAttribute(convert_float_to_i8, cudaFuncAttributeMaxDynamicSharedMemorySize, smempb)); + CUDA_CHECK(cudaFuncSetAttribute(convert_float_to_i8, cudaFuncAttributeMaxDynamicSharedMemorySize, smempb)); + CUDA_CHECK(cudaFuncSetAttribute(convert_float_to_i8, cudaFuncAttributeMaxDynamicSharedMemorySize, smempb)); + CUDA_CHECK(cudaFuncSetAttribute(convert_float_to_i8, cudaFuncAttributeMaxDynamicSharedMemorySize, smempb)); + } + const int nrows_smem = smempb / (WARP_SIZE*sizeof(int) + kx*sizeof(half)); + + switch (nrows_smem) { + case 0: + convert_float_to_i8<<>> + (x, y_qs_low, y_bs, y_d, kx); + break; + case 1: + convert_float_to_i8<<>> + (x, y_qs_low, y_bs, y_d, kx); + break; + case 2: + convert_float_to_i8<<>> + (x, y_qs_low, y_bs, y_d, kx); + break; + case 3: + convert_float_to_i8<<>> + (x, y_qs_low, y_bs, y_d, kx); + break; + case 4: + convert_float_to_i8<<>> + (x, y_qs_low, y_bs, y_d, kx); + break; + case 5: + convert_float_to_i8<<>> + (x, y_qs_low, y_bs, y_d, kx); + break; + case 6: + convert_float_to_i8<<>> + (x, y_qs_low, y_bs, y_d, kx); + break; + case 7: + convert_float_to_i8<<>> + (x, y_qs_low, y_bs, y_d, kx); + break; + default: + convert_float_to_i8<<>> + (x, y_qs_low, y_bs, y_d, kx); + break; + } +} + +static void convert_q8_0_to_i8_cuda(const void * x, int * y_qs_low, float * y_d, const int kx, const int ky, cudaStream_t stream) { + const dim3 num_blocks(1, ky, 1); + const size_t smem_vals = kx*ggml_type_size(GGML_TYPE_Q8_0)/ggml_blck_size(GGML_TYPE_Q8_0); + GGML_ASSERT(smem_vals % (4*sizeof(int)) == 0); + const size_t smem_barrier = 128; // actually only need 8 bytes but pad to 128 for alignment + const size_t smem = smem_vals + smem_barrier; + + switch (kx) { + case 4096: + convert_q8_0_to_i8<128, 4096><<>>(x, y_qs_low, y_d, kx); + break; + case 5120: + convert_q8_0_to_i8<128, 5120><<>>(x, y_qs_low, y_d, kx); + break; + case 11008: + convert_q8_0_to_i8<512, 11008><<>>(x, y_qs_low, y_d, kx); + break; + case 13824: + convert_q8_0_to_i8<512, 13824><<>>(x, y_qs_low, y_d, kx); + break; + default: + fprintf(stderr, "%d\n", kx); + convert_q8_0_to_i8<256, 0><<>>(x, y_qs_low, y_d, kx); + break; + } +} + template static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) { const int num_blocks = (k + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE); @@ -7052,6 +7581,157 @@ static void ggml_mul_mat_q5_K_q8_1_cuda( } } +static const uint32_t g_mmi8_configs[6] = {0x00068100, 0x00040100, 0x00080080, 0x00040080, 0x00040040, 0x00020040}; + +static float get_mmi8_config_score(const uint32_t config, const int nrows_x, const int ncols_y) { + const int64_t mmi8_x = (config >> 12) & 0x00000FFF; + const int64_t mmi8_y = (config >> 0) & 0x00000FFF; + const int64_t nwarps = mmi8_y/32; + const int64_t smem = 512 + sizeof(int)*(mmi8_x + mmi8_y)*MMI8_TILE_STRIDE; + + int id; + CUDA_CHECK(cudaGetDevice(&id)); + + if (smem > g_device_caps[id].smempb) { + return 0.0f; + } + + const int64_t blocks_per_sm = g_device_caps[id].smem / smem; + + if (blocks_per_sm == 0) { + return 0.0f; + } + + float score = 1.0f; + + const int64_t grid_x = (nrows_x + mmi8_y - 1) / mmi8_y; + score *= nrows_x; + score /= grid_x*mmi8_y; + + const int64_t grid_y = (ncols_y + mmi8_x - 1) / mmi8_x; + score *= ncols_y; + score /= grid_y*mmi8_x; + + if (nrows_x % mmi8_y == 0) { + score *= 1.03; + } + if (ncols_y % mmi8_x == 0) { + score *= 1.03; + } + + score *= mmi8_x*mmi8_y; + score /= mmi8_x*mmi8_y + 8196; + + const int64_t nsm = g_device_caps[id].nsm; + const int64_t nblocks = grid_x*grid_y; + const int64_t nwaves = (nblocks + nsm*blocks_per_sm - 1) / (nsm*blocks_per_sm); + score *= nblocks; + score /= nsm*blocks_per_sm; + score /= nwaves; + + if (mmi8_x > ncols_y) { + score -= mmi8_x - ncols_y; + } + + return score; +} + +#define MMI8_SMEM(mmi8_x, mmi8_y) (512 + sizeof(int)*(0x##mmi8_x + 0x##mmi8_y)*MMI8_TILE_STRIDE) + +#define MMI8_SWITCH_CASE(mmi8_x, mmi8_y) \ + case 0x00##mmi8_x##mmi8_y: \ + mul_mat_i8<0x##mmi8_x, 0x##mmi8_y, false, false> \ + <<>> \ + (x_qs_low, x_d, y_qs_low, y_bs, y_d, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); \ + break; \ + case 0x01##mmi8_x##mmi8_y: \ + mul_mat_i8<0x##mmi8_x, 0x##mmi8_y, true, false> \ + <<>> \ + (x_qs_low, x_d, y_qs_low, y_bs, y_d, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); \ + break; \ + case 0x02##mmi8_x##mmi8_y: \ + mul_mat_i8<0x##mmi8_x, 0x##mmi8_y, false, true> \ + <<>> \ + (x_qs_low, x_d, y_qs_low, y_bs, y_d, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); \ + break; \ + case 0x03##mmi8_x##mmi8_y: \ + mul_mat_i8<0x##mmi8_x, 0x##mmi8_y, true, true> \ + <<>> \ + (x_qs_low, x_d, y_qs_low, y_bs, y_d, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); \ + break; \ + +#define MMI8_RAISE_SMEM_LIMIT(mmi8_x, mmi8_y) \ + CUDA_CHECK(cudaFuncSetAttribute(mul_mat_i8<0x##mmi8_x, 0x##mmi8_y, false, false>, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, g_device_caps[id].smempb)); \ + CUDA_CHECK(cudaFuncSetAttribute(mul_mat_i8<0x##mmi8_x, 0x##mmi8_y, true, false>, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, g_device_caps[id].smempb)); \ + CUDA_CHECK(cudaFuncSetAttribute(mul_mat_i8<0x##mmi8_x, 0x##mmi8_y, false, true>, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, g_device_caps[id].smempb)); \ + CUDA_CHECK(cudaFuncSetAttribute(mul_mat_i8<0x##mmi8_x, 0x##mmi8_y, true, true>, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, g_device_caps[id].smempb)); \ + +static void ggml_mul_mat_i8_cuda( + const int * x_qs_low, const float * x_d, const int * y_qs_low, const int * y_bs, const float * y_d, float * dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { + + int id; + CUDA_CHECK(cudaGetDevice(&id)); + + static bool smem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; + if (!smem_limit_raised[id]) { + MMI8_RAISE_SMEM_LIMIT(068, 100) // 104x256 + MMI8_RAISE_SMEM_LIMIT(040, 100) // 64x256 + MMI8_RAISE_SMEM_LIMIT(080, 080) // 128x128 + MMI8_RAISE_SMEM_LIMIT(040, 080) // 64x128 + MMI8_RAISE_SMEM_LIMIT(040, 040) // 64x64 + MMI8_RAISE_SMEM_LIMIT(020, 040) // 32x64 + + smem_limit_raised[id] = true; + } + + uint32_t best_config = 0; + float best_score = 0.0f; + for (uint64_t i = 0; i < sizeof(g_mmi8_configs)/sizeof(g_mmi8_configs[0]); ++i) { + const uint32_t config = g_mmi8_configs[i]; + const float score = get_mmi8_config_score(config, nrows_x, ncols_y); + + if (score > best_score) { + best_config = config; + best_score = score; + } + } + GGML_ASSERT(best_config != 0); + + const int mmi8_x = (best_config >> 12) & 0x00000FFF; + const int mmi8_y = (best_config >> 0) & 0x00000FFF; + const int nwarps = mmi8_y/32; + GGML_ASSERT(mmi8_x <= mmi8_y); // Otherwise not enough space for fragment scales. + + const int block_num_x = (nrows_x + mmi8_y - 1) / mmi8_y; + const int block_num_y = (ncols_y + mmi8_x - 1) / mmi8_x; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, nwarps, 1); + + if (nrows_x % mmi8_y != 0) { + best_config |= 0x01000000; + } + if (ncols_y % mmi8_x != 0) { + best_config |= 0x02000000; + } + + switch (best_config) { + MMI8_SWITCH_CASE(068, 100) // 104x256 + MMI8_SWITCH_CASE(040, 100) // 64x256 + MMI8_SWITCH_CASE(080, 080) // 128x128 + MMI8_SWITCH_CASE(040, 080) // 64x128 + MMI8_SWITCH_CASE(040, 040) // 64x64 + MMI8_SWITCH_CASE(020, 040) // 32x64 + default: + GGML_ASSERT(false); + break; + } +} + static void ggml_mul_mat_q6_K_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { @@ -7290,7 +7970,7 @@ static void soft_max_f16_cuda(const float * x, const float * y, float * dst, con const dim3 block_nums(nrows_x, 1, 1); const size_t shmem = (GGML_PAD(ncols_x, 2*WARP_SIZE) + WARP_SIZE)*sizeof(half); static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted."); - if (shmem <= g_device_caps[g_main_device].smpb) { + if (shmem <= g_device_caps[g_main_device].smempb) { switch (ncols_x) { case 32: soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale); @@ -7333,7 +8013,7 @@ static void soft_max_f32_cuda(const float * x, const float * y, float * dst, con const dim3 block_nums(nrows_x, 1, 1); const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float); static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted."); - if (shmem < g_device_caps[g_main_device].smpb) { + if (shmem < g_device_caps[g_main_device].smempb) { switch (ncols_x) { case 32: soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); @@ -7682,7 +8362,9 @@ GGML_CALL void ggml_init_cublas() { #else g_device_caps[id].cc = 100*prop.major + 10*prop.minor; #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) - g_device_caps[id].smpb = prop.sharedMemPerBlock; + g_device_caps[id].nsm = prop.multiProcessorCount; + g_device_caps[id].smem = prop.sharedMemPerMultiprocessor; + g_device_caps[id].smempb = prop.sharedMemPerBlockOptin; } for (int id = 0; id < g_device_count; ++id) { g_default_tensor_split[id] /= total_vram; @@ -8160,6 +8842,62 @@ static void ggml_cuda_op_mul_mat_q( (void) src1_ddf_i; } +static void ggml_cuda_op_mul_mat_i8( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, + const int64_t src1_padded_row_size, cudaStream_t stream) { + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + + const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + + const int64_t ne0 = dst->ne[0]; + + const int64_t row_diff = row_high - row_low; + + int id; + CUDA_CHECK(cudaGetDevice(&id)); + + // the main device has a larger memory buffer to hold the results from all GPUs + // nrows_dst == nrows of the matrix that the dequantize_mul_mat kernel writes into + const int64_t nrows_dst = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : row_diff; + + cuda_pool_alloc src0_ddi8(ne00*ne01 + ne01*sizeof(float)); + int * src0_qs_low = (int *) (src0_ddi8.get() + 0); + float * src0_d = (float *) (src0_ddi8.get() + ne00*ne01); + + cuda_pool_alloc src1_ddi8((ne10 + sizeof(float))*src1_ncols + sizeof(int)*(ne10/16)*(src1_ncols/8)); + int * src1_qs_low = (int *) (src1_ddi8.get() + 0); + int * src1_bs = (int *) (src1_ddi8.get() + ne10*src1_ncols); + float * src1_d = (float *) (src1_bs + (ne10/16) * (src1_ncols/8)); + + // efficient conversion to i8 only implemented for q8_0, convert to intermediary float as workaround + cuda_pool_alloc src0_workaround; + to_fp32_cuda_t to_fp32 = ggml_get_to_fp32_cuda(src0->type); + + switch (src0->type) { + case GGML_TYPE_Q8_0: + convert_q8_0_to_i8_cuda(src0_dd_i, src0_qs_low, src0_d, ne00, ne01, stream); + break; + default: + src0_workaround.alloc(ne00*ne01); + to_fp32(src0_dd_i, src0_workaround.get(), ne00*ne01, stream); + convert_float_to_i8_cuda(src0_workaround.get(), src0_qs_low, nullptr, src0_d, ne00, ne01, stream); + break; + } + + convert_float_to_i8_cuda(src1_ddf_i, src1_qs_low, src1_bs, src1_d, ne10, ne11, stream); + ggml_mul_mat_i8_cuda(src0_qs_low, src0_d, src1_qs_low, src1_bs, src1_d, dst_dd_i, + ne00, row_diff, src1_ncols, ne10, nrows_dst, stream); + + (void) src1; + (void) dst; + (void) src1_ddq_i; + (void) src1_padded_row_size; +} + static int64_t get_row_rounding(ggml_type type, const std::array & tensor_split) { int64_t min_compute_capability = INT_MAX; int64_t max_compute_capability = INT_MIN; @@ -9576,7 +10314,27 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1 if (use_mul_mat_q) { ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_q, true); } else { - ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false); +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) + const bool use_mmi8 = false; +#else + // const bool layer_0 = strncmp(src0->name+0, "blk.0.", 6) == 0; + // const bool layer_1 = strncmp(src0->name+0, "blk.1.", 6) == 0; + // const bool layer_2 = strncmp(src0->name+0, "blk.2.", 6) == 0; + // const bool attn_q = strncmp(src0->name+6, "attn_q", 6) == 0 || strncmp(src0->name+7, "attn_q", 6) == 0; + // const bool attn_k = strncmp(src0->name+6, "attn_k", 6) == 0 || strncmp(src0->name+7, "attn_k", 6) == 0; + // const bool attn_v = strncmp(src0->name+6, "attn_v", 6) == 0 || strncmp(src0->name+7, "attn_v", 6) == 0; + // const bool attn_output = strncmp(src0->name+6, "attn_output", 11) == 0 || strncmp(src0->name+7, "attn_output", 11) == 0; + // const bool ffn_up = strncmp(src0->name+6, "ffn_up", 6) == 0 || strncmp(src0->name+7, "ffn_up", 6) == 0; + // const bool ffn_gate = strncmp(src0->name+6, "ffn_gate", 8) == 0 || strncmp(src0->name+7, "ffn_gate", 8) == 0; + // const bool ffn_down = strncmp(src0->name+6, "ffn_down", 8) == 0 || strncmp(src0->name+7, "ffn_down", 8) == 0; + // const bool use_mmi8 = min_compute_capability >= CC_VOLTA && (attn_q || attn_k || attn_output || ffn_gate || ffn_down); + const bool use_mmi8 = min_compute_capability >= CC_VOLTA; +#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) + if (use_mmi8) { + ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_i8, false); + } else { + ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false); + } } } } else {