|
8 | 8 | #include <string>
|
9 | 9 | #include <vector>
|
10 | 10 |
|
| 11 | +#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 100 |
| 12 | +#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5 |
| 13 | + |
11 | 14 | struct seq_draft {
|
12 | 15 | bool active = false;
|
13 | 16 | bool drafting = false;
|
@@ -64,6 +67,33 @@ int main(int argc, char ** argv) {
|
64 | 67 | params.n_gpu_layers = params.n_gpu_layers_draft;
|
65 | 68 | std::tie(model_dft, ctx_dft) = llama_init_from_gpt_params(params);
|
66 | 69 |
|
| 70 | + { |
| 71 | + const int n_vocab_tgt = llama_n_vocab(model_tgt); |
| 72 | + const int n_vocab_dft = llama_n_vocab(model_dft); |
| 73 | + const int vocab_diff = n_vocab_tgt > n_vocab_dft |
| 74 | + ? n_vocab_tgt - n_vocab_dft |
| 75 | + : n_vocab_dft - n_vocab_tgt; |
| 76 | + |
| 77 | + if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) { |
| 78 | + fprintf(stderr, "%s: error: draft model vocab must closely match target model to use speculation but ", __func__); |
| 79 | + fprintf(stderr, "target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n", |
| 80 | + n_vocab_tgt, llama_n_vocab(model_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE); |
| 81 | + return 1; |
| 82 | + } |
| 83 | + |
| 84 | + for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) { |
| 85 | + const char * token_text_tgt = llama_token_get_text(model_tgt, i); |
| 86 | + const char * token_text_dft = llama_token_get_text(model_dft, i); |
| 87 | + if (std::strcmp(token_text_tgt, token_text_dft) != 0) { |
| 88 | + fprintf(stderr, "%s: error: draft model vocab must match target model to use speculation but ", __func__); |
| 89 | + fprintf(stderr, "token %d content differs - target '%s', draft '%s'\n", i, |
| 90 | + llama_token_to_piece(ctx_tgt, i).c_str(), |
| 91 | + llama_token_to_piece(ctx_dft, i).c_str()); |
| 92 | + return 1; |
| 93 | + } |
| 94 | + } |
| 95 | + } |
| 96 | + |
67 | 97 | // tokenize the prompt
|
68 | 98 | std::vector<llama_token> inp;
|
69 | 99 | inp = ::llama_tokenize(ctx_tgt, params.prompt, true);
|
@@ -227,6 +257,7 @@ int main(int argc, char ** argv) {
|
227 | 257 | llama_batch_add (batch_dft, id, n_past_dft, { 0 }, true);
|
228 | 258 |
|
229 | 259 | llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
|
| 260 | + // LOG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str()); |
230 | 261 | llama_decode (ctx_dft, batch_dft);
|
231 | 262 |
|
232 | 263 | ++n_past_dft;
|
@@ -370,7 +401,7 @@ int main(int argc, char ** argv) {
|
370 | 401 | llama_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1);
|
371 | 402 | }
|
372 | 403 |
|
373 |
| - //LOG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt)); |
| 404 | + // LOG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str()); |
374 | 405 | llama_decode(ctx_tgt, batch_tgt);
|
375 | 406 | ++n_past_tgt;
|
376 | 407 | }
|
|
0 commit comments