Skip to content

common.cpp: add --enable-special-out and --disable-special-out for ov… #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -901,6 +901,22 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.interactive = true;
return true;
}
if (arg == "-eso" || arg == "--enable-special-out") {
params.enable_special_token_rendering = true;
if (params.disable_special_token_rendering) {
invalid_param = true;
return true;
}
return true;
}
if (arg == "-dso" ||arg == "--disable-special-out") {
params.disable_special_token_rendering = true;
if (params.enable_special_token_rendering) {
invalid_param = true;
return true;
}
return true;
}
if (arg == "--interactive-specials") {
params.interactive_specials = true;
return true;
Expand Down Expand Up @@ -1432,6 +1448,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" -h, --help show this help message and exit\n");
printf(" --version show version and build info\n");
printf(" -i, --interactive run in interactive mode\n");
printf(" -eso --enable-special-out enable special tokens print (overrides default behaviour)\n");
printf(" -dso --disable-special-out disable special tokens print (overrides default behaviour)\n");
printf(" --interactive-specials allow special tokens in user text, in interactive mode\n");
printf(" --interactive-first run in interactive mode and wait for input right away\n");
printf(" -cnv, --conversation run in conversation mode (does not print special tokens and suffix/prefix)\n");
Expand Down Expand Up @@ -1493,8 +1511,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" modifies the likelihood of token appearing in the completion,\n");
printf(" i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n");
printf(" or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n");
printf(" --grammar GRAMMAR BNF-like grammar to constrain generations (see samples in grammars/ dir)\n");
printf(" --grammar-file FNAME file to read grammar from\n");
printf(" --grammar GRAMMAR BNF-like grammar to constrain generations (see samples in grammars/ dir) (special token disabled by default)\n");
printf(" --grammar-file FNAME file to read grammar from (special token disabled by default)\n");
printf(" -j SCHEMA, --json-schema SCHEMA\n");
printf(" JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object.\n");
printf(" For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead\n");
Expand Down
2 changes: 2 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ struct gpt_params {
bool random_prompt = false; // do not randomize prompt if none provided
bool use_color = false; // use color to distinguish generations and inputs
bool interactive = false; // interactive mode
bool enable_special_token_rendering = false; // override special token rendering to enabled mode regardless of default (useful for debugging)
bool disable_special_token_rendering = false; // override special token rendering to disabled mode regardless of default (useful for scripting)
bool interactive_specials = false; // whether to allow special tokens from user, during interactive mode
bool conversation = false; // conversation mode (does not print special tokens and suffix/prefix)
bool chatml = false; // chatml mode (used for models trained on chatml syntax)
Expand Down
12 changes: 9 additions & 3 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,13 @@ int main(int argc, char ** argv) {
exit(1);
}

bool should_show_special_tokens = sparams.grammar.empty();
const bool special_token_render_override = params.enable_special_token_rendering || params.disable_special_token_rendering;
bool special_token_render = sparams.grammar.empty();
if (params.enable_special_token_rendering) {
special_token_render = true;
} else if (params.disable_special_token_rendering) {
special_token_render = false;
}

while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
// predict
Expand Down Expand Up @@ -742,7 +748,7 @@ int main(int argc, char ** argv) {
// display text
if (input_echo && display) {
for (auto id : embd) {
const std::string token_str = llama_token_to_piece(ctx, id, !params.conversation && should_show_special_tokens);
const std::string token_str = llama_token_to_piece(ctx, id, special_token_render_override ? special_token_render : !params.conversation && special_token_render);
printf("%s", token_str.c_str());

if (embd.size() > 1) {
Expand Down Expand Up @@ -908,7 +914,7 @@ int main(int argc, char ** argv) {
for (size_t i = original_size; i < embd_inp.size(); ++i) {
const llama_token token = embd_inp[i];
output_tokens.push_back(token);
output_ss << llama_token_to_piece(ctx, token, should_show_special_tokens);
output_ss << llama_token_to_piece(ctx, token, special_token_render);
}

n_remain -= line_inp.size();
Expand Down