Skip to content

cuda: 1.2x faster dequantization kernel #2809

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 49 additions & 2 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4197,9 +4197,55 @@ static void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, con
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx, kx_padded);
}

#ifdef GGML_CUDA_F16
#define make_dfloat2(x, y) __halves2half2((x), (y))
#else
#define make_dfloat2(x, y) make_float2((x), (y))
#endif

static __device__ __forceinline__ dfloat2 dfmul2(dfloat2 a, dfloat2 b) {
#ifdef GGML_CUDA_F16
return __hmul2(a, b);
#else
return make_float2(a.x * b.x, a.y * b.y);
#endif
}

static __device__ __forceinline__ float2 dfloat22float2(dfloat2 a) {
#ifdef GGML_CUDA_F16
return __half22float2(a);
#else
return a;
#endif
}

static __global__ void dequantize_block_q4_0(const void * __restrict__ vx, float * __restrict__ y, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;

if (i*4 >= k) {
return;
}

const int ib = i/(QK4_0/4);
const int iqs = i%(QK4_0/4);

const block_q4_0 * x = (const block_q4_0 *) vx;
const uchar2 qs = *(const uchar2 *)(x[ib].qs + iqs*2);
const dfloat d = x[ib].d;

dfloat2 dv0 = make_dfloat2((int)(qs.x & 0xf) - 8, (int)(qs.y & 0xf) - 8);
const float2 v0 = dfloat22float2(dfmul2(dv0, {d, d}));
*(float2 *)(y + ib*QK4_0 + iqs*2) = v0;

dfloat2 dv1 = make_dfloat2((int)(qs.x >> 4) - 8, (int)(qs.y >> 4) - 8);
const float2 v1 = dfloat22float2(dfmul2(dv1, {d, d}));
*(float2 *)(y + ib*QK4_0 + QK4_0/2 + iqs*2) = v1;
}

static void dequantize_row_q4_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
dequantize_block<QK4_0, QR4_0, dequantize_q4_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
GGML_ASSERT(k % 4 == 0);
const int num_blocks = (k/4 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
dequantize_block_q4_0<<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
}

static void dequantize_row_q4_1_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
Expand Down Expand Up @@ -5711,6 +5757,7 @@ inline void ggml_cuda_op_alibi(
(void) src1;
(void) src0_ddq_i;
(void) src1_ddf_i;
(void) i02;
(void) i1;
}

Expand Down