Skip to content

Commit e3c2202

Browse files
threadpool: move all pause/resume logic into ggml
1 parent 5d4c0a1 commit e3c2202

File tree

6 files changed

+19
-88
lines changed

6 files changed

+19
-88
lines changed

examples/llama-bench/llama-bench.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -1537,7 +1537,7 @@ int main(int argc, char ** argv) {
15371537
exit(1);
15381538
}
15391539

1540-
llama_attach_threadpool(ctx, threadpool);
1540+
llama_attach_threadpool(ctx, threadpool, NULL);
15411541

15421542
// warmup run
15431543
if (t.n_prompt > 0) {

examples/main/main.cpp

+2-7
Original file line numberDiff line numberDiff line change
@@ -240,11 +240,6 @@ int main(int argc, char ** argv) {
240240
exit(1);
241241
}
242242

243-
llama_attach_batch_threadpool(ctx, threadpool_batch);
244-
if (ctx_guidance) {
245-
llama_attach_batch_threadpool(ctx_guidance, threadpool_batch);
246-
}
247-
248243
// Start the non-batch threadpool in the paused state
249244
tpp.paused = true;
250245
}
@@ -255,9 +250,9 @@ int main(int argc, char ** argv) {
255250
exit(1);
256251
}
257252

258-
llama_attach_threadpool(ctx, threadpool);
253+
llama_attach_threadpool(ctx, threadpool, threadpool_batch);
259254
if (ctx_guidance) {
260-
llama_attach_threadpool(ctx_guidance, threadpool);
255+
llama_attach_threadpool(ctx_guidance, threadpool, threadpool_batch);
261256
}
262257

263258
const int n_ctx_train = llama_n_ctx_train(model);

ggml/src/ggml-backend.c

+5
Original file line numberDiff line numberDiff line change
@@ -910,6 +910,11 @@ void ggml_backend_cpu_set_threadpool(ggml_backend_t backend_cpu, ggml_compute_th
910910
GGML_ASSERT(ggml_backend_is_cpu(backend_cpu));
911911

912912
struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context;
913+
914+
if (ctx->threadpool && ctx->threadpool != threadpool) {
915+
// already had a different threadpool, pause/suspend it before switching
916+
ggml_pause_threadpool(ctx->threadpool);
917+
}
913918
ctx->threadpool = threadpool;
914919
}
915920

ggml/src/ggml.c

-3
Original file line numberDiff line numberDiff line change
@@ -19198,9 +19198,6 @@ static thread_ret_t ggml_graph_compute_secondary_thread(void* data) {
1919819198
state->pending = false;
1919919199

1920019200
ggml_graph_compute_thread(state);
19201-
if (state->threadpool->ec != GGML_STATUS_SUCCESS) {
19202-
break;
19203-
}
1920419201
}
1920519202
}
1920619203

include/llama.h

+2-9
Original file line numberDiff line numberDiff line change
@@ -431,16 +431,9 @@ extern "C" {
431431
// Optional: an auto threadpool gets created in ggml if not passed explicitly
432432
LLAMA_API void llama_attach_threadpool(
433433
struct llama_context * ctx,
434-
ggml_compute_threadpool_t threadpool);
435-
LLAMA_API void llama_attach_batch_threadpool(
436-
struct llama_context * ctx,
437-
ggml_compute_threadpool_t threadpool);
434+
ggml_compute_threadpool_t threadpool,
435+
ggml_compute_threadpool_t threadpool_batch);
438436
LLAMA_API void llama_detach_threadpool(struct llama_context * ctx);
439-
LLAMA_API void llama_detach_batch_threadpool(struct llama_context * ctx);
440-
LLAMA_API void llama_detach_threadpools(struct llama_context * ctx);
441-
442-
// Pauses all attached threadpools
443-
LLAMA_API void llama_pause_threadpools(struct llama_context * ctx);
444437

445438
// Call once at the end of the program - currently only used for MPI
446439
LLAMA_API void llama_backend_free(void);

src/llama.cpp

+9-68
Original file line numberDiff line numberDiff line change
@@ -15523,39 +15523,6 @@ static void llama_graph_compute(
1552315523
// fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched));
1552415524
}
1552515525

