Skip to content

Commit 611f679

Browse files
ggerganovarthw
authored andcommitted
speculative : refactor and add a simpler example (ggml-org#10362)
* speculative : refactor and add a simpler example ggml-ci * speculative : clean-up and add comments and TODOs [no ci] * speculative : manage context in common_speculative ggml-ci * speculative : simplify ggml-ci * speculative : simplify (cont) ggml-ci * speculative : add --draft-min CLI arg * speculative : minor fixup * make : build fixes * speculative : do not redraft previous drafts ggml-ci * speculative : fix the draft sampling ggml-ci * speculative : fix compile warning * common : refactor args ggml-ci * common : change defaults [no ci] * common : final touches ggml-ci
1 parent 02dc28e commit 611f679

28 files changed

+1028
-326
lines changed

Makefile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -966,6 +966,7 @@ OBJ_COMMON = \
966966
$(DIR_COMMON)/console.o \
967967
$(DIR_COMMON)/ngram-cache.o \
968968
$(DIR_COMMON)/sampling.o \
969+
$(DIR_COMMON)/speculative.o \
969970
$(DIR_COMMON)/build-info.o \
970971
$(DIR_COMMON)/json-schema-to-grammar.o
971972

common/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ add_library(${TARGET} STATIC
6666
ngram-cache.h
6767
sampling.cpp
6868
sampling.h
69+
speculative.cpp
70+
speculative.h
6971
)
7072

7173
if (BUILD_SHARED_LIBS)

common/arg.cpp

Lines changed: 231 additions & 207 deletions
Large diffs are not rendered by default.

common/common.cpp

Lines changed: 68 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -536,12 +536,12 @@ std::string string_from(const struct llama_context * ctx, const struct llama_bat
536536
[](const unsigned char c) { return !std::isprint(c); }),
537537
detokenized.end());
538538

539-
buf << "\n" << std::to_string(i)
540-
<< ":token '" << detokenized << "'"
541-
<< ":pos " << std::to_string(batch.pos[i])
542-
<< ":n_seq_id " << std::to_string(batch.n_seq_id[i])
543-
<< ":seq_id " << std::to_string(batch.seq_id[i][0])
544-
<< ":logits " << std::to_string(batch.logits[i]);
539+
buf << "\n" << std::to_string(i)
540+
<< ", token '" << detokenized << "'"
541+
<< ", pos " << std::to_string(batch.pos[i])
542+
<< ", n_seq_id " << std::to_string(batch.n_seq_id[i])
543+
<< ", seq_id " << std::to_string(batch.seq_id[i][0])
544+
<< ", logits " << std::to_string(batch.logits[i]);
545545
}
546546

547547
buf << " ]";
@@ -925,9 +925,9 @@ struct common_init_result common_init_from_params(common_params & params) {
925925
common_lora_adapters_apply(lctx, iparams.lora_adapters);
926926
}
927927

