Skip to content

vulkan: support multi/vision rope, and noncontiguous rope #11902

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 45 additions & 13 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,8 @@ struct vk_device_struct {
vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16;
vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
vk_pipeline pipeline_argsort_f32;
vk_pipeline pipeline_sum_rows_f32;
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
Expand Down Expand Up @@ -494,6 +496,10 @@ struct vk_op_rope_push_constants {
float corr_dims[2];
float theta_scale;
uint32_t has_ff;
uint32_t ne02;
uint32_t s1;
uint32_t s2;
int32_t sections[4];
};

struct vk_op_soft_max_push_constants {
Expand Down Expand Up @@ -2180,13 +2186,19 @@ static void ggml_vk_load_shaders(vk_device& device) {

ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32, "rope_multi_f32", rope_multi_f32_len, rope_multi_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f32, "rope_vision_f32", rope_vision_f32_len, rope_vision_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);

if (device->float_controls_rte_fp16) {
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_rte_len, rope_norm_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_rte_len, rope_neox_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_rte_len, rope_multi_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_rte_len, rope_vision_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
} else {
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_len, rope_multi_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
}

ggml_vk_create_pipeline(device, device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1);
Expand Down Expand Up @@ -5307,6 +5319,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
{
const int mode = ((const int32_t *) dst->op_params)[2];
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;

if (is_neox) {
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
Expand All @@ -5315,6 +5329,20 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
return ctx->device->pipeline_rope_neox_f16;
}
} else if (is_mrope && !is_vision) {
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_rope_multi_f32;
}
if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
return ctx->device->pipeline_rope_multi_f16;
}
} else if (is_vision) {
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_rope_vision_f32;
}
if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
return ctx->device->pipeline_rope_vision_f16;
}
} else {
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_rope_norm_f32;
Expand Down Expand Up @@ -5385,6 +5413,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
case GGML_OP_CLAMP:
case GGML_OP_PAD:
case GGML_OP_REPEAT:
case GGML_OP_ROPE:
return true;
default:
return false;
Expand Down Expand Up @@ -6149,7 +6178,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,

static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
const int n_dims = ((int32_t *) dst->op_params)[1];
// const int mode = ((int32_t *) dst->op_params)[2];
const int mode = ((int32_t *) dst->op_params)[2];
// const int n_ctx = ((int32_t *) dst->op_params)[3];
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
const float freq_base = ((float *) dst->op_params)[5];
Expand All @@ -6158,16 +6187,24 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
const float attn_factor = ((float *) dst->op_params)[8];
const float beta_fast = ((float *) dst->op_params)[9];
const float beta_slow = ((float *) dst->op_params)[10];
int sections[4] {};
if (mode & GGML_ROPE_TYPE_MROPE) {
memcpy(sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
}

float corr_dims[2];
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);

const float theta_scale = powf(freq_base, -2.0f/n_dims);

uint32_t s1 = src0->nb[1] / ggml_type_size(src0->type);
uint32_t s2 = src0->nb[2] / ggml_type_size(src0->type);

ggml_vk_op_f32<vk_op_rope_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_ROPE, {
(uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
src2 != nullptr,
src2 != nullptr, (uint32_t)src0->ne[2], s1, s2,
sections[0], sections[1], sections[2], sections[3],
}, dryrun);
}

