@@ -37,8 +37,8 @@ int main(int argc, char ** argv) {
37
37
const int n_seq_dft = params.n_parallel ;
38
38
39
39
// TODO: make this configurable
40
- const float p_accept = 0 .4f ;
41
- const float p_split = 0 .3f ;
40
+ const float p_accept = 0 .80f ;
41
+ const float p_split = 0 .10f ;
42
42
43
43
#ifndef LOG_DISABLE_LOGS
44
44
log_set_target (log_filename_generator (" speculative" , " log" ));
@@ -118,7 +118,7 @@ int main(int argc, char ** argv) {
118
118
std::vector<seq_draft> drafts (n_seq_dft);
119
119
120
120
params.grammar .clear (); // the draft samplers will copy the target sampler's grammar
121
- params.sampling_params .temp = 1 . 0f ; // the draft samplers use default temperature
121
+ params.sampling_params .temp = std::max ( 0 . 01f , params. sampling_params . temp );
122
122
123
123
for (int s = 0 ; s < n_seq_dft; ++s) {
124
124
drafts[s].ctx_sampling = llama_sampling_init (params);
@@ -156,7 +156,7 @@ int main(int argc, char ** argv) {
156
156
157
157
llama_sampling_accept (ctx_sampling, ctx_tgt, id);
158
158
159
- // LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, last_tokens ));
159
+ // LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str( ));
160
160
161
161
const std::string token_str = llama_token_to_piece (ctx_tgt, id);
162
162
@@ -202,7 +202,7 @@ int main(int argc, char ** argv) {
202
202
203
203
// TODO: simplify
204
204
{
205
- LOG (" keeping sequence %d\n " , s_keep);
205
+ LOG (" keeping sequence %d, n_past_tgt = %d, n_past_dft = %d \n " , s_keep, n_past_tgt, n_past_dft );
206
206
207
207
llama_kv_cache_seq_keep (ctx_dft, s_keep);
208
208
llama_kv_cache_seq_cp (ctx_dft, s_keep, 0 , -1 , -1 );
@@ -277,7 +277,7 @@ int main(int argc, char ** argv) {
277
277
}
278
278
279
279
if (cur_p[0 ].p < p_accept) {
280
- LOG (" stopping drafting for seq %3d, probability too low: %.3f < 2* %.3f\n " , s, cur_p[0 ].p , cur_p[ 1 ]. p );
280
+ LOG (" stopping drafting for seq %3d, probability too low: %.3f < %.3f\n " , s, cur_p[0 ].p , p_accept );
281
281
drafts[s].drafting = false ;
282
282
continue ;
283
283
}
@@ -337,16 +337,14 @@ int main(int argc, char ** argv) {
337
337
338
338
llama_batch_add (batch_tgt, id, n_past_tgt + i + 1 , { s }, true );
339
339
340
- // no need to evaluate the last drafted token, since we won't use the result
341
- if (batch_tgt.n_tokens > n_draft) {
342
- drafts[s].drafting = false ;
343
- continue ;
344
- }
345
-
346
340
// add the token to the batch for batched decoding with the draft model
347
341
drafts[s].i_batch_dft = batch_dft.n_tokens ;
348
342
349
343
llama_batch_add (batch_dft, id, n_past_cur, { s }, true );
344
+
345
+ if (batch_tgt.n_tokens > n_draft) {
346
+ drafts[s].drafting = false ;
347
+ }
350
348
}
351
349
}
352
350
@@ -365,11 +363,6 @@ int main(int argc, char ** argv) {
365
363
}
366
364
}
367
365
368
- // account for the last drafted token that we didn't evaluate
369
- if (batch_tgt.n_tokens > n_draft) {
370
- ++n_drafted;
371
- }
372
-
373
366
// evaluate the target model on the drafted tokens
374
367
{
375
368
llama_kv_cache_seq_keep (ctx_tgt, 0 );
0 commit comments