@@ -15523,39 +15523,6 @@ static void llama_graph_compute(
15523
15523
// fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched));
15524
15524
}
15525
15525
15526
- // Optionally swaps the batch and single-tok threadpools.
15527
- // Returns the number of threads, and if a valid threadpool exists, returns it too.
15528
- static std::pair<int32_t, ggml_compute_threadpool_t> llama_swap_threadpools(
15529
- llama_context & lctx,
15530
- int32_t n_tokens) {
15531
-
15532
- const auto & cparams = lctx.cparams;
15533
- int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
15534
-
15535
- ggml_compute_threadpool_t threadpool = nullptr; // nullptr -> disposable threadpool
15536
-
15537
- // A batch threadpool without a non-batch threadpool isn't supported.
15538
- GGML_ASSERT(!lctx.threadpool_batch || lctx.threadpool);
15539
-
15540
- if (lctx.threadpool_batch && lctx.threadpool) {
15541
- // Switch between the 2 threadpools as needed
15542
- if (n_tokens > 1) {
15543
- ggml_pause_threadpool(lctx.threadpool);
15544
- threadpool = lctx.threadpool_batch;
15545
- n_threads = cparams.n_threads_batch;
15546
- } else {
15547
- ggml_pause_threadpool(lctx.threadpool_batch);
15548
- threadpool = lctx.threadpool;
15549
- n_threads = cparams.n_threads;
15550
- }
15551
- } else if (lctx.threadpool) {
15552
- threadpool = lctx.threadpool;
15553
- n_threads = cparams.n_threads;
15554
- }
15555
- return std::make_pair(n_threads, threadpool);
15556
- }
15557
-
15558
-
15559
15526
// decode a batch of tokens by evaluating the transformer
15560
15527
//
15561
15528
// - lctx: llama context
@@ -15662,11 +15629,8 @@ static int llama_decode_internal(
15662
15629
lctx.n_outputs = n_outputs_new;
15663
15630
}
15664
15631
15665
- std::pair<int32_t, ggml_compute_threadpool_t> threads =
15666
- llama_swap_threadpools(lctx, n_tokens);
15667
-
15668
- int n_threads = threads.first;
15669
- ggml_compute_threadpool_t threadpool = threads.second;
15632
+ int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
15633
+ ggml_compute_threadpool_t threadpool = n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch;
15670
15634
15671
15635
GGML_ASSERT(n_threads > 0);
15672
15636
@@ -15906,11 +15870,9 @@ static int llama_encode_internal(
15906
15870
lctx.inp_embd_enc = NULL;
15907
15871
lctx.n_outputs = n_tokens;
15908
15872
15909
- std::pair<int32_t, ggml_compute_threadpool_t> threads =
15910
- llama_swap_threadpools( lctx, n_tokens) ;
15873
+ int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
15874
+ ggml_compute_threadpool_t threadpool = n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch ;
15911
15875
15912
- int n_threads = threads.first;
15913
- ggml_compute_threadpool_t threadpool = threads.second;
15914
15876
GGML_ASSERT(n_threads > 0);
15915
15877
15916
15878
ggml_backend_sched_reset(lctx.sched);
@@ -17500,36 +17462,15 @@ void llama_numa_init(enum ggml_numa_strategy numa) {
17500
17462
17501
17463
void llama_attach_threadpool(
17502
17464
struct llama_context * ctx,
17503
- ggml_compute_threadpool_t threadpool) {
17504
- ctx->threadpool = threadpool;
17505
- }
17506
-
17507
- void llama_attach_batch_threadpool(
17508
- struct llama_context * ctx,
17465
+ ggml_compute_threadpool_t threadpool,
17509
17466
ggml_compute_threadpool_t threadpool_batch) {
17510
- ctx->threadpool_batch = threadpool_batch;
17467
+ ctx->threadpool = threadpool;
17468
+ ctx->threadpool_batch = threadpool_batch ? threadpool_batch : threadpool;
17511
17469
}
17512
17470
17513
17471
void llama_detach_threadpool(struct llama_context * ctx) {
17514
- ctx->threadpool = nullptr;
17515
- }
17516
-
17517
- void llama_detach_batch_threadpool(struct llama_context * ctx) {
17518
- ctx->threadpool = nullptr;
17519
- }
17520
-
17521
- void llama_detach_threadpools(struct llama_context * ctx) {
17522
- llama_detach_threadpool(ctx);
17523
- llama_detach_batch_threadpool(ctx);
17524
- }
17525
-
17526
- void llama_pause_threadpools(struct llama_context * ctx) {
17527
- if (ctx->threadpool) {
17528
- ggml_pause_threadpool(ctx->threadpool);
17529
- }
17530
- if (ctx->threadpool_batch) {
17531
- ggml_pause_threadpool(ctx->threadpool_batch);
17532
- }
17472
+ ctx->threadpool = nullptr;
17473
+ ctx->threadpool_batch = nullptr;
17533
17474
}
17534
17475
17535
17476
void llama_backend_free(void) {
0 commit comments