Skip to content

Commit 486d061

Browse files
committed
check abort_callback on main thread only
1 parent d27f26e commit 486d061

File tree

1 file changed

+89
-135
lines changed

1 file changed

+89
-135
lines changed

ggml.c

Lines changed: 89 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -1744,13 +1744,14 @@ struct ggml_compute_state_shared {
17441744
void * abort_callback_data;
17451745

17461746
atomic_int current_chunk; // currently processing chunk during mul_mat, shared between all the threads
1747+
1748+
enum ggml_status ec;
17471749
};
17481750

17491751
struct ggml_compute_state {
17501752
ggml_thread_t thrd;
17511753
int ith;
17521754
struct ggml_compute_state_shared * shared;
1753-
enum ggml_status ec;
17541755
};
17551756

17561757
struct ggml_compute_params {
@@ -3001,7 +3002,7 @@ void ggml_numa_init(enum ggml_numa_strategy numa_flag) {
30013002
}
30023003
}
30033004
#else
3004-
GGML_UNUSED(numa_flag);
3005+
UNUSED(numa_flag);
30053006
// TODO
30063007
#endif
30073008
}
@@ -15980,7 +15981,7 @@ static void ggml_compute_forward_unary(
1598015981
static void ggml_compute_forward_get_rel_pos_f16(
1598115982
const struct ggml_compute_params * params,
1598215983
struct ggml_tensor * dst) {
15983-
GGML_UNUSED(params);
15984+
UNUSED(params);
1598415985

1598515986
const struct ggml_tensor * src0 = dst->src[0];
1598615987

@@ -18317,8 +18318,8 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
1831718318
case GGML_UNARY_OP_ELU:
1831818319
case GGML_UNARY_OP_RELU:
1831918320
case GGML_UNARY_OP_SIGMOID:
18320-
case GGML_UNARY_OP_HARDSWISH: // to opt for multiple threads
18321-
case GGML_UNARY_OP_HARDSIGMOID: // to opt for multiple threads
18321+
case GGML_UNARY_OP_HARDSWISH:
18322+
case GGML_UNARY_OP_HARDSIGMOID:
1832218323
{
1832318324
n_tasks = 1;
1832418325
} break;
@@ -18341,24 +18342,16 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
1834118342
case GGML_OP_RMS_NORM_BACK:
1834218343
case GGML_OP_GROUP_NORM:
1834318344
case GGML_OP_CONCAT:
18344-
{
18345-
n_tasks = n_threads;
18346-
} break;
1834718345
case GGML_OP_MUL_MAT:
18348-
{
18349-
n_tasks = n_threads;
18350-
} break;
1835118346
case GGML_OP_MUL_MAT_ID:
18352-
{
18353-
n_tasks = n_threads;
18354-
} break;
1835518347
case GGML_OP_OUT_PROD:
1835618348
{
1835718349
n_tasks = n_threads;
1835818350
} break;
1835918351
case GGML_OP_GET_ROWS:
1836018352
{
18361-
// FIXME: the cost of launching additional threads decreases performance with GPU offloading
18353+
// FIXME: get_rows can use additional threads, but the cost of launching additional threads
18354+
// decreases performance with GPU offloading
1836218355
//n_tasks = n_threads;
1836318356
n_tasks = 1;
1836418357
} break;
@@ -18390,14 +18383,8 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
1839018383
{
1839118384
n_tasks = MIN(n_threads, ggml_nrows(node->src[0]));
1839218385
} break;
18393-
case GGML_OP_CONV_TRANSPOSE_1D:
18394-
{
18395-
n_tasks = n_threads;
18396-
} break;
1839718386
case GGML_OP_IM2COL:
18398-
{
18399-
n_tasks = n_threads;
18400-
} break;
18387+
case GGML_OP_CONV_TRANSPOSE_1D:
1840118388
case GGML_OP_CONV_TRANSPOSE_2D:
1840218389
{
1840318390
n_tasks = n_threads;
@@ -18408,33 +18395,12 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
1840818395
n_tasks = 1;
1840918396
} break;
1841018397
case GGML_OP_UPSCALE:
18411-
{
18412-
n_tasks = n_threads;
18413-
} break;
1841418398
case GGML_OP_PAD:
18415-
{
18416-
n_tasks = n_threads;
18417-
} break;
1841818399
case GGML_OP_ARANGE:
18419-
{
18420-
n_tasks = n_threads;
18421-
} break;
1842218400
case GGML_OP_TIMESTEP_EMBEDDING:
18423-
{
18424-
n_tasks = n_threads;
18425-
} break;
1842618401
case GGML_OP_ARGSORT:
18427-
{
18428-
n_tasks = n_threads;
18429-
} break;
1843018402
case GGML_OP_FLASH_ATTN_EXT:
18431-
{
18432-
n_tasks = n_threads;
18433-
} break;
1843418403
case GGML_OP_FLASH_ATTN_BACK:
18435-
{
18436-
n_tasks = n_threads;
18437-
} break;
1843818404
case GGML_OP_SSM_CONV:
1843918405
case GGML_OP_SSM_SCAN:
1844018406
{
@@ -18482,9 +18448,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
1848218448
}
1848318449
} break;
1848418450
case GGML_OP_CROSS_ENTROPY_LOSS:
18485-
{
18486-
n_tasks = n_threads;
18487-
} break;
1848818451
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
1848918452
{
1849018453
n_tasks = n_threads;
@@ -18514,37 +18477,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
1851418477
return n_tasks;
1851518478
}
1851618479

18517-
static thread_ret_t ggml_graph_compute_thread(void * data) {
18518-
struct ggml_compute_state * state = (struct ggml_compute_state *) data;
18519-
18520-
const struct ggml_cgraph * cgraph = state->shared->cgraph;
18521-
const struct ggml_cplan * cplan = state->shared->cplan;
18522-
18523-
set_numa_thread_affinity(state->ith);
18524-
18525-
struct ggml_compute_params params = {
18526-
/*.ith =*/ state->ith,
18527-
/*.nth =*/ state->shared->n_threads,
18528-
/*.wsize =*/ cplan->work_size,
18529-
/*.wdata =*/ cplan->work_data,
18530-
/*.shared=*/ state->shared,
18531-
};
18532-
18533-
for (int node_n = 0; node_n < cgraph->n_nodes; node_n++) {
18534-
if (cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) {
18535-
state->ec = GGML_STATUS_ABORTED;
18536-
return 0;
18537-
}
18538-
struct ggml_tensor * node = cgraph->nodes[node_n];
18539-
18540-
ggml_compute_forward(&params, node);
18541-
18542-
ggml_barrier(state->shared);
18543-
}
18544-
18545-
return 0;
18546-
}
18547-
1854818480
struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threads) {
1854918481
if (n_threads <= 0) {
1855018482
n_threads = GGML_DEFAULT_N_THREADS;
@@ -18713,75 +18645,48 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
1871318645
return cplan;
1871418646
}
1871518647

18716-
static enum ggml_status ggml_graph_compute_parallel(struct ggml_compute_state * workers, int n_threads) {
18717-
enum ggml_status compute_status = GGML_STATUS_SUCCESS;
18648+
static thread_ret_t ggml_graph_compute_thread(void * data) {
18649+
struct ggml_compute_state * state = (struct ggml_compute_state *) data;
1871818650

18719-
#ifdef GGML_USE_OPENMP
18720-
if (n_threads > 1) {
18721-
#pragma omp parallel num_threads(n_threads)
18722-
{
18723-
#pragma omp single
18724-
{
18725-
// update the number of threads from the actual number of threads that we got from OpenMP
18726-
n_threads = omp_get_num_threads();
18727-
workers[0].shared->n_threads = n_threads;
18728-
workers[0].shared->current_chunk = n_threads;
18729-
}
18730-
ggml_graph_compute_thread(&workers[omp_get_thread_num()]);
18731-
}
18732-
} else {
18733-
ggml_graph_compute_thread(&workers[0]);
18734-
}
18735-
#else
18736-
// create thread pool
18737-
if (n_threads > 1) {
18738-
for (int j = 1; j < n_threads; ++j) {
18739-
const int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]);
18740-
GGML_ASSERT(rc == 0);
18741-
UNUSED(rc);
18742-
}
18743-
}
18651+
const struct ggml_cgraph * cgraph = state->shared->cgraph;
18652+
const struct ggml_cplan * cplan = state->shared->cplan;
1874418653

18745-
// this is a work thread too
18746-
ggml_graph_compute_thread(&workers[0]);
18654+
set_numa_thread_affinity(state->ith);
1874718655

18748-
// join or kill thread pool
18749-
if (n_threads > 1) {
18750-
for (int j = 1; j < n_threads; j++) {
18751-
const int rc = ggml_thread_join(workers[j].thrd, NULL);
18752-
GGML_ASSERT(rc == 0);
18753-
UNUSED(rc);
18656+
struct ggml_compute_params params = {
18657+
/*.ith =*/ state->ith,
18658+
/*.nth =*/ state->shared->n_threads,
18659+
/*.wsize =*/ cplan->work_size,
18660+
/*.wdata =*/ cplan->work_data,
18661+
/*.shared=*/ state->shared,
18662+
};
18663+
18664+
for (int node_n = 0; node_n < cgraph->n_nodes; node_n++) {
18665+
struct ggml_tensor * node = cgraph->nodes[node_n];
18666+
18667+
ggml_compute_forward(&params, node);
18668+
18669+
if (state->ith == 0 && cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) {
18670+
state->shared->ec = GGML_STATUS_ABORTED;
1875418671
}
18755-
}
18756-
#endif
18757-
// don't leave affinity set on the main thread
18758-
clear_numa_thread_affinity();
1875918672

18760-
for (int j = 0; j < n_threads; j++) {
18761-
if (workers[j].ec != GGML_STATUS_SUCCESS) {
18762-
compute_status = workers[j].ec;
18673+
ggml_barrier(state->shared);
18674+
18675+
if (state->shared->ec != GGML_STATUS_SUCCESS) {
1876318676
break;
1876418677
}
1876518678
}
18766-
return compute_status;
18679+
18680+
return 0;
1876718681
}
1876818682

1876918683
enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
18770-
{
18771-
GGML_ASSERT(cplan);
18772-
GGML_ASSERT(cplan->n_threads > 0);
18773-
18774-
if (cplan->work_size > 0) {
18775-
GGML_ASSERT(cplan->work_data);
18776-
}
18777-
}
18684+
GGML_ASSERT(cplan);
18685+
GGML_ASSERT(cplan->n_threads > 0);
18686+
GGML_ASSERT(cplan->work_size == 0 || cplan->work_data != NULL);
1877818687

1877918688
int n_threads = cplan->n_threads;
1878018689

18781-
#if defined(GGML_USE_OPENMP)
18782-
n_threads = MIN(n_threads, omp_get_max_threads());
18783-
#endif
18784-
1878518690
struct ggml_compute_state_shared state_shared = {
1878618691
/*.cgraph =*/ cgraph,
1878718692
/*.cgraph_plan =*/ cplan,
@@ -18791,21 +18696,70 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl
1879118696
/*.abort_callback =*/ NULL,
1879218697
/*.abort_callback_data =*/ NULL,
1879318698
/*.current_chunk =*/ 0,
18699+
/*.ec =*/ GGML_STATUS_SUCCESS,
1879418700
};
18701+
18702+
#ifdef GGML_USE_OPENMP
18703+
if (n_threads > 1) {
18704+
#pragma omp parallel num_threads(n_threads)
18705+
{
18706+
#pragma omp single
18707+
{
18708+
// update the number of threads from the actual number of threads that we got from OpenMP
18709+
n_threads = omp_get_num_threads();
18710+
state_shared.n_threads = n_threads;
18711+
}
18712+
18713+
struct ggml_compute_state worker = {
18714+
.thrd = 0,
18715+
.ith = omp_get_thread_num(),
18716+
.shared = &state_shared,
18717+
};
18718+
ggml_graph_compute_thread(&worker);
18719+
}
18720+
} else {
18721+
struct ggml_compute_state worker = {
18722+
.thrd = 0,
18723+
.ith = 0,
18724+
.shared = &state_shared,
18725+
};
18726+
ggml_graph_compute_thread(&worker);
18727+
}
18728+
#else
1879518729
struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads);
1879618730

1879718731
for (int j = 0; j < n_threads; ++j) {
1879818732
workers[j] = (struct ggml_compute_state) {
1879918733
.thrd = 0,
1880018734
.ith = j,
1880118735
.shared = &state_shared,
18802-
.ec = GGML_STATUS_SUCCESS,
1880318736
};
1880418737
}
1880518738

18806-
enum ggml_status compute_status = ggml_graph_compute_parallel(workers, n_threads);
18739+
// create thread pool
18740+
for (int j = 1; j < n_threads; ++j) {
18741+
const int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]);
18742+
GGML_ASSERT(rc == 0);
18743+
UNUSED(rc);
18744+
}
18745+
18746+
// this is a work thread too
18747+
ggml_graph_compute_thread(&workers[0]);
18748+
18749+
// join or kill thread pool
18750+
if (n_threads > 1) {
18751+
for (int j = 1; j < n_threads; j++) {
18752+
const int rc = ggml_thread_join(workers[j].thrd, NULL);
18753+
GGML_ASSERT(rc == 0);
18754+
UNUSED(rc);
18755+
}
18756+
}
18757+
#endif
18758+
18759+
// don't leave affinity set on the main thread
18760+
clear_numa_thread_affinity();
1880718761

18808-
return compute_status;
18762+
return state_shared.ec;
1880918763
}
1881018764

1881118765
enum ggml_status ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads) {

0 commit comments

Comments
 (0)