Skip to content

Commit 4e82b2e

Browse files
committed
speculative : bug fixes
1 parent 0e89203 commit 4e82b2e

File tree

1 file changed

+10
-17
lines changed

1 file changed

+10
-17
lines changed

examples/speculative/speculative.cpp

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ int main(int argc, char ** argv) {
3737
const int n_seq_dft = params.n_parallel;
3838

3939
// TODO: make this configurable
40-
const float p_accept = 0.4f;
41-
const float p_split = 0.3f;
40+
const float p_accept = 0.80f;
41+
const float p_split = 0.10f;
4242

4343
#ifndef LOG_DISABLE_LOGS
4444
log_set_target(log_filename_generator("speculative", "log"));
@@ -118,7 +118,7 @@ int main(int argc, char ** argv) {
118118
std::vector<seq_draft> drafts(n_seq_dft);
119119

120120
params.grammar.clear(); // the draft samplers will copy the target sampler's grammar
121-
params.sampling_params.temp = 1.0f; // the draft samplers use default temperature
121+
params.sampling_params.temp = std::max(0.01f, params.sampling_params.temp);
122122

123123
for (int s = 0; s < n_seq_dft; ++s) {
124124
drafts[s].ctx_sampling = llama_sampling_init(params);
@@ -156,7 +156,7 @@ int main(int argc, char ** argv) {
156156

157157
llama_sampling_accept(ctx_sampling, ctx_tgt, id);
158158

159-
//LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, last_tokens));
159+
//LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str());
160160

161161
const std::string token_str = llama_token_to_piece(ctx_tgt, id);
162162

@@ -202,7 +202,7 @@ int main(int argc, char ** argv) {
202202

203203
// TODO: simplify
204204
{
205-
LOG("keeping sequence %d\n", s_keep);
205+
LOG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft);
206206

207207
llama_kv_cache_seq_keep(ctx_dft, s_keep);
208208
llama_kv_cache_seq_cp (ctx_dft, s_keep, 0, -1, -1);
@@ -277,7 +277,7 @@ int main(int argc, char ** argv) {
277277
}
278278

279279
if (cur_p[0].p < p_accept) {
280-
LOG("stopping drafting for seq %3d, probability too low: %.3f < 2*%.3f\n", s, cur_p[0].p, cur_p[1].p);
280+
LOG("stopping drafting for seq %3d, probability too low: %.3f < %.3f\n", s, cur_p[0].p, p_accept);
281281
drafts[s].drafting = false;
282282
continue;
283283
}
@@ -337,16 +337,14 @@ int main(int argc, char ** argv) {
337337

338338
llama_batch_add(batch_tgt, id, n_past_tgt + i + 1, { s }, true);
339339

340-
// no need to evaluate the last drafted token, since we won't use the result
341-
if (batch_tgt.n_tokens > n_draft) {
342-
drafts[s].drafting = false;
343-
continue;
344-
}
345-
346340
// add the token to the batch for batched decoding with the draft model
347341
drafts[s].i_batch_dft = batch_dft.n_tokens;
348342

349343
llama_batch_add(batch_dft, id, n_past_cur, { s }, true);
344+
345+
if (batch_tgt.n_tokens > n_draft) {
346+
drafts[s].drafting = false;
347+
}
350348
}
351349
}
352350

@@ -365,11 +363,6 @@ int main(int argc, char ** argv) {
365363
}
366364
}
367365

368-
// account for the last drafted token that we didn't evaluate
369-
if (batch_tgt.n_tokens > n_draft) {
370-
++n_drafted;
371-
}
372-
373366
// evaluate the target model on the drafted tokens
374367
{
375368
llama_kv_cache_seq_keep(ctx_tgt, 0);

0 commit comments

Comments
 (0)