Skip to content

ggml : spread compute across threads in chunks #1507

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 80 additions & 46 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -3590,6 +3590,9 @@ struct ggml_compute_params {
// work buffer for all threads
size_t wsize;
void * wdata;

// atomic counter used to distribute chunks of work
atomic_int * aic;
};

//
Expand Down Expand Up @@ -9030,18 +9033,20 @@ static void ggml_compute_forward_rms_norm_f32(
GGML_ASSERT(ggml_are_same_shape(src0, dst));

if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
atomic_store(params->aic, 0);

return;
}

GGML_ASSERT(src0->nb[0] == sizeof(float));

const int ith = params->ith;
const int ith = params->ith; UNUSED(ith);
const int nth = params->nth;

const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];
const int64_t ne03 = src0->ne[3];
const int64_t ne03 = src0->ne[3]; UNUSED(ne03);

const size_t nb01 = src0->nb[1];
const size_t nb02 = src0->nb[2];
Expand All @@ -9053,30 +9058,45 @@ static void ggml_compute_forward_rms_norm_f32(

const float eps = 1e-6f; // TODO: make this a parameter

// TODO: optimize
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);

ggml_float sum = 0.0;
for (int64_t i00 = 0; i00 < ne00; i00++) {
sum += (ggml_float)(x[i00] * x[i00]);
}
const int nr = ggml_nrows(src0);
const int dr = (nr + 8*nth - 1)/(8*nth);

float mean = sum/ne00;
while (true) {
const int ir0 = atomic_fetch_add(params->aic, dr);

float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
for (int ir = ir0; ir < ir0 + dr; ++ir) {
if (ir >= nr) {
break;
}

memcpy(y, x, ne00 * sizeof(float));
// for (int i00 = 0; i00 < ne00; i00++) {
// y[i00] = x[i00];
// }
// src0 indices
const int i03 = ir/(ne02*ne01);
const int i02 = (ir - i03*ne02*ne01)/ne01;
const int i01 = (ir - i03*ne02*ne01 - i02*ne01);

const float scale = 1.0f/sqrtf(mean + eps);
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);

ggml_vec_scale_f32(ne00, y, scale);
ggml_float sum = 0.0;
for (int64_t i00 = 0; i00 < ne00; i00++) {
sum += (ggml_float)(x[i00] * x[i00]);
}

float mean = sum/ne00;

float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);

memcpy(y, x, ne00 * sizeof(float));
// for (int i00 = 0; i00 < ne00; i00++) {
// y[i00] = x[i00];
// }

const float scale = 1.0f/sqrtf(mean + eps);

ggml_vec_scale_f32(ne00, y, scale);
}

if (ir0 + dr >= nr) {
break;
}
}
}
Expand Down Expand Up @@ -9751,7 +9771,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
const int nb2 = dst->nb[2];
const int nb3 = dst->nb[3];

const int ith = params->ith;
const int ith = params->ith; UNUSED(ith);
const int nth = params->nth;

GGML_ASSERT(ne02 == ne12);
Expand Down Expand Up @@ -9867,50 +9887,57 @@ static void ggml_compute_forward_mul_mat_q_f32(
}
}

atomic_store(params->aic, 0);

return;
}

if (params->type == GGML_TASK_FINALIZE) {
return;
}

void * wdata = params->wdata;
const size_t row_size = ne00*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];

// parallelize by src0 rows using ggml_vec_dot_q

// total rows in src0
const int nr = ne01*ne02*ne03;
const int nr = ggml_nrows(src0);
const int dr = (nr + 8*nth - 1)/(8*nth);

// rows per thread
const int dr = (nr + nth - 1)/nth;
while (true) {
const int ir0 = atomic_fetch_add(params->aic, dr);

// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int ir = ir0; ir < ir0 + dr; ++ir) {
if (ir >= nr) {
break;
}

void * wdata = params->wdata;
const size_t row_size = ne00*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
// src0 indices
const int i03 = ir/(ne02*ne01);
const int i02 = (ir - i03*ne02*ne01)/ne01;
const int i01 = (ir - i03*ne02*ne01 - i02*ne01);

for (int ir = ir0; ir < ir1; ++ir) {
// src0 indices
const int i03 = ir/(ne02*ne01);
const int i02 = (ir - i03*ne02*ne01)/ne01;
const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
const int i13 = i03;
const int i12 = i02;

const int i13 = i03;
const int i12 = i02;
const int i0 = i01;
const int i2 = i02;
const int i3 = i03;

const int i0 = i01;
const int i2 = i02;
const int i3 = i03;
void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
char * src1_col = ((char *) wdata + ( (0 + i12*ne11 + i13*ne12*ne11)*row_size));

void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
char * src1_col = ((char *) wdata + ( (0 + i12*ne11 + i13*ne12*ne11)*row_size));
float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));

float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
assert(ne00 % 32 == 0);

assert(ne00 % 32 == 0);
for (int64_t ic = 0; ic < ne11; ++ic) {
vec_dot_q(ne00, &dst_col[ic*ne0], src0_row, (void *) (src1_col + ic*row_size));
}
}

for (int64_t ic = 0; ic < ne11; ++ic) {
vec_dot_q(ne00, &dst_col[ic*ne0], src0_row, (void *) (src1_col + ic*row_size));
if (ir0 + dr >= nr) {
break;
}
}

Expand Down Expand Up @@ -13749,6 +13776,7 @@ struct ggml_compute_state_shared {

// synchronization primitives
atomic_int n_ready;
atomic_int aic;
atomic_bool has_work;
atomic_bool stop; // stop all threads
};
Expand Down Expand Up @@ -13817,6 +13845,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
/*.spin =*/ GGML_LOCK_INITIALIZER,
/*.n_threads =*/ n_threads,
/*.n_ready =*/ 0,
/*.aic =*/ 0,
/*.has_work =*/ false,
/*.stop =*/ false,
};
Expand All @@ -13837,6 +13866,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
.nth = n_threads,
.wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0,
.wdata = cgraph->work ? cgraph->work->data : NULL,
.aic = &state_shared.aic,
},
.node = NULL,
.shared = &state_shared,
Expand Down Expand Up @@ -14126,6 +14156,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
/*.nth =*/ node->n_tasks,
/*.wsize =*/ cgraph->work ? ggml_nbytes(cgraph->work) : 0,
/*.wdata =*/ cgraph->work ? cgraph->work->data : NULL,
/*.aic =*/ &state_shared.aic,
};

ggml_compute_forward(&params, node);
Expand All @@ -14149,6 +14180,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
.nth = node->n_tasks,
.wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0,
.wdata = cgraph->work ? cgraph->work->data : NULL,
.aic = &state_shared.aic,
};
workers[j].node = node;
}
Expand All @@ -14164,6 +14196,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
}

params.type = GGML_TASK_COMPUTE;
params.aic = &state_shared.aic;
ggml_compute_forward(&params, node);

// wait for thread pool
Expand Down Expand Up @@ -14204,6 +14237,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
.nth = node->n_tasks,
.wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0,
.wdata = cgraph->work ? cgraph->work->data : NULL,
.aic = &state_shared.aic,
};
workers[j].node = node;
}
Expand Down