Skip to content

speculative : refactor and add a simpler example #10362

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -966,6 +966,7 @@ OBJ_COMMON = \
$(DIR_COMMON)/console.o \
$(DIR_COMMON)/ngram-cache.o \
$(DIR_COMMON)/sampling.o \
$(DIR_COMMON)/speculative.o \
$(DIR_COMMON)/build-info.o \
$(DIR_COMMON)/json-schema-to-grammar.o

Expand Down
2 changes: 2 additions & 0 deletions common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ add_library(${TARGET} STATIC
ngram-cache.h
sampling.cpp
sampling.h
speculative.cpp
speculative.h
)

if (BUILD_SHARED_LIBS)
Expand Down
438 changes: 231 additions & 207 deletions common/arg.cpp

Large diffs are not rendered by default.

76 changes: 68 additions & 8 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -536,12 +536,12 @@ std::string string_from(const struct llama_context * ctx, const struct llama_bat
[](const unsigned char c) { return !std::isprint(c); }),
detokenized.end());

buf << "\n" << std::to_string(i)
<< ":token '" << detokenized << "'"
<< ":pos " << std::to_string(batch.pos[i])
<< ":n_seq_id " << std::to_string(batch.n_seq_id[i])
<< ":seq_id " << std::to_string(batch.seq_id[i][0])
<< ":logits " << std::to_string(batch.logits[i]);
buf << "\n" << std::to_string(i)
<< ", token '" << detokenized << "'"
<< ", pos " << std::to_string(batch.pos[i])
<< ", n_seq_id " << std::to_string(batch.n_seq_id[i])
<< ", seq_id " << std::to_string(batch.seq_id[i][0])
<< ", logits " << std::to_string(batch.logits[i]);
}

buf << " ]";
Expand Down Expand Up @@ -925,9 +925,9 @@ struct common_init_result common_init_from_params(common_params & params) {
common_lora_adapters_apply(lctx, iparams.lora_adapters);
}

