@@ -16568,8 +16568,6 @@ struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) {
16568
16568
/*.n_nodes =*/ 0,
16569
16569
/*.n_leafs =*/ 0,
16570
16570
/*.n_threads =*/ GGML_DEFAULT_N_THREADS,
16571
- /*.work_size =*/ 0,
16572
- /*.work =*/ NULL,
16573
16571
/*.nodes =*/ { NULL },
16574
16572
/*.grads =*/ { NULL },
16575
16573
/*.leafs =*/ { NULL },
@@ -16740,6 +16738,7 @@ void clear_numa_thread_affinity(void) {}
16740
16738
16741
16739
struct ggml_compute_state_shared {
16742
16740
struct ggml_cgraph * cgraph;
16741
+ struct ggml_cgraph_context * cgraph_ctx;
16743
16742
16744
16743
int64_t perf_node_start_cycles;
16745
16744
int64_t perf_node_start_time_us;
@@ -16769,6 +16768,7 @@ static void ggml_graph_compute_perf_stats_node(struct ggml_tensor * node, const
16769
16768
static thread_ret_t ggml_graph_compute_thread(void * data) {
16770
16769
struct ggml_compute_state * state = (struct ggml_compute_state *) data;
16771
16770
struct ggml_cgraph * cgraph = state->shared->cgraph;
16771
+ struct ggml_cgraph_context * ctx = state->shared->cgraph_ctx;
16772
16772
16773
16773
const int n_threads = state->shared->n_threads;
16774
16774
set_numa_thread_affinity(state->ith, n_threads);
@@ -16783,8 +16783,8 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
16783
16783
/*.type =*/ GGML_TASK_FINALIZE,
16784
16784
/*.ith =*/ 0,
16785
16785
/*.nth =*/ 0,
16786
- /*.wsize =*/ cgraph->work ? ggml_nbytes(cgraph->work) : 0 ,
16787
- /*.wdata =*/ cgraph->work ? cgraph->work->data : NULL ,
16786
+ /*.wsize =*/ ctx->work_size ,
16787
+ /*.wdata =*/ ctx->work_data ,
16788
16788
};
16789
16789
16790
16790
if (node_n != -1) {
@@ -16844,8 +16844,8 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
16844
16844
/*.type =*/ GGML_TASK_COMPUTE,
16845
16845
/*.ith =*/ state->ith,
16846
16846
/*.nth =*/ node->n_tasks,
16847
- /*.wsize =*/ cgraph->work ? ggml_nbytes(cgraph->work) : 0 ,
16848
- /*.wdata =*/ cgraph->work ? cgraph->work->data : NULL ,
16847
+ /*.wsize =*/ ctx->work_size ,
16848
+ /*.wdata =*/ ctx->work_data ,
16849
16849
};
16850
16850
16851
16851
if (state->ith < node->n_tasks) {
@@ -16856,23 +16856,20 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
16856
16856
return 0;
16857
16857
}
16858
16858
16859
- void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) {
16860
- const int n_threads = cgraph->n_threads;
16859
+ // Prepare for graph computing.
16860
+ // Will set: node->n_tasks, ctx->{work_size, planned}
16861
+ void ggml_graph_compute_plan(struct ggml_cgraph_context * ctx, struct ggml_cgraph * cgraph) {
16862
+ GGML_ASSERT(ctx);
16863
+ // This function is actually reentrant, but duplicate calls is unnecessary.
16864
+ GGML_ASSERT(ctx->work_size == 0);
16865
+ GGML_ASSERT(ctx->work_data == NULL);
16866
+ GGML_ASSERT(!ctx->planned);
16861
16867
16862
- struct ggml_compute_state_shared state_shared = {
16863
- /*.cgraph =*/ cgraph,
16864
- /*.perf_node_start_cycles =*/ 0,
16865
- /*.perf_node_start_time_us =*/ 0,
16866
- /*.n_threads =*/ n_threads,
16867
- /*.n_active =*/ n_threads,
16868
- /*.node_n =*/ -1,
16869
- };
16870
- struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads);
16868
+ int n_threads = cgraph->n_threads;
16869
+ size_t work_size = 0;
16871
16870
16872
16871
// initialize tasks + work buffer
16873
16872
{
16874
- size_t work_size = 0;
16875
-
16876
16873
// thread scheduling for the different operations
16877
16874
for (int i = 0; i < cgraph->n_nodes; i++) {
16878
16875
struct ggml_tensor * node = cgraph->nodes[i];
@@ -17202,19 +17199,53 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
17202
17199
} break;
17203
17200
}
17204
17201
}
17202
+ }
17205
17203
17206
- if (cgraph->work != NULL && work_size > cgraph->work_size) {
17207
- GGML_ASSERT(false); // TODO: better handling
17208
- }
17204
+ if (work_size > 0) {
17205
+ work_size += CACHE_LINE_SIZE*(n_threads - 1);
17206
+ }
17207
+
17208
+ ctx->work_size = work_size;
17209
+ ctx->work_data = NULL;
17210
+ ctx->planned = true;
17211
+ }
17209
17212
17210
- if (work_size > 0 && cgraph->work == NULL) {
17211
- cgraph->work_size = work_size + CACHE_LINE_SIZE*(n_threads - 1);
17213
+ void ggml_graph_compute_v2(struct ggml_cgraph_context * ctx, struct ggml_cgraph * cgraph) {
17214
+ if (ctx == NULL) {
17215
+ ctx = alloca(sizeof(struct ggml_cgraph_context));
17216
+ GGML_ASSERT(ctx);
17217
+ ctx->work_size = 0;
17218
+ ctx->work_data = NULL;
17219
+ ctx->planned = false;
17220
+ } else {
17221
+ // The work_size and work_data MAY have default values even if has been planned.
17222
+ if (ctx->work_size > 0) {
17223
+ GGML_ASSERT(ctx->work_data);
17224
+ }
17225
+ }
17212
17226
17213
- GGML_PRINT_DEBUG("%s: allocating work buffer for graph (%zu bytes)\n", __func__, cgraph->work_size);
17214
- cgraph->work = ggml_new_tensor_1d(ctx, GGML_TYPE_I8, cgraph->work_size);
17227
+ if (!ctx->planned) {
17228
+ ggml_graph_compute_plan(ctx, cgraph);
17229
+ if (ctx->work_size > 0) {
17230
+ ctx->work_data = malloc(ctx->work_size * sizeof(GGML_TYPE_I8));
17231
+ GGML_ASSERT(ctx->work_data);
17232
+ GGML_PRINT_DEBUG("%s: allocating work buffer for graph (%zu bytes)\n", __func__, work_size);
17215
17233
}
17216
17234
}
17217
17235
17236
+ const int n_threads = cgraph->n_threads;
17237
+
17238
+ struct ggml_compute_state_shared state_shared = {
17239
+ /*.cgraph =*/ cgraph,
17240
+ /*.cgraph_ctx =*/ ctx,
17241
+ /*.perf_node_start_cycles =*/ 0,
17242
+ /*.perf_node_start_time_us =*/ 0,
17243
+ /*.n_threads =*/ n_threads,
17244
+ /*.n_active =*/ n_threads,
17245
+ /*.node_n =*/ -1,
17246
+ };
17247
+ struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads);
17248
+
17218
17249
// create thread pool
17219
17250
if (n_threads > 1) {
17220
17251
for (int j = 1; j < n_threads; ++j) {
@@ -17266,6 +17297,12 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
17266
17297
}
17267
17298
}
17268
17299
17300
+ // Deprecated, keep it only for backward compatibility.
17301
+ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) {
17302
+ UNUSED(ctx);
17303
+ ggml_graph_compute_v2(NULL, cgraph);
17304
+ }
17305
+
17269
17306
void ggml_graph_reset(struct ggml_cgraph * cgraph) {
17270
17307
for (int i = 0; i < cgraph->n_nodes; i++) {
17271
17308
struct ggml_tensor * grad = cgraph->grads[i];
0 commit comments