Skip to content

Commit 321a061

Browse files
jeffbolznvmglambda
authored andcommitted
vulkan: support multi/vision rope, and noncontiguous rope (ggml-org#11902)
1 parent 6200a00 commit 321a061

File tree

7 files changed

+204
-41
lines changed

7 files changed

+204
-41
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,8 @@ struct vk_device_struct {
251251
vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
252252
vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16;
253253
vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
254+
vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
255+
vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
254256
vk_pipeline pipeline_argsort_f32;
255257
vk_pipeline pipeline_sum_rows_f32;
256258
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
@@ -494,6 +496,10 @@ struct vk_op_rope_push_constants {
494496
float corr_dims[2];
495497
float theta_scale;
496498
uint32_t has_ff;
499+
uint32_t ne02;
500+
uint32_t s1;
501+
uint32_t s2;
502+
int32_t sections[4];
497503
};
498504

499505
struct vk_op_soft_max_push_constants {
@@ -2180,13 +2186,19 @@ static void ggml_vk_load_shaders(vk_device& device) {
21802186

21812187
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
21822188
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2189+
ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32, "rope_multi_f32", rope_multi_f32_len, rope_multi_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2190+
ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f32, "rope_vision_f32", rope_vision_f32_len, rope_vision_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
21832191

21842192
if (device->float_controls_rte_fp16) {
21852193
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_rte_len, rope_norm_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
21862194
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_rte_len, rope_neox_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2195+
ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_rte_len, rope_multi_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2196+
ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_rte_len, rope_vision_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
21872197
} else {
21882198
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
21892199
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2200+
ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_len, rope_multi_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2201+
ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
21902202
}
21912203

21922204
ggml_vk_create_pipeline(device, device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1);
@@ -5307,6 +5319,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
53075319
{
53085320
const int mode = ((const int32_t *) dst->op_params)[2];
53095321
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
5322+
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
5323+
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
53105324

53115325
if (is_neox) {
53125326
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
@@ -5315,6 +5329,20 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
53155329
if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
53165330
return ctx->device->pipeline_rope_neox_f16;
53175331
}
5332+
} else if (is_mrope && !is_vision) {
5333+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5334+
return ctx->device->pipeline_rope_multi_f32;
5335+
}
5336+
if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
5337+
return ctx->device->pipeline_rope_multi_f16;
5338+
}
5339+
} else if (is_vision) {
5340+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5341+
return ctx->device->pipeline_rope_vision_f32;
5342+
}
5343+
if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
5344+
return ctx->device->pipeline_rope_vision_f16;
5345+
}
53185346
} else {
53195347
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
53205348
return ctx->device->pipeline_rope_norm_f32;
@@ -5385,6 +5413,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
53855413
case GGML_OP_CLAMP:
53865414
case GGML_OP_PAD:
53875415
case GGML_OP_REPEAT:
5416+
case GGML_OP_ROPE:
53885417
return true;
53895418
default:
53905419
return false;
@@ -6149,7 +6178,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
61496178

61506179
static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
61516180
const int n_dims = ((int32_t *) dst->op_params)[1];
6152-
// const int mode = ((int32_t *) dst->op_params)[2];
6181+
const int mode = ((int32_t *) dst->op_params)[2];
61536182
// const int n_ctx = ((int32_t *) dst->op_params)[3];
61546183
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
61556184
const float freq_base = ((float *) dst->op_params)[5];
@@ -6158,16 +6187,24 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
61586187
const float attn_factor = ((float *) dst->op_params)[8];
61596188
const float beta_fast = ((float *) dst->op_params)[9];
61606189
const float beta_slow = ((float *) dst->op_params)[10];
6190+
int sections[4] {};
6191+
if (mode & GGML_ROPE_TYPE_MROPE) {
6192+
memcpy(sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
6193+
}
61616194

61626195
float corr_dims[2];
61636196
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
61646197

61656198
const float theta_scale = powf(freq_base, -2.0f/n_dims);
61666199

6200+
uint32_t s1 = src0->nb[1] / ggml_type_size(src0->type);
6201+
uint32_t s2 = src0->nb[2] / ggml_type_size(src0->type);
6202+
61676203
ggml_vk_op_f32<vk_op_rope_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_ROPE, {
61686204
(uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
61696205
freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
6170-
src2 != nullptr,
6206+
src2 != nullptr, (uint32_t)src0->ne[2], s1, s2,
6207+
sections[0], sections[1], sections[2], sections[3],
61716208
}, dryrun);
61726209
}
61736210

@@ -8264,16 +8301,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
82648301
case GGML_OP_REPEAT:
82658302
return ggml_type_size(op->type) == sizeof(float) && ggml_type_size(op->src[0]->type) == sizeof(float);
82668303
case GGML_OP_ROPE:
8267-
{
8268-
const int mode = ((const int32_t *) op->op_params)[2];
8269-
if (mode & GGML_ROPE_TYPE_MROPE) {
8270-
return false;
8271-
}
8272-
if (mode & GGML_ROPE_TYPE_VISION) {
8273-
return false;
8274-
}
8275-
return ggml_is_contiguous(op->src[0]);
8276-
}
82778304
case GGML_OP_NONE:
82788305
case GGML_OP_RESHAPE:
82798306
case GGML_OP_VIEW:
@@ -8831,7 +8858,12 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
88318858
const float attn_factor = ((float *) tensor->op_params)[8];
88328859
const float beta_fast = ((float *) tensor->op_params)[9];
88338860
const float beta_slow = ((float *) tensor->op_params)[10];
8834-
tensor_clone = ggml_rope_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
8861+
if (mode & GGML_ROPE_TYPE_MROPE) {
8862+
int32_t *sections = ((int32_t *) tensor->op_params) + 11;
8863+
tensor_clone = ggml_rope_multi(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
8864+
} else {
8865+
tensor_clone = ggml_rope_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
8866+
}
88358867
} else if (tensor->op == GGML_OP_UNARY) {
88368868
switch (ggml_get_unary_op(tensor)) {
88378869
case GGML_UNARY_OP_SILU:

ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ layout (push_constant) uniform parameter {
2525
float corr_dims[2];
2626
float theta_scale;
2727
uint has_ff;
28+
uint ne02;
29+
uint s1;
30+
uint s2;
31+
int sections[4];
2832
} p;
2933

3034
float rope_yarn_ramp(const float low, const float high, const uint i0) {
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
#version 450
2+
3+
#include "rope_head.comp"
4+
5+
void main() {
6+
const uint i0 = 2*gl_GlobalInvocationID.y;
7+
uint ne0 = p.ncols;
8+
uint ne1 = p.p_delta_rows;
9+
uint ne2 = p.ne02;
10+
11+
if (i0 >= ne0) {
12+
return;
13+
}
14+
15+
const uint row_dst = gl_GlobalInvocationID.x;
16+
17+
if (i0 >= p.n_dims) {
18+
const uint i = row_dst*ne0 + i0;
19+
20+
data_d[i + 0] = data_a[i + 0];
21+
data_d[i + 1] = data_a[i + 1];
22+
23+
return;
24+
}
25+
26+
const uint row_x = row_dst % ne1;
27+
const uint channel_x = row_dst / ne1;
28+
29+
const uint idst = row_dst*ne0 + i0/2;
30+
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;
31+
32+
const int sect_dims = p.sections[0] + p.sections[1] + p.sections[2] + p.sections[3];
33+
const int sec_w = p.sections[1] + p.sections[0];
34+
const uint sector = (i0 / 2) % sect_dims;
35+
36+
float theta_base = 0.0;
37+
if (sector < p.sections[0]) {
38+
theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f);
39+
}
40+
else if (sector >= p.sections[0] && sector < sec_w) {
41+
theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
42+
}
43+
else if (sector >= sec_w && sector < sec_w + p.sections[2]) {
44+
theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
45+
}
46+
else if (sector >= sec_w + p.sections[2]) {
47+
theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
48+
}
49+
50+
const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
51+
52+
float cos_theta, sin_theta;
53+
rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta);
54+
55+
const float x0 = float(data_a[ix + 0]);
56+
const float x1 = float(data_a[ix + p.n_dims/2]);
57+
58+
data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
59+
data_d[idst + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta);
60+
}

ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,35 +3,41 @@
33
#include "rope_head.comp"
44

55
void main() {
6-
const uint col = gl_GlobalInvocationID.y * 2;
7-
const uint row = gl_GlobalInvocationID.x;
6+
const uint i0 = 2*gl_GlobalInvocationID.y;
7+
uint ne0 = p.ncols;
8+
uint ne1 = p.p_delta_rows;
89

9-
if (col >= p.ncols) {
10+
if (i0 >= ne0) {
1011
return;
1112
}
1213

13-
if (col >= p.n_dims) {
14-
const uint i = row*p.ncols + col;
14+
const uint row_dst = gl_GlobalInvocationID.x;
15+
16+
if (i0 >= p.n_dims) {
17+
const uint i = row_dst*ne0 + i0;
1518

1619
data_d[i + 0] = data_a[i + 0];
1720
data_d[i + 1] = data_a[i + 1];
1821

1922
return;
2023
}
2124

22-
const uint i = row*p.ncols + col/2;
23-
const uint i2 = row/p.p_delta_rows;
25+
const uint row_x = row_dst % ne1;
26+
const uint channel_x = row_dst / ne1;
27+
28+
const uint idst = row_dst*ne0 + i0/2;
29+
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;
2430

25-
const float theta_base = data_pos[i2] * pow(p.theta_scale, col/2.0f);
31+
const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f);
2632

27-
const float freq_factor = p.has_ff != 0 ? data_ff[col/2] : 1.0f;
33+
const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
2834

2935
float cos_theta, sin_theta;
30-
rope_yarn(theta_base / freq_factor, col, cos_theta, sin_theta);
36+
rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta);
3137

32-
const float x0 = float(data_a[i + 0]);
33-
const float x1 = float(data_a[i + p.n_dims/2]);
38+
const float x0 = float(data_a[ix + 0]);
39+
const float x1 = float(data_a[ix + p.n_dims/2]);
3440

35-
data_d[i + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
36-
data_d[i + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta);
41+
data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
42+
data_d[idst + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta);
3743
}

ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,35 +3,41 @@
33
#include "rope_head.comp"
44

55
void main() {
6-
const uint col = gl_GlobalInvocationID.y * 2;
7-
const uint row = gl_GlobalInvocationID.x;
6+
const uint i0 = 2*gl_GlobalInvocationID.y;
7+
uint ne0 = p.ncols;
8+
uint ne1 = p.p_delta_rows;
89

9-
if (col >= p.ncols) {
10+
if (i0 >= ne0) {
1011
return;
1112
}
1213

13-
if (col >= p.n_dims) {
14-
const uint i = row*p.ncols + col;
14+
const uint row_dst = gl_GlobalInvocationID.x;
15+
16+
if (i0 >= p.n_dims) {
17+
const uint i = row_dst*ne0 + i0;
1518

1619
data_d[i + 0] = data_a[i + 0];
1720
data_d[i + 1] = data_a[i + 1];
1821

1922
return;
2023
}
2124

22-
const uint i = row*p.ncols + col;
23-
const uint i2 = row/p.p_delta_rows;
25+
const uint row_x = row_dst % ne1;
26+
const uint channel_x = row_dst / ne1;
27+
28+
const uint idst = row_dst*ne0 + i0;
29+
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0;
2430

25-
const float theta_base = data_pos[i2] * pow(p.theta_scale, col/2.0f);
31+
const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f);
2632

27-
const float freq_factor = p.has_ff != 0 ? data_ff[col/2] : 1.0f;
33+
const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
2834

2935
float cos_theta, sin_theta;
30-
rope_yarn(theta_base / freq_factor, col, cos_theta, sin_theta);
36+
rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta);
3137

32-
const float x0 = float(data_a[i + 0]);
33-
const float x1 = float(data_a[i + 1]);
38+
const float x0 = float(data_a[ix + 0]);
39+
const float x1 = float(data_a[ix + 1]);
3440

35-
data_d[i + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
36-
data_d[i + 1] = D_TYPE(x0*sin_theta + x1*cos_theta);
41+
data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
42+
data_d[idst + 1] = D_TYPE(x0*sin_theta + x1*cos_theta);
3743
}
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#version 450
2+
3+
#include "rope_head.comp"
4+
5+
void main() {
6+
const uint i0 = 2*gl_GlobalInvocationID.y;
7+
uint ne0 = p.ncols;
8+
uint ne1 = p.p_delta_rows;
9+
uint ne2 = p.ne02;
10+
11+
if (i0 >= ne0) {
12+
return;
13+
}
14+
15+
const uint row_dst = gl_GlobalInvocationID.x;
16+
17+
const uint row_x = row_dst % ne1;
18+
const uint channel_x = row_dst / ne1;
19+
20+
const uint idst = row_dst*ne0 + i0/2;
21+
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;
22+
23+
const int sect_dims = p.sections[0] + p.sections[1];
24+
const int sec_w = p.sections[1] + p.sections[0];
25+
const uint sector = (i0 / 2) % sect_dims;
26+
27+
float theta_base = 0.0;
28+
if (sector < p.sections[0]) {
29+
const uint p0 = sector;
30+
theta_base = data_pos[channel_x]*pow(p.theta_scale, p0);
31+
}
32+
else if (sector >= p.sections[0] && sector < sec_w) {
33+
const uint p0 = sector - p.sections[0];
34+
theta_base = data_pos[channel_x + ne2]*pow(p.theta_scale, p0);
35+
}
36+
37+
const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
38+
39+
float cos_theta, sin_theta;
40+
rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta);
41+
42+
const float x0 = float(data_a[ix + 0]);
43+
const float x1 = float(data_a[ix + p.n_dims]);
44+
45+
data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
46+
data_d[idst + p.n_dims] = D_TYPE(x0*sin_theta + x1*cos_theta);
47+
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,14 @@ void process_shaders() {
491491
string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
492492
string_to_spv("rope_neox_f16_rte", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
493493

494+
string_to_spv("rope_multi_f32", "rope_multi.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
495+
string_to_spv("rope_multi_f16", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
496+
string_to_spv("rope_multi_f16_rte", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
497+
498+
string_to_spv("rope_vision_f32", "rope_vision.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
499+
string_to_spv("rope_vision_f16", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
500+
string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
501+
494502
string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
495503

496504
string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));

0 commit comments

Comments
 (0)