928-
if (params.sparams.ignore_eos && llama_token_eos(model) == LLAMA_TOKEN_NULL) {
928+
if (params.sampling.ignore_eos && llama_token_eos(model) == LLAMA_TOKEN_NULL) {
929929
LOG_WRN("%s: warning: model does not have an EOS token, ignoring --ignore-eos\n", __func__);
930-
params.sparams.ignore_eos = false;
930+
params.sampling.ignore_eos = false;
931931
}
932932

933933
if (params.warmup) {
@@ -1490,6 +1490,66 @@ void common_batch_add(
14901490
batch.n_tokens++;
14911491
}
14921492

1493+
//
1494+
// Token utils
1495+
//
1496+
1497+
size_t common_lcp(const llama_tokens & a, const llama_tokens & b) {
1498+
size_t i;
1499+
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
1500+
1501+
return i;
1502+
}
1503+
1504+
size_t common_lcs(const llama_tokens & a, const llama_tokens & b) {
1505+
// check for empty sequences
1506+
if (a.empty() || b.empty()) {
1507+
return 0;
1508+
}
1509+
1510+
// get the lengths of the input sequences
1511+
size_t a_len = a.size();
1512+
size_t b_len = b.size();
1513+
1514+
// initialize the maximum length of the longest common subsequence (LCS)
1515+
size_t max_length = 0;
1516+
1517+
// use two rows instead of a 2D matrix to optimize space
1518+
std::vector<size_t> prev_row(b_len + 1, 0);
1519+
std::vector<size_t> curr_row(b_len + 1, 0);
1520+
1521+
// iterate through the elements of a
1522+
for (size_t i = 1; i <= a_len; i++) {
1523+
// iterate through the elements of b
1524+
for (size_t j = 1; j <= b_len; j++) {
1525+
// if elements at the current positions match
1526+
if (a[i - 1] == b[j - 1]) {
1527+
// if it's the first element of either sequences, set LCS length to 1
1528+
if (i == 1 || j == 1) {
1529+
curr_row[j] = 1;
1530+
} else {
1531+
// increment LCS length by 1 compared to the previous element
1532+
curr_row[j] = prev_row[j - 1] + 1;
1533+
}
1534+
1535+
// update max_length if necessary
1536+
if (curr_row[j] > max_length) {
1537+
max_length = curr_row[j];
1538+
}
1539+
} else {
1540+
// reset LCS length if elements don't match
1541+
curr_row[j] = 0;
1542+
}
1543+
}
1544+
1545+
// update the previous row for the next iteration
1546+
prev_row = curr_row;
1547+
}
1548+
1549+
// return the maximum length of the LCS
1550+
return max_length;
1551+
}
1552+
14931553
//
14941554
// Vocab utils
14951555
//

common/common.h

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ struct common_lora_adapter_container : common_lora_adapter_info {
3333
struct llama_lora_adapter * adapter;
3434
};
3535

36+
using llama_tokens = std::vector<llama_token>;
37+
3638
// build info
3739
extern int LLAMA_BUILD_NUMBER;
3840
extern char const * LLAMA_COMMIT;
@@ -101,8 +103,8 @@ enum dimre_method {
101103
DIMRE_METHOD_MEAN,
102104
};
103105

104-
// sampler parameters
105-
struct common_sampler_params {
106+
// sampling parameters
107+
struct common_params_sampling {
106108
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
107109

108110
int32_t n_prev = 64; // number of previous tokens to remember
@@ -153,19 +155,30 @@ struct common_sampler_params {
153155
std::string print() const;
154156
};
155157

158+
struct common_params_speculative {
159+
int32_t n_ctx = 0; // draft context size
160+
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
161+
int32_t n_min = 5; // minimum number of draft tokens to use for speculative decoding
162+
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
163+
float p_split = 0.1f; // speculative decoding split probability
164+
float p_min = 0.9f; // minimum speculative decoding probability (greedy)
165+
166+
struct cpu_params cpuparams;
167+
struct cpu_params cpuparams_batch;
168+
169+
std::string model = ""; // draft model for speculative decoding // NOLINT
170+
};
171+
156172
struct common_params {
157173
int32_t n_predict = -1; // new tokens to predict
158174
int32_t n_ctx = 4096; // context size
159175
int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS)
160176
int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS)
161177
int32_t n_keep = 0; // number of tokens to keep from initial prompt
162-
int32_t n_draft = 5; // number of tokens to draft during speculative decoding
163178
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
164179
int32_t n_parallel = 1; // number of parallel sequences to decode
165180
int32_t n_sequences = 1; // number of sequences to decode
166-
float p_split = 0.1f; // speculative decoding split probability
167181
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
168-
int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
169182
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
170183
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
171184
int32_t grp_attn_n = 1; // group-attention factor
@@ -182,8 +195,6 @@ struct common_params {
182195

183196
struct cpu_params cpuparams;
184197
struct cpu_params cpuparams_batch;
185-
struct cpu_params draft_cpuparams;
186-
struct cpu_params draft_cpuparams_batch;
187198

188199
ggml_backend_sched_eval_callback cb_eval = nullptr;
189200
void * cb_eval_user_data = nullptr;
@@ -195,10 +206,10 @@ struct common_params {
195206
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
196207
enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings
197208

198-
struct common_sampler_params sparams;
209+
struct common_params_sampling sampling;
210+
struct common_params_speculative speculative;
199211

200212
std::string model = ""; // model path // NOLINT
201-
std::string model_draft = ""; // draft model for speculative decoding // NOLINT
202213
std::string model_alias = "unknown"; // model alias // NOLINT
203214
std::string model_url = ""; // model url to download // NOLINT
204215
std::string hf_token = ""; // HF token // NOLINT
@@ -461,7 +472,9 @@ struct llama_model * common_load_model_from_hf(const char * repo, const char * f
461472
// clear LoRA adapters from context, then apply new list of adapters
462473
void common_lora_adapters_apply(struct llama_context * ctx, std::vector<common_lora_adapter_container> & lora_adapters);
463474

475+
//
464476
// Batch utils
477+
//
465478

466479
void common_batch_clear(struct llama_batch & batch);
467480

@@ -472,6 +485,16 @@ void common_batch_add(
472485
const std::vector<llama_seq_id> & seq_ids,
473486
bool logits);
474487

488+
//
489+
// Token utils
490+
//
491+
492+
// longest common prefix
493+
size_t common_lcp(const llama_tokens & a, const llama_tokens & b);
494+
495+
// longet common subsequence
496+
size_t common_lcs(const llama_tokens & a, const llama_tokens & b);
497+
475498
//
476499
// Vocab utils
477500
//

common/sampling.cpp

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ struct ring_buffer {
9999
};
100100

101101
struct common_sampler {
102-
common_sampler_params params;
102+
common_params_sampling params;
103103

104104
struct llama_sampler * grmr;
105105
struct llama_sampler * chain;
@@ -125,7 +125,7 @@ struct common_sampler {
125125
}
126126
};
127127

128-
std::string common_sampler_params::print() const {
128+
std::string common_params_sampling::print() const {
129129
char result[1024];
130130

131131
snprintf(result, sizeof(result),
@@ -141,7 +141,7 @@ std::string common_sampler_params::print() const {
141141
return std::string(result);
142142
}
143143

144-
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_sampler_params & params) {
144+
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) {
145145
llama_sampler_chain_params lparams = llama_sampler_chain_default_params();
146146

147147
lparams.no_perf = params.no_perf;
@@ -320,6 +320,45 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
320320
return cur_p.data[cur_p.selected].id;
321321
}
322322

323+
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
324+
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
325+
326+
std::vector<llama_token> result;
327+
result.reserve(idxs.size());
328+
329+
size_t i = 0;
330+
for (; i < draft.size(); i++) {
331+
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
332+
333+
common_sampler_accept(gsmpl, id, true);
334+
335+
result.push_back(id);
336+
337+
if (draft[i] != id) {
338+
break;
339+
}
340+
}
341+
342+
if (i == draft.size()) {
343+
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
344+
345+
common_sampler_accept(gsmpl, id, true);
346+
347+
result.push_back(id);
348+
}
349+
350+
return result;
351+
}
352+
353+
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) {
354+
std::vector<int> idxs(draft.size() + 1);
355+
for (size_t i = 0; i < idxs.size(); ++i) {
356+
idxs[i] = i;
357+
}
358+
359+
return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first);
360+
}
361+
323362
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
324363
return llama_sampler_get_seed(gsmpl->chain);
325364
}

common/sampling.h

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ struct common_sampler;
3636

3737
// llama_sampler API overloads
3838

39-
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_sampler_params & params);
39+
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params);
4040

4141
void common_sampler_free(struct common_sampler * gsmpl);
4242

@@ -60,6 +60,27 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam
6060
//
6161
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
6262

63+
// generalized version of common_sampler_sample
64+
//
65+
// will cross-reference the sampled tokens with a batch of draft tokens and accept those that match
66+
// if the sampler disagrees at some point, we stop and return the accepted tokens up to now
67+
//
68+
// common_sampler_sample_n(gsmpl, ctx, { idx }, {});
69+
//
70+
// is equivalent to
71+
//
72+
// common_sampler_sample(gsmpl, ctx, idx);
73+
// common_sampler_accept(gsmpl, token, true);
74+
//
75+
// requires: idxs.size() == draft.size() + 1
76+
//
77+
// returns at least 1 token, up to idxs.size()
78+
//
79+
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false);
80+
81+
// assume idxs == [ 0, 1, 2, ..., draft.size() ]
82+
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false);
83+
6384
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
6485

6586
// helpers

0 commit comments

Comments
 (0)