Skip to content

Commit 56ba00b

Browse files
committed
sampling : hide prev behind API and apply #3661
ggml-ci
1 parent 7e2b5fb commit 56ba00b

File tree

9 files changed

+117
-103
lines changed

9 files changed

+117
-103
lines changed

common/sampling.cpp

+21-2
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,24 @@ void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * ds
6666
dst->prev = src->prev;
6767
}
6868

69+
llama_token llama_sampling_last(llama_sampling_context * ctx) {
70+
return ctx->prev.back();
71+
}
72+
73+
std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n) {
74+
const int size = ctx_sampling->prev.size();
75+
76+
n = std::min(n, size);
77+
78+
std::string result;
79+
80+
for (int i = size - n; i < size; i++) {
81+
result += llama_token_to_piece(ctx_main, ctx_sampling->prev[i]);
82+
}
83+
84+
return result;
85+
}
86+
6987
std::string llama_sampling_print(const llama_sampling_params & params) {
7088
char result[1024];
7189

@@ -193,11 +211,12 @@ llama_token llama_sampling_sample(
193211
void llama_sampling_accept(
194212
struct llama_sampling_context * ctx_sampling,
195213
struct llama_context * ctx_main,
196-
llama_token id) {
214+
llama_token id,
215+
bool apply_grammar) {
197216
ctx_sampling->prev.erase(ctx_sampling->prev.begin());
198217
ctx_sampling->prev.push_back(id);
199218

200-
if (ctx_sampling->grammar != NULL) {
219+
if (ctx_sampling->grammar != NULL && apply_grammar) {
201220
llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id);
202221
}
203222
}

common/sampling.h

+8-1
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,12 @@ void llama_sampling_reset(llama_sampling_context * ctx);
7070
// Copy the sampler context
7171
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst);
7272

73+
// Get the last sampled token
74+
llama_token llama_sampling_last(llama_sampling_context * ctx);
75+
76+
// Get a string representation of the last sampled tokens
77+
std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n);
78+
7379
// Print sampling parameters into a string
7480
std::string llama_sampling_print(const llama_sampling_params & params);
7581

@@ -99,4 +105,5 @@ llama_token llama_sampling_sample(
99105
void llama_sampling_accept(
100106
struct llama_sampling_context * ctx_sampling,
101107
struct llama_context * ctx_main,
102-
llama_token id);
108+
llama_token id,
109+
bool apply_grammar);

examples/CMakeLists.txt

+14-13
Original file line numberDiff line numberDiff line change
@@ -12,25 +12,26 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR})
1212

1313
if (EMSCRIPTEN)
1414
else()
15+
add_subdirectory(baby-llama)
16+
add_subdirectory(batched)
17+
add_subdirectory(batched-bench)
18+
add_subdirectory(beam-search)
19+
add_subdirectory(benchmark)
20+
add_subdirectory(convert-llama2c-to-ggml)
21+
add_subdirectory(embedding)
22+
add_subdirectory(finetune)
23+
add_subdirectory(infill)
24+
add_subdirectory(llama-bench)
25+
add_subdirectory(llava)
1526
add_subdirectory(main)
27+
add_subdirectory(parallel)
28+
add_subdirectory(perplexity)
1629
add_subdirectory(quantize)
1730
add_subdirectory(quantize-stats)
18-
add_subdirectory(perplexity)
19-
add_subdirectory(embedding)
2031
add_subdirectory(save-load-state)
21-
add_subdirectory(benchmark)
22-
add_subdirectory(baby-llama)
23-
add_subdirectory(train-text-from-scratch)
24-
add_subdirectory(finetune)
25-
add_subdirectory(convert-llama2c-to-ggml)
2632
add_subdirectory(simple)
27-
add_subdirectory(batched)
28-
add_subdirectory(batched-bench)
2933
add_subdirectory(speculative)
30-
add_subdirectory(parallel)
31-
add_subdirectory(llava)
32-
add_subdirectory(llama-bench)
33-
add_subdirectory(beam-search)
34+
add_subdirectory(train-text-from-scratch)
3435
if (LLAMA_METAL)
3536
add_subdirectory(metal)
3637
endif()

examples/infill/CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,5 @@ install(TARGETS ${TARGET} RUNTIME)
44
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
55
target_compile_features(${TARGET} PRIVATE cxx_std_11)
66
if(TARGET BUILD_INFO)
7-
add_dependencies(${TARGET} BUILD_INFO)
7+
add_dependencies(${TARGET} BUILD_INFO)
88
endif()

examples/infill/infill.cpp

+11-8
Original file line numberDiff line numberDiff line change
@@ -523,7 +523,7 @@ int main(int argc, char ** argv) {
523523

524524
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
525525

526-
llama_sampling_accept(ctx_sampling, ctx, id);
526+
llama_sampling_accept(ctx_sampling, ctx, id, true);
527527

528528
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str());
529529