15526-
// Optionally swaps the batch and single-tok threadpools.
15527-
// Returns the number of threads, and if a valid threadpool exists, returns it too.
15528-
static std::pair<int32_t, ggml_compute_threadpool_t> llama_swap_threadpools(
15529-
llama_context & lctx,
15530-
int32_t n_tokens) {
15531-
15532-
const auto & cparams = lctx.cparams;
15533-
int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
15534-
15535-
ggml_compute_threadpool_t threadpool = nullptr; // nullptr -> disposable threadpool
15536-
15537-
// A batch threadpool without a non-batch threadpool isn't supported.
15538-
GGML_ASSERT(!lctx.threadpool_batch || lctx.threadpool);
15539-
15540-
if (lctx.threadpool_batch && lctx.threadpool) {
15541-
// Switch between the 2 threadpools as needed
15542-
if (n_tokens > 1) {
15543-
ggml_pause_threadpool(lctx.threadpool);
15544-
threadpool = lctx.threadpool_batch;
15545-
n_threads = cparams.n_threads_batch;
15546-
} else {
15547-
ggml_pause_threadpool(lctx.threadpool_batch);
15548-
threadpool = lctx.threadpool;
15549-
n_threads = cparams.n_threads;
15550-
}
15551-
} else if (lctx.threadpool) {
15552-
threadpool = lctx.threadpool;
15553-
n_threads = cparams.n_threads;
15554-
}
15555-
return std::make_pair(n_threads, threadpool);
15556-
}
15557-
15558-
1555915526
// decode a batch of tokens by evaluating the transformer
1556015527
//
1556115528
// - lctx: llama context
@@ -15662,11 +15629,8 @@ static int llama_decode_internal(
1566215629
lctx.n_outputs = n_outputs_new;
1566315630
}
1566415631

15665-
std::pair<int32_t, ggml_compute_threadpool_t> threads =
15666-
llama_swap_threadpools(lctx, n_tokens);
15667-
15668-
int n_threads = threads.first;
15669-
ggml_compute_threadpool_t threadpool = threads.second;
15632+
int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
15633+
ggml_compute_threadpool_t threadpool = n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch;
1567015634

1567115635
GGML_ASSERT(n_threads > 0);
1567215636

@@ -15906,11 +15870,9 @@ static int llama_encode_internal(
1590615870
lctx.inp_embd_enc = NULL;
1590715871
lctx.n_outputs = n_tokens;
1590815872

15909-
std::pair<int32_t, ggml_compute_threadpool_t> threads =
15910-
llama_swap_threadpools(lctx, n_tokens);
15873+
int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
15874+
ggml_compute_threadpool_t threadpool = n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch;
1591115875

15912-
int n_threads = threads.first;
15913-
ggml_compute_threadpool_t threadpool = threads.second;
1591415876
GGML_ASSERT(n_threads > 0);
1591515877

1591615878
ggml_backend_sched_reset(lctx.sched);
@@ -17500,36 +17462,15 @@ void llama_numa_init(enum ggml_numa_strategy numa) {
1750017462

1750117463
void llama_attach_threadpool(
1750217464
struct llama_context * ctx,
17503-
ggml_compute_threadpool_t threadpool) {
17504-
ctx->threadpool = threadpool;
17505-
}
17506-
17507-
void llama_attach_batch_threadpool(
17508-
struct llama_context * ctx,
17465+
ggml_compute_threadpool_t threadpool,
1750917466
ggml_compute_threadpool_t threadpool_batch) {
17510-
ctx->threadpool_batch = threadpool_batch;
17467+
ctx->threadpool = threadpool;
17468+
ctx->threadpool_batch = threadpool_batch ? threadpool_batch : threadpool;
1751117469
}
1751217470

1751317471
void llama_detach_threadpool(struct llama_context * ctx) {
17514-
ctx->threadpool = nullptr;
17515-
}
17516-
17517-
void llama_detach_batch_threadpool(struct llama_context * ctx) {
17518-
ctx->threadpool = nullptr;
17519-
}
17520-
17521-
void llama_detach_threadpools(struct llama_context * ctx) {
17522-
llama_detach_threadpool(ctx);
17523-
llama_detach_batch_threadpool(ctx);
17524-
}
17525-
17526-
void llama_pause_threadpools(struct llama_context * ctx) {
17527-
if (ctx->threadpool) {
17528-
ggml_pause_threadpool(ctx->threadpool);
17529-
}
17530-
if (ctx->threadpool_batch) {
17531-
ggml_pause_threadpool(ctx->threadpool_batch);
17532-
}
17472+
ctx->threadpool = nullptr;
17473+
ctx->threadpool_batch = nullptr;
1753317474
}
1753417475

1753517476
void llama_backend_free(void) {

0 commit comments

Comments
 (0)