Skip to content

Commit 2392f7a

Browse files
committed
ggml : add ggml_graph_compute_with_ctx()
- backwards compatible API - deduplicates a lot of copy-paste
1 parent 8e1f0b6 commit 2392f7a

File tree

4 files changed

+29
-125
lines changed

4 files changed

+29
-125
lines changed

ggml.c

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16493,21 +16493,17 @@ void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan)
1649316493
}
1649416494
}
1649516495

16496-
// TODO: avoid allocating memory frequently.
16497-
// TODO: make part of public API - use different name and put warning that it makes allocations
16498-
static void ggml_graph_compute_helper(struct ggml_cgraph * cgraph, int n_threads) {
16496+
// same as ggml_graph_compute() but the work data is allocated as a part of the context
16497+
// note: the drawback of this API is that you must have ensured that the context has enough memory for the work data
16498+
void ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads) {
1649916499
struct ggml_cplan cplan = ggml_graph_plan(cgraph, n_threads);
1650016500

16501-
if (cplan.work_size > 0) {
16502-
cplan.work_data = malloc(cplan.work_size);
16503-
GGML_ASSERT(cplan.work_data);
16504-
}
16501+
struct ggml_tensor * buf = ggml_new_tensor_1d(ctx, GGML_TYPE_I8, cplan.work_size);
16502+
GGML_ASSERT(buf);
1650516503

16506-
ggml_graph_compute(cgraph, &cplan);
16504+
cplan.work_data = buf->data;
1650716505

16508-
if (cplan.work_data) {
16509-
free(cplan.work_data);
16510-
}
16506+
ggml_graph_compute(cgraph, &cplan);
1651116507
}
1651216508

1651316509
void ggml_graph_reset(struct ggml_cgraph * cgraph) {
@@ -17292,6 +17288,7 @@ static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g
1729217288
//
1729317289

1729417290
static enum ggml_opt_result ggml_opt_adam(
17291+
struct ggml_context * ctx,
1729517292
struct ggml_opt_context * opt,
1729617293
struct ggml_opt_params params,
1729717294
struct ggml_tensor * f,
@@ -17346,7 +17343,7 @@ static enum ggml_opt_result ggml_opt_adam(
1734617343
ggml_graph_reset (gf);
1734717344
ggml_set_f32 (f->grad, 1.0f);
1734817345

17349-
ggml_graph_compute_helper(gb, params.n_threads);
17346+
ggml_graph_compute_with_ctx(ctx, gb, params.n_threads);
1735017347

1735117348
opt->adam.fx_prev = ggml_get_f32_1d(f, 0);
1735217349
opt->adam.fx_best = opt->adam.fx_prev;
@@ -17427,7 +17424,7 @@ static enum ggml_opt_result ggml_opt_adam(
1742717424
ggml_graph_reset (gf);
1742817425
ggml_set_f32 (f->grad, 1.0f);
1742917426

17430-
ggml_graph_compute_helper(gb, params.n_threads);
17427+
ggml_graph_compute_with_ctx(ctx, gb, params.n_threads);
1743117428

1743217429
const float fx = ggml_get_f32_1d(f, 0);
1743317430

@@ -17498,6 +17495,7 @@ struct ggml_lbfgs_iteration_data {
1749817495
};
1749917496

1750017497
static enum ggml_opt_result linesearch_backtracking(
17498+
struct ggml_context * ctx,
1750117499
const struct ggml_opt_params * params,
1750217500
int nx,
1750317501
float * x,
@@ -17549,7 +17547,7 @@ static enum ggml_opt_result linesearch_backtracking(
1754917547
ggml_graph_reset (gf);
1755017548
ggml_set_f32 (f->grad, 1.0f);
1755117549

17552-
ggml_graph_compute_helper(gb, params->n_threads);
17550+
ggml_graph_compute_with_ctx(ctx, gb, params->n_threads);
1755317551

1755417552
ggml_opt_get_grad(np, ps, g);
1755517553

@@ -17669,7 +17667,7 @@ static enum ggml_opt_result ggml_opt_lbfgs(
1766917667
ggml_graph_reset (gf);
1767017668
ggml_set_f32 (f->grad, 1.0f);
1767117669

17672-
ggml_graph_compute_helper(gb, params.n_threads);
17670+
ggml_graph_compute_with_ctx(ctx, gb, params.n_threads);
1767317671

1767417672
ggml_opt_get_grad(np, ps, g);
1767517673

@@ -17728,7 +17726,7 @@ static enum ggml_opt_result ggml_opt_lbfgs(
1772817726
ggml_vec_cpy_f32(nx, xp, x);
1772917727
ggml_vec_cpy_f32(nx, gp, g);
1773017728

17731-
ls = linesearch_backtracking(&params, nx, x, &fx, g, d, step, xp, f, gf, gb, np, ps);
17729+
ls = linesearch_backtracking(ctx, &params, nx, x, &fx, g, d, step, xp, f, gf, gb, np, ps);
1773217730

1773317731
if (ls < 0) {
1773417732
// linesearch failed - go back to the previous point and return
@@ -18030,7 +18028,7 @@ enum ggml_opt_result ggml_opt_resume_g(
1803018028
switch (opt->params.type) {
1803118029
case GGML_OPT_ADAM:
1803218030
{
18033-
result = ggml_opt_adam(opt, opt->params, f, gf, gb);
18031+
result = ggml_opt_adam(ctx, opt, opt->params, f, gf, gb);
1803418032
} break;
1803518033
case GGML_OPT_LBFGS:
1803618034
{

ggml.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1306,7 +1306,7 @@ extern "C" {
13061306

13071307
GGML_API void ggml_set_param(
13081308
struct ggml_context * ctx,
1309-
struct ggml_tensor * tensor);
1309+
struct ggml_tensor * tensor);
13101310

13111311
GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
13121312

@@ -1319,6 +1319,10 @@ extern "C" {
13191319
GGML_API void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);
13201320
GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph);
13211321

1322+
// same as ggml_graph_compute() but the work data is allocated as a part of the context
1323+
// note: the drawback of this API is that you must have ensured that the context has enough memory for the work data
1324+
GGML_API void ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads);
1325+
13221326
GGML_API struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name);
13231327

13241328
GGML_API void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname);

tests/test-grad0.c

Lines changed: 5 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -195,32 +195,6 @@ void print_elements(const char* label, const struct ggml_tensor * t) {
195195

196196
}
197197

198-
struct work_buffer {
199-
size_t size;
200-
uint8_t * data;
201-
};
202-
203-
static uint8_t * work_buffer_resize(struct work_buffer * buf, size_t size) {
204-
if (size == 0) {
205-
return NULL;
206-
}
207-
208-
GGML_ASSERT(buf);
209-
210-
if (buf->size == 0) {
211-
buf->data = malloc(size);
212-
buf->size = size;
213-
} else if (buf->size < size) {
214-
buf->data = realloc(buf->data, size);
215-
buf->size = size;
216-
} else {
217-
// skip shrinking.
218-
}
219-
220-
GGML_ASSERT(buf->data);
221-
return buf->data;
222-
}
223-
224198
bool check_gradient(
225199
const char * op_name,
226200
struct ggml_context * ctx0,
@@ -247,28 +221,12 @@ bool check_gradient(
247221
struct ggml_cgraph gf = ggml_build_forward (f);
248222
struct ggml_cgraph gb = ggml_build_backward(ctx0, &gf, false);
249223

250-
struct work_buffer buf = { /*.size = */ 0, /*.data =*/ NULL };
251-
252-
{
253-
struct ggml_cplan pf = ggml_graph_plan(&gf, n_threads);
254-
if (pf.work_size > 0) {
255-
pf.work_data = malloc(pf.work_size);
256-
GGML_ASSERT(pf.work_data);
257-
}
258-
ggml_graph_compute(&gf, &pf);
259-
if (pf.work_data) {
260-
free(pf.work_data);
261-
}
262-
}
224+
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
263225

264226
ggml_graph_reset (&gf);
265227
ggml_set_f32 (f->grad, 1.0f);
266228

267-
{
268-
struct ggml_cplan pf = ggml_graph_plan(&gb, n_threads);
269-
pf.work_data = work_buffer_resize(&buf, pf.work_size);
270-
ggml_graph_compute(&gf, &pf);
271-
}
229+
ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
272230

273231
// ggml_graph_dump_dot(&gf, NULL, "test-grad0-forward.dot");
274232
// ggml_graph_dump_dot(&gb, &gf, "test-grad0-backward.dot");
@@ -282,24 +240,15 @@ bool check_gradient(
282240
const float xp = x0 + eps;
283241
set_element(x[i], k, xp);
284242

285-
{
286-
struct ggml_cplan pf = ggml_graph_plan(&gf, n_threads);
287-
pf.work_data = work_buffer_resize(&buf, pf.work_size);
288-
ggml_graph_compute(&gf, &pf);
289-
}
243+
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
290244

291245
const float f0 = ggml_get_f32_1d(f, 0);
292246

293247
set_element(x[i], k, xm);
294248

295-
{
296-
struct ggml_cplan pf = ggml_graph_plan(&gf, n_threads);
297-
pf.work_data = work_buffer_resize(&buf, pf.work_size);
298-
ggml_graph_compute(&gf, &pf);
299-
}
249+
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
300250

301251
const float f1 = ggml_get_f32_1d(f, 0);
302-
303252
const float g0 = (f0 - f1)/(2.0f*eps);
304253

305254
set_element(x[i], k, x0);
@@ -308,11 +257,7 @@ bool check_gradient(
308257
ggml_graph_reset (&gf);
309258
ggml_set_f32 (f->grad, 1.0f);
310259

311-
{
312-
struct ggml_cplan pf = ggml_graph_plan(&gb, n_threads);
313-
pf.work_data = work_buffer_resize(&buf, pf.work_size);
314-
ggml_graph_compute(&gf, &pf);
315-
}
260+
ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
316261

317262
const float g1 = get_element(x[i]->grad, k);
318263

@@ -328,10 +273,6 @@ bool check_gradient(
328273
}
329274
}
330275

331-
if (buf.data) {
332-
free(buf.data);
333-
}
334-
335276
return true;
336277
}
337278

tests/test-opt.c

Lines changed: 4 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -115,31 +115,6 @@ void set_element(struct ggml_tensor * t, int idx, float value) {
115115
((float *)t->data)[idx] = value;
116116
}
117117

118-
119-
struct work_buffer {
120-
size_t size;
121-
uint8_t * data;
122-
};
123-
124-
static uint8_t * work_buffer_resize(struct work_buffer * buf, size_t size) {
125-
if (size == 0) {
126-
return NULL;
127-
}
128-
129-
if (buf->size == 0) {
130-
buf->data = malloc(size);
131-
buf->size = size;
132-
} else if (buf->size < size) {
133-
buf->data = realloc(buf->data, size);
134-
buf->size = size;
135-
} else {
136-
// skip shrinking.
137-
}
138-
139-
GGML_ASSERT(buf->data);
140-
return buf->data;
141-
}
142-
143118
int main(void) {
144119
struct ggml_init_params params = {
145120
.mem_size = 1024*1024*1024,
@@ -163,16 +138,10 @@ int main(void) {
163138
struct ggml_tensor * d = ggml_sub(ctx, c, ab);
164139
struct ggml_tensor * e = ggml_sum(ctx, ggml_sqr(ctx, d));
165140

166-
167141
struct ggml_cgraph ge = ggml_build_forward(e);
168-
ggml_graph_reset (&ge);
142+
ggml_graph_reset(&ge);
169143

170-
struct work_buffer buf = { /*.size = */ 0, /*.data =*/ NULL };
171-
{
172-
struct ggml_cplan pe = ggml_graph_plan(&ge, /*n_threads*/ 1);
173-
pe.work_data = work_buffer_resize(&buf, pe.work_size);
174-
ggml_graph_compute(&ge, &pe);
175-
}
144+
ggml_graph_compute_with_ctx(ctx, &ge, /*n_threads*/ 1);
176145

177146
const float fe = ggml_get_f32_1d(e, 0);
178147
printf("%s: e = %.4f\n", __func__, fe);
@@ -181,17 +150,9 @@ int main(void) {
181150

182151
ggml_opt(ctx, opt_params, e);
183152

184-
ggml_graph_reset (&ge);
153+
ggml_graph_reset(&ge);
185154

186-
{
187-
struct ggml_cplan pe = ggml_graph_plan(&ge, /*n_threads*/ 1);
188-
pe.work_data = work_buffer_resize(&buf, pe.work_size);
189-
ggml_graph_compute(&ge, &pe);
190-
}
191-
192-
if (buf.data) {
193-
free(buf.data);
194-
}
155+
ggml_graph_compute_with_ctx(ctx, &ge, /*n_threads*/ 1);
195156

196157
const float fe_opt = ggml_get_f32_1d(e, 0);
197158
printf("%s: original e = %.4f\n", __func__, fe);

0 commit comments

Comments
 (0)