@@ -541,8 +541,11 @@ int main(int argc, char ** argv) {
541541
LOG("embd_inp.size(): %d, n_consumed: %d\n", (int) embd_inp.size(), n_consumed);
542542
while ((int) embd_inp.size() > n_consumed) {
543543
embd.push_back(embd_inp[n_consumed]);
544-
ctx_sampling->prev.erase(ctx_sampling->prev.begin());
545-
ctx_sampling->prev.push_back(embd_inp[n_consumed]);
544+
545+
// push the prompt in the sampling context in order to apply repetition penalties later
546+
// for the prompt, we don't apply grammar rules
547+
llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], false);
548+
546549
++n_consumed;
547550
if ((int) embd.size() >= params.n_batch) {
548551
break;
@@ -574,7 +577,7 @@ int main(int argc, char ** argv) {
574577
if ((int) embd_inp.size() <= n_consumed) {
575578

576579
// deal with eot token in infill mode
577-
if ((ctx_sampling->prev.back() == llama_token_eot(ctx) || is_interacting) && params.interactive){
580+
if ((llama_sampling_last(ctx_sampling) == llama_token_eot(ctx) || is_interacting) && params.interactive){
578581
if(is_interacting && !params.interactive_first) {
579582
// print an eot token
580583
printf("%s", llama_token_to_piece(ctx, llama_token_eot(ctx)).c_str());
@@ -591,7 +594,7 @@ int main(int argc, char ** argv) {
591594
buffer += line;
592595
} while (another_line);
593596
// check if we got an empty line, if so we use the old input
594-
if(!buffer.empty() && !(buffer.length() == 1 && buffer[0] == '\n')) {
597+
if (!buffer.empty() && !(buffer.length() == 1 && buffer[0] == '\n')) {
595598
params.input_prefix = buffer;
596599
}
597600
buffer.clear();
@@ -601,7 +604,7 @@ int main(int argc, char ** argv) {
601604
buffer += line;
602605
} while (another_line);
603606
// check if we got an empty line
604-
if(!buffer.empty() && !(buffer.length() == 1 && buffer[0] == '\n')) {
607+
if (!buffer.empty() && !(buffer.length() == 1 && buffer[0] == '\n')) {
605608
params.input_suffix = buffer;
606609
}
607610
buffer.clear();
@@ -614,7 +617,7 @@ int main(int argc, char ** argv) {
614617
process_escapes(params.input_suffix);
615618
}
616619
suff_rm_leading_spc = params.escape;
617-
if (suff_rm_leading_spc && params.input_suffix.find_first_of(" ") == 0 && params.input_suffix.size() > 1) {
620+
if (suff_rm_leading_spc && params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) {
618621
params.input_suffix.erase(0, 1);
619622
suff_rm_leading_spc = false;
620623
}
@@ -641,7 +644,7 @@ int main(int argc, char ** argv) {
641644
is_interacting = false;
642645
}
643646
// deal with end of text token in interactive mode
644-
else if (ctx_sampling->prev.back() == llama_token_eos(ctx)) {
647+
else if (llama_sampling_last(ctx_sampling) == llama_token_eos(ctx)) {
645648
LOG("found EOS token\n");
646649

647650
if (params.interactive) {

examples/main/main.cpp

+8-13
Original file line numberDiff line numberDiff line change
@@ -611,7 +611,7 @@ int main(int argc, char ** argv) {
611611

612612
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
613613

614-
llama_sampling_accept(ctx_sampling, ctx, id);
614+
llama_sampling_accept(ctx_sampling, ctx, id, true);
615615

616616
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str());
617617

@@ -630,12 +630,9 @@ int main(int argc, char ** argv) {
630630
while ((int) embd_inp.size() > n_consumed) {
631631
embd.push_back(embd_inp[n_consumed]);
632632

633-
// GG: I'm not sure it's a good idea to push the prompt tokens into the sampling context
634-
// Most likely will remove this in the future to avoid exposing "prev"
635-
// Same thing is done in "server". If we stop pushing the prompt tokens, then the repetition
636-
// penalty will be applied only based on the tokens generated by the model.
637-
ctx_sampling->prev.erase(ctx_sampling->prev.begin());
638-
ctx_sampling->prev.push_back(embd_inp[n_consumed]);
633+
// push the prompt in the sampling context in order to apply repetition penalties later
634+
// for the prompt, we don't apply grammar rules
635+
llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], false);
639636

640637
++n_consumed;
641638
if ((int) embd.size() >= params.n_batch) {
@@ -666,12 +663,10 @@ int main(int argc, char ** argv) {
666663

667664
// if not currently processing queued inputs;
668665
if ((int) embd_inp.size() <= n_consumed) {
669-
// check for reverse prompt
666+
// check for reverse prompt in the last n_prev tokens
670667
if (!params.antiprompt.empty()) {
671-
std::string last_output;
672-
for (auto id : ctx_sampling->prev) {
673-
last_output += llama_token_to_piece(ctx, id);
674-
}
668+
const int n_prev = 32;
669+
const std::string last_output = llama_sampling_prev_str(ctx_sampling, ctx, n_prev);
675670

676671
is_antiprompt = false;
677672
// Check if each of the reverse prompts appears at the end of the output.
@@ -698,7 +693,7 @@ int main(int argc, char ** argv) {
698693
}
699694

700695
// deal with end of text token in interactive mode
701-
if (ctx_sampling->prev.back() == llama_token_eos(ctx)) {
696+
if (llama_sampling_last(ctx_sampling) == llama_token_eos(ctx)) {
702697
LOG("found EOS token\n");
703698

704699
if (params.interactive) {

examples/parallel/parallel.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ int main(int argc, char ** argv) {
330330

331331
const llama_token id = llama_sampling_sample(client.ctx_sampling, ctx, NULL, client.i_batch - i);
332332

333-
llama_sampling_accept(client.ctx_sampling, ctx, id);
333+
llama_sampling_accept(client.ctx_sampling, ctx, id, true);
334334

335335
if (client.n_decoded == 1) {
336336
// start measuring generation time after the first token to make sure all concurrent clients

0 commit comments

Comments
 (0)