@@ -2251,6 +2251,8 @@ struct server_context {
2251
2251
2252
2252
id = common_sampler_sample (slot.smpl , ctx, slot.i_batch - i);
2253
2253
2254
+ slot.i_batch = -1 ;
2255
+
2254
2256
common_sampler_accept (slot.smpl , id, true );
2255
2257
2256
2258
slot.n_decoded += 1 ;
@@ -2277,73 +2279,64 @@ struct server_context {
2277
2279
slot.print_timings ();
2278
2280
send_final_response (slot);
2279
2281
metrics.on_prediction (slot);
2282
+ continue ;
2280
2283
}
2281
2284
}
2282
2285
2283
- slot.i_batch = -1 ;
2284
-
2285
- if (slot.ctx_dft ) {
2286
- struct common_speculative_params params_spec;
2287
- params_spec.n_draft = params.n_draft ;
2288
- params_spec.n_reuse = 256 ;
2289
- params_spec.p_min = 0 .9f ;
2290
-
2291
- llama_tokens draft = common_speculative_gen_draft (slot.spec , params_spec, slot.cache_tokens , id);
2286
+ // check if the slot supports speculative decoding
2287
+ if (!slot.ctx_dft ) {
2288
+ continue ;
2289
+ }
2292
2290
2293
- if (draft.size () > params.n_draft_min ) {
2294
- common_batch_clear (slot.batch_spec );
2295
- common_batch_add (slot.batch_spec , id, slot.n_past ++, { slot.id }, true );
2291
+ // TODO: configurable through requests
2292
+ struct common_speculative_params params_spec;
2293
+ params_spec.n_draft = params.n_draft ;
2294
+ params_spec.n_reuse = 256 ;
2295
+ params_spec.p_min = 0 .9f ;
2296
2296
2297
- for (size_t i = 0 ; i < draft.size (); ++i) {
2298
- common_batch_add (slot.batch_spec , draft[i], slot.n_past + i, { slot.id }, true );
2299
- }
2297
+ llama_tokens draft = common_speculative_gen_draft (slot.spec , params_spec, slot.cache_tokens , id);
2300
2298
2301
- llama_decode (ctx, slot. batch_spec );
2302
-
2303
- const auto ids = common_sampler_sample_n (slot. smpl , ctx, draft);
2299
+ if (params. n_draft_min > ( int ) draft. size ()) {
2300
+ continue ;
2301
+ }
2304
2302
2305
- slot.n_past += ids.size () - 1 ;
2303
+ // construct the speculation batch
2304
+ common_batch_clear (slot.batch_spec );
2305
+ common_batch_add (slot.batch_spec , id, slot.n_past , { slot.id }, true );
2306
2306
2307
- slot.cache_tokens .push_back (id);
2307
+ for (size_t i = 0 ; i < draft.size (); ++i) {
2308
+ common_batch_add (slot.batch_spec , draft[i], slot.n_past + 1 + i, { slot.id }, true );
2309
+ }
2308
2310
2309
- for (size_t i = 0 ; i < ids.size (); ++i) {
2310
- completion_token_output result;
2311
+ llama_decode (ctx, slot.batch_spec );
2311
2312
2312
- id = ids[i];
2313
+ // the accepted tokens from the speculation
2314
+ const auto ids = common_sampler_sample_n (slot.smpl , ctx, draft);
2313
2315
2314
- common_sampler_accept (slot.smpl , id, true );
2316
+ slot.n_past += ids.size ();
2317
+ slot.n_decoded += ids.size ();
2315
2318
2316
- slot.n_decoded += 1 ;
2317
- if (slot.n_decoded == 1 ) {
2318
- slot.t_start_generation = ggml_time_us ();
2319
- slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt ) / 1e3 ;
2320
- metrics.on_prompt_eval (slot);
2321
- }
2319
+ slot.cache_tokens .push_back (id);
2320
+ slot.cache_tokens .insert (slot.cache_tokens .end (), ids.begin (), ids.end () - 1 );
2322
2321
2323
- result. tok = id ;
2322
+ llama_kv_cache_seq_rm (ctx, slot. id , slot. n_past , - 1 ) ;
2324
2323
2325
- const auto * cur_p = common_sampler_get_candidates (slot.smpl );
2324
+ for (size_t i = 0 ; i < ids.size (); ++i) {
2325
+ completion_token_output result;
2326
2326
2327
- for (size_t i = 0 ; i < (size_t ) slot.sparams .n_probs ; ++i) {
2328
- result.probs .push_back ({
2329
- cur_p->data [i].id ,
2330
- i >= cur_p->size ? 0 .0f : cur_p->data [i].p ,
2331
- });
2332
- }
2327
+ id = ids[i];
2333
2328
2334
- if (!process_token (result, slot)) {
2335
- // release slot because of stop condition
2336
- slot.release ();
2337
- slot.print_timings ();
2338
- send_final_response (slot);
2339
- metrics.on_prediction (slot);
2340
- break ;
2341
- }
2342
- }
2329
+ common_sampler_accept (slot.smpl , id, true );
2343
2330
2344
- llama_kv_cache_seq_rm (ctx, slot. id , slot. n_past , - 1 ) ;
2331
+ result. tok = id ;
2345
2332
2346
- slot.cache_tokens .insert (slot.cache_tokens .end (), ids.begin (), ids.end () - 1 );
2333
+ if (!process_token (result, slot)) {
2334
+ // release slot because of stop condition
2335
+ slot.release ();
2336
+ slot.print_timings ();
2337
+ send_final_response (slot);
2338
+ metrics.on_prediction (slot);
2339
+ break ;
2347
2340
}
2348
2341
}
2349
2342
}
0 commit comments