if (params.sparams.ignore_eos && llama_token_eos(model) == LLAMA_TOKEN_NULL) {
if (params.sampling.ignore_eos && llama_token_eos(model) == LLAMA_TOKEN_NULL) {
LOG_WRN("%s: warning: model does not have an EOS token, ignoring --ignore-eos\n", __func__);
params.sparams.ignore_eos = false;
params.sampling.ignore_eos = false;
}

if (params.warmup) {
Expand Down Expand Up @@ -1490,6 +1490,66 @@ void common_batch_add(
batch.n_tokens++;
}

//
// Token utils
//

size_t common_lcp(const llama_tokens & a, const llama_tokens & b) {
size_t i;
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}

return i;
}

size_t common_lcs(const llama_tokens & a, const llama_tokens & b) {
// check for empty sequences
if (a.empty() || b.empty()) {
return 0;
}

// get the lengths of the input sequences
size_t a_len = a.size();
size_t b_len = b.size();

// initialize the maximum length of the longest common subsequence (LCS)
size_t max_length = 0;

// use two rows instead of a 2D matrix to optimize space
std::vector<size_t> prev_row(b_len + 1, 0);
std::vector<size_t> curr_row(b_len + 1, 0);

// iterate through the elements of a
for (size_t i = 1; i <= a_len; i++) {
// iterate through the elements of b
for (size_t j = 1; j <= b_len; j++) {
// if elements at the current positions match
if (a[i - 1] == b[j - 1]) {
// if it's the first element of either sequences, set LCS length to 1
if (i == 1 || j == 1) {
curr_row[j] = 1;
} else {
// increment LCS length by 1 compared to the previous element
curr_row[j] = prev_row[j - 1] + 1;
}

// update max_length if necessary
if (curr_row[j] > max_length) {
max_length = curr_row[j];
}
} else {
// reset LCS length if elements don't match
curr_row[j] = 0;
}
}

// update the previous row for the next iteration
prev_row = curr_row;
}

// return the maximum length of the LCS
return max_length;
}

//
// Vocab utils
//
Expand Down
41 changes: 32 additions & 9 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ struct common_lora_adapter_container : common_lora_adapter_info {
struct llama_lora_adapter * adapter;
};

using llama_tokens = std::vector<llama_token>;

// build info
extern int LLAMA_BUILD_NUMBER;
extern char const * LLAMA_COMMIT;
Expand Down Expand Up @@ -101,8 +103,8 @@ enum dimre_method {
DIMRE_METHOD_MEAN,
};

// sampler parameters
struct common_sampler_params {
// sampling parameters
struct common_params_sampling {
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler

int32_t n_prev = 64; // number of previous tokens to remember
Expand Down Expand Up @@ -153,19 +155,30 @@ struct common_sampler_params {
std::string print() const;
};

struct common_params_speculative {
int32_t n_ctx = 0; // draft context size
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
int32_t n_min = 5; // minimum number of draft tokens to use for speculative decoding
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
float p_split = 0.1f; // speculative decoding split probability
float p_min = 0.9f; // minimum speculative decoding probability (greedy)

struct cpu_params cpuparams;
struct cpu_params cpuparams_batch;

std::string model = ""; // draft model for speculative decoding // NOLINT
};

struct common_params {
int32_t n_predict = -1; // new tokens to predict
int32_t n_ctx = 4096; // context size
int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS)
int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS)
int32_t n_keep = 0; // number of tokens to keep from initial prompt
int32_t n_draft = 5; // number of tokens to draft during speculative decoding
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
int32_t n_parallel = 1; // number of parallel sequences to decode
int32_t n_sequences = 1; // number of sequences to decode
float p_split = 0.1f; // speculative decoding split probability
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
int32_t grp_attn_n = 1; // group-attention factor
Expand All @@ -182,8 +195,6 @@ struct common_params {

struct cpu_params cpuparams;
struct cpu_params cpuparams_batch;
struct cpu_params draft_cpuparams;
struct cpu_params draft_cpuparams_batch;

ggml_backend_sched_eval_callback cb_eval = nullptr;
void * cb_eval_user_data = nullptr;
Expand All @@ -195,10 +206,10 @@ struct common_params {
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings

struct common_sampler_params sparams;
struct common_params_sampling sampling;
struct common_params_speculative speculative;

std::string model = ""; // model path // NOLINT
std::string model_draft = ""; // draft model for speculative decoding // NOLINT
std::string model_alias = "unknown"; // model alias // NOLINT
std::string model_url = ""; // model url to download // NOLINT
std::string hf_token = ""; // HF token // NOLINT
Expand Down Expand Up @@ -461,7 +472,9 @@ struct llama_model * common_load_model_from_hf(const char * repo, const char * f
// clear LoRA adapters from context, then apply new list of adapters
void common_lora_adapters_apply(struct llama_context * ctx, std::vector<common_lora_adapter_container> & lora_adapters);

//
// Batch utils
//

void common_batch_clear(struct llama_batch & batch);

Expand All @@ -472,6 +485,16 @@ void common_batch_add(
const std::vector<llama_seq_id> & seq_ids,
bool logits);

//
// Token utils
//

// longest common prefix
size_t common_lcp(const llama_tokens & a, const llama_tokens & b);

// longet common subsequence
size_t common_lcs(const llama_tokens & a, const llama_tokens & b);

//
// Vocab utils
//
Expand Down
45 changes: 42 additions & 3 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ struct ring_buffer {
};

struct common_sampler {
common_sampler_params params;
common_params_sampling params;

struct llama_sampler * grmr;
struct llama_sampler * chain;
Expand All @@ -125,7 +125,7 @@ struct common_sampler {
}
};

std::string common_sampler_params::print() const {
std::string common_params_sampling::print() const {
char result[1024];

snprintf(result, sizeof(result),
Expand All @@ -141,7 +141,7 @@ std::string common_sampler_params::print() const {
return std::string(result);
}

struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_sampler_params & params) {
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) {
llama_sampler_chain_params lparams = llama_sampler_chain_default_params();

lparams.no_perf = params.no_perf;
Expand Down Expand Up @@ -320,6 +320,45 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
return cur_p.data[cur_p.selected].id;
}

std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");

std::vector<llama_token> result;
result.reserve(idxs.size());

size_t i = 0;
for (; i < draft.size(); i++) {
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);

common_sampler_accept(gsmpl, id, true);

result.push_back(id);

if (draft[i] != id) {
break;
}
}

if (i == draft.size()) {
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);

common_sampler_accept(gsmpl, id, true);

result.push_back(id);
}

return result;
}

std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) {
std::vector<int> idxs(draft.size() + 1);
for (size_t i = 0; i < idxs.size(); ++i) {
idxs[i] = i;
}

return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first);
}

uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
return llama_sampler_get_seed(gsmpl->chain);
}
Expand Down
23 changes: 22 additions & 1 deletion common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ struct common_sampler;

// llama_sampler API overloads

struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_sampler_params & params);
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params);

void common_sampler_free(struct common_sampler * gsmpl);

Expand All @@ -60,6 +60,27 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam
//
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);

// generalized version of common_sampler_sample
//
// will cross-reference the sampled tokens with a batch of draft tokens and accept those that match
// if the sampler disagrees at some point, we stop and return the accepted tokens up to now
//
// common_sampler_sample_n(gsmpl, ctx, { idx }, {});
//
// is equivalent to
//
// common_sampler_sample(gsmpl, ctx, idx);
// common_sampler_accept(gsmpl, token, true);
//
// requires: idxs.size() == draft.size() + 1
//
// returns at least 1 token, up to idxs.size()
//
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false);

// assume idxs == [ 0, 1, 2, ..., draft.size() ]
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false);

uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);

// helpers
Expand Down
Loading
Loading