Expand Down Expand Up @@ -8264,16 +8301,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_REPEAT:
return ggml_type_size(op->type) == sizeof(float) && ggml_type_size(op->src[0]->type) == sizeof(float);
case GGML_OP_ROPE:
{
const int mode = ((const int32_t *) op->op_params)[2];
if (mode & GGML_ROPE_TYPE_MROPE) {
return false;
}
if (mode & GGML_ROPE_TYPE_VISION) {
return false;
}
return ggml_is_contiguous(op->src[0]);
}
case GGML_OP_NONE:
case GGML_OP_RESHAPE:
case GGML_OP_VIEW:
Expand Down Expand Up @@ -8831,7 +8858,12 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
const float attn_factor = ((float *) tensor->op_params)[8];
const float beta_fast = ((float *) tensor->op_params)[9];
const float beta_slow = ((float *) tensor->op_params)[10];
tensor_clone = ggml_rope_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
if (mode & GGML_ROPE_TYPE_MROPE) {
int32_t *sections = ((int32_t *) tensor->op_params) + 11;
tensor_clone = ggml_rope_multi(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
} else {
tensor_clone = ggml_rope_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
}
} else if (tensor->op == GGML_OP_UNARY) {
switch (ggml_get_unary_op(tensor)) {
case GGML_UNARY_OP_SILU:
Expand Down
4 changes: 4 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ layout (push_constant) uniform parameter {
float corr_dims[2];
float theta_scale;
uint has_ff;
uint ne02;
uint s1;
uint s2;
int sections[4];
} p;

float rope_yarn_ramp(const float low, const float high, const uint i0) {
Expand Down
60 changes: 60 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#version 450

#include "rope_head.comp"

void main() {
const uint i0 = 2*gl_GlobalInvocationID.y;
uint ne0 = p.ncols;
uint ne1 = p.p_delta_rows;
uint ne2 = p.ne02;

if (i0 >= ne0) {
return;
}

const uint row_dst = gl_GlobalInvocationID.x;

if (i0 >= p.n_dims) {
const uint i = row_dst*ne0 + i0;

data_d[i + 0] = data_a[i + 0];
data_d[i + 1] = data_a[i + 1];

return;
}

const uint row_x = row_dst % ne1;
const uint channel_x = row_dst / ne1;

const uint idst = row_dst*ne0 + i0/2;
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;

const int sect_dims = p.sections[0] + p.sections[1] + p.sections[2] + p.sections[3];
const int sec_w = p.sections[1] + p.sections[0];
const uint sector = (i0 / 2) % sect_dims;

float theta_base = 0.0;
if (sector < p.sections[0]) {
theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f);
}
else if (sector >= p.sections[0] && sector < sec_w) {
theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
}
else if (sector >= sec_w && sector < sec_w + p.sections[2]) {
theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
}
else if (sector >= sec_w + p.sections[2]) {
theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
}

const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;

float cos_theta, sin_theta;
rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta);

const float x0 = float(data_a[ix + 0]);
const float x1 = float(data_a[ix + p.n_dims/2]);

data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
data_d[idst + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta);
}
34 changes: 20 additions & 14 deletions ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp
Original file line number Diff line number Diff line change
Expand Up @@ -3,35 +3,41 @@
#include "rope_head.comp"

void main() {
const uint col = gl_GlobalInvocationID.y * 2;
const uint row = gl_GlobalInvocationID.x;
const uint i0 = 2*gl_GlobalInvocationID.y;
uint ne0 = p.ncols;
uint ne1 = p.p_delta_rows;

if (col >= p.ncols) {
if (i0 >= ne0) {
return;
}

if (col >= p.n_dims) {
const uint i = row*p.ncols + col;
const uint row_dst = gl_GlobalInvocationID.x;

if (i0 >= p.n_dims) {
const uint i = row_dst*ne0 + i0;

data_d[i + 0] = data_a[i + 0];
data_d[i + 1] = data_a[i + 1];

return;
}

const uint i = row*p.ncols + col/2;
const uint i2 = row/p.p_delta_rows;
const uint row_x = row_dst % ne1;
const uint channel_x = row_dst / ne1;

const uint idst = row_dst*ne0 + i0/2;
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;

const float theta_base = data_pos[i2] * pow(p.theta_scale, col/2.0f);
const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f);

const float freq_factor = p.has_ff != 0 ? data_ff[col/2] : 1.0f;
const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;

float cos_theta, sin_theta;
rope_yarn(theta_base / freq_factor, col, cos_theta, sin_theta);
rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta);

const float x0 = float(data_a[i + 0]);
const float x1 = float(data_a[i + p.n_dims/2]);
const float x0 = float(data_a[ix + 0]);
const float x1 = float(data_a[ix + p.n_dims/2]);

data_d[i + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
data_d[i + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta);
data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
data_d[idst + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta);
}
34 changes: 20 additions & 14 deletions ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp
Original file line number Diff line number Diff line change
Expand Up @@ -3,35 +3,41 @@
#include "rope_head.comp"

void main() {
const uint col = gl_GlobalInvocationID.y * 2;
const uint row = gl_GlobalInvocationID.x;
const uint i0 = 2*gl_GlobalInvocationID.y;
uint ne0 = p.ncols;
uint ne1 = p.p_delta_rows;

if (col >= p.ncols) {
if (i0 >= ne0) {
return;
}

if (col >= p.n_dims) {
const uint i = row*p.ncols + col;
const uint row_dst = gl_GlobalInvocationID.x;

if (i0 >= p.n_dims) {
const uint i = row_dst*ne0 + i0;

data_d[i + 0] = data_a[i + 0];
data_d[i + 1] = data_a[i + 1];

return;
}

const uint i = row*p.ncols + col;
const uint i2 = row/p.p_delta_rows;
const uint row_x = row_dst % ne1;
const uint channel_x = row_dst / ne1;

const uint idst = row_dst*ne0 + i0;
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0;

const float theta_base = data_pos[i2] * pow(p.theta_scale, col/2.0f);
const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f);

const float freq_factor = p.has_ff != 0 ? data_ff[col/2] : 1.0f;
const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;

float cos_theta, sin_theta;
rope_yarn(theta_base / freq_factor, col, cos_theta, sin_theta);
rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta);

const float x0 = float(data_a[i + 0]);
const float x1 = float(data_a[i + 1]);
const float x0 = float(data_a[ix + 0]);
const float x1 = float(data_a[ix + 1]);

data_d[i + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
data_d[i + 1] = D_TYPE(x0*sin_theta + x1*cos_theta);
data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
data_d[idst + 1] = D_TYPE(x0*sin_theta + x1*cos_theta);
}
47 changes: 47 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#version 450

#include "rope_head.comp"

void main() {
const uint i0 = 2*gl_GlobalInvocationID.y;
uint ne0 = p.ncols;
uint ne1 = p.p_delta_rows;
uint ne2 = p.ne02;

if (i0 >= ne0) {
return;
}

const uint row_dst = gl_GlobalInvocationID.x;

const uint row_x = row_dst % ne1;
const uint channel_x = row_dst / ne1;

const uint idst = row_dst*ne0 + i0/2;
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;

const int sect_dims = p.sections[0] + p.sections[1];
const int sec_w = p.sections[1] + p.sections[0];
const uint sector = (i0 / 2) % sect_dims;

float theta_base = 0.0;
if (sector < p.sections[0]) {
const uint p0 = sector;
theta_base = data_pos[channel_x]*pow(p.theta_scale, p0);
}
else if (sector >= p.sections[0] && sector < sec_w) {
const uint p0 = sector - p.sections[0];
theta_base = data_pos[channel_x + ne2]*pow(p.theta_scale, p0);
}

const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;

float cos_theta, sin_theta;
rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta);

const float x0 = float(data_a[ix + 0]);
const float x1 = float(data_a[ix + p.n_dims]);

data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
data_d[idst + p.n_dims] = D_TYPE(x0*sin_theta + x1*cos_theta);
}
8 changes: 8 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,14 @@ void process_shaders() {
string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("rope_neox_f16_rte", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});

string_to_spv("rope_multi_f32", "rope_multi.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("rope_multi_f16", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("rope_multi_f16_rte", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});

string_to_spv("rope_vision_f32", "rope_vision.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("rope_vision_f16", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});

string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});

string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
Expand Down
Loading