Skip to content

Commit 1973399

Browse files
committed
server : simplify
ggml-ci
1 parent 831e63a commit 1973399

File tree

1 file changed

+42
-49
lines changed

1 file changed

+42
-49
lines changed

examples/server/server.cpp

Lines changed: 42 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -2251,6 +2251,8 @@ struct server_context {
22512251

22522252
id = common_sampler_sample(slot.smpl, ctx, slot.i_batch - i);
22532253

2254+
slot.i_batch = -1;
2255+
22542256
common_sampler_accept(slot.smpl, id, true);
22552257

22562258
slot.n_decoded += 1;
@@ -2277,73 +2279,64 @@ struct server_context {
22772279
slot.print_timings();
22782280
send_final_response(slot);
22792281
metrics.on_prediction(slot);
2282+
continue;
22802283
}
22812284
}
22822285

2283-
slot.i_batch = -1;
2284-
2285-
if (slot.ctx_dft) {
2286-
struct common_speculative_params params_spec;
2287-
params_spec.n_draft = params.n_draft;
2288-
params_spec.n_reuse = 256;
2289-
params_spec.p_min = 0.9f;
2290-
2291-
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, slot.cache_tokens, id);
2286+
// check if the slot supports speculative decoding
2287+
if (!slot.ctx_dft) {
2288+
continue;
2289+
}
22922290

2293-
if (draft.size() > params.n_draft_min) {
2294-
common_batch_clear(slot.batch_spec);
2295-
common_batch_add(slot.batch_spec, id, slot.n_past++, { slot.id }, true);
2291+
// TODO: configurable through requests
2292+
struct common_speculative_params params_spec;
2293+
params_spec.n_draft = params.n_draft;
2294+
params_spec.n_reuse = 256;
2295+
params_spec.p_min = 0.9f;
22962296

2297-
for (size_t i = 0; i < draft.size(); ++i) {
2298-
common_batch_add(slot.batch_spec, draft[i], slot.n_past + i, { slot.id }, true);
2299-
}
2297+
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, slot.cache_tokens, id);
23002298

2301-
llama_decode(ctx, slot.batch_spec);
2302-
2303-
const auto ids = common_sampler_sample_n(slot.smpl, ctx, draft);
2299+
if (params.n_draft_min > (int) draft.size()) {
2300+
continue;
2301+
}
23042302

2305-
slot.n_past += ids.size() - 1;
2303+
// construct the speculation batch
2304+
common_batch_clear(slot.batch_spec);
2305+
common_batch_add (slot.batch_spec, id, slot.n_past, { slot.id }, true);
23062306

2307-
slot.cache_tokens.push_back(id);
2307+
for (size_t i = 0; i < draft.size(); ++i) {
2308+
common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true);
2309+
}
23082310

2309-
for (size_t i = 0; i < ids.size(); ++i) {
2310-
completion_token_output result;
2311+
llama_decode(ctx, slot.batch_spec);
23112312

2312-
id = ids[i];
2313+
// the accepted tokens from the speculation
2314+
const auto ids = common_sampler_sample_n(slot.smpl, ctx, draft);
23132315

2314-
common_sampler_accept(slot.smpl, id, true);
2316+
slot.n_past += ids.size();
2317+
slot.n_decoded += ids.size();
23152318

2316-
slot.n_decoded += 1;
2317-
if (slot.n_decoded == 1) {
2318-
slot.t_start_generation = ggml_time_us();
2319-
slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
2320-
metrics.on_prompt_eval(slot);
2321-
}
2319+
slot.cache_tokens.push_back(id);
2320+
slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1);
23222321

2323-
result.tok = id;
2322+
llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1);
23242323

2325-
const auto * cur_p = common_sampler_get_candidates(slot.smpl);
2324+
for (size_t i = 0; i < ids.size(); ++i) {
2325+
completion_token_output result;
23262326

2327-
for (size_t i = 0; i < (size_t) slot.sparams.n_probs; ++i) {
2328-
result.probs.push_back({
2329-
cur_p->data[i].id,
2330-
i >= cur_p->size ? 0.0f : cur_p->data[i].p,
2331-
});
2332-
}
2327+
id = ids[i];
23332328

2334-
if (!process_token(result, slot)) {
2335-
// release slot because of stop condition
2336-
slot.release();
2337-
slot.print_timings();
2338-
send_final_response(slot);
2339-
metrics.on_prediction(slot);
2340-
break;
2341-
}
2342-
}
2329+
common_sampler_accept(slot.smpl, id, true);
23432330

2344-
llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1);
2331+
result.tok = id;
23452332

2346-
slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1);
2333+
if (!process_token(result, slot)) {
2334+
// release slot because of stop condition
2335+
slot.release();
2336+
slot.print_timings();
2337+
send_final_response(slot);
2338+
metrics.on_prediction(slot);
2339+
break;
23472340
}
23482341
}
23492342
}

0 commit comments

Comments
 (0)