Skip to content

Commit 8a19c85

Browse files
cccclaifacebook-github-bot
authored andcommitted
Support multiple prompts in the runner (#9817)
Summary: Pull Request resolved: #9817 As a preparation for the multiturn conversation, we can provide multiple prompts and execute them in sequence. Example command: ``` ./qnn_llama3_2_runner --model_path hybrid_llama_qnn.pte --tokenizer_path tiktokenizer.bin --eval_mode 1 --prompt "Once upon a time" --prompt "girl named Lily." --prompt "her toys and her favorite toy was a big," --kv_updater "ShiftPointer" --logits_scale 0.1 --output_path output.txt --num_iters 1 ``` It will be hard to use any char as delimiter, so we use `--prompt` to explicitly mark a prompt and collect them together. Reviewed By: kirklandsign Differential Revision: D72276104
1 parent 7d36e41 commit 8a19c85

File tree

1 file changed

+91
-79
lines changed

1 file changed

+91
-79
lines changed

examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp

Lines changed: 91 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -14,82 +14,94 @@
1414
*
1515
*/
1616

17-
#include <executorch/backends/qualcomm/runtime/QnnExecuTorch.h>
18-
#include <executorch/examples/qualcomm/oss_scripts/llama/runner/runner.h>
19-
#include <executorch/runtime/platform/log.h>
20-
#include <gflags/gflags.h>
21-
#include <fstream>
22-
#include <vector>
23-
24-
DEFINE_string(
25-
model_path,
26-
"kv_llama_qnn.pte",
27-
"Model serialized in flatbuffer format.");
28-
DEFINE_string(
29-
output_path,
30-
"outputs.txt",
31-
"Executorch inference data output path.");
32-
DEFINE_string(
33-
performance_output_path,
34-
"inference_speed.txt",
35-
"Records inference speed. For CI purpose.");
36-
DEFINE_string(tokenizer_path, "tokenizer.bin", "Tokenizer stuff.");
37-
DEFINE_string(prompt, "The answer to the ultimate question is", "Prompt.");
38-
DEFINE_string(
39-
system_prompt,
40-
"",
41-
"Tells the model what kind of assistant it should be. For example, You are a helpful AI assistant for travel tips and recommendations. Default is None");
42-
DEFINE_double(
43-
temperature,
44-
0.0f,
45-
"Temperature; Default is 0.0f. 0 = greedy argmax sampling (deterministic). Lower temperature = more deterministic");
46-
DEFINE_int32(
47-
seq_len,
48-
128,
49-
"Total number of tokens to generate (prompt + output).");
50-
DEFINE_int32(
51-
eval_mode,
52-
1,
53-
"0: TokenGenerator(kv) / 1: HybridMode (prefill+kv)");
54-
DEFINE_double(logits_scale, 0.0, "Logits scale");
55-
DEFINE_int32(logits_offset, 0, "Logits offset");
56-
DEFINE_string(
57-
kv_updater,
58-
"How to update kv cache. Choose between SmartMask and ShiftPointer",
59-
"SmartMask");
60-
DEFINE_int32(num_iters, 1, "total num of iterations to run.");
61-
62-
int main(int argc, char** argv) {
63-
gflags::ParseCommandLineFlags(&argc, &argv, true);
64-
65-
// create llama runner
66-
example::Runner runner(
67-
{FLAGS_model_path},
68-
FLAGS_tokenizer_path.c_str(),
69-
FLAGS_performance_output_path.c_str(),
70-
FLAGS_logits_scale,
71-
FLAGS_logits_offset,
72-
FLAGS_temperature,
73-
FLAGS_eval_mode,
74-
FLAGS_kv_updater,
75-
FLAGS_num_iters);
76-
std::vector<char> buf;
77-
buf.reserve(5 * FLAGS_seq_len); // assume each token is around 5 char
78-
std::ofstream fout(FLAGS_output_path.c_str());
79-
auto callback = [&](const std::string& piece) {
80-
for (const char c : piece) {
81-
buf.push_back(c);
82-
}
83-
};
84-
// generate tokens & store inference output
85-
for (int i = 0; i < FLAGS_num_iters; i++) {
86-
runner.generate(
87-
FLAGS_seq_len,
88-
FLAGS_prompt.c_str(),
89-
FLAGS_system_prompt.c_str(),
90-
callback);
91-
}
92-
fout.write(buf.data(), buf.size());
93-
fout.close();
94-
return 0;
95-
}
17+
#include <executorch/backends/qualcomm/runtime/QnnExecuTorch.h>
18+
#include <executorch/examples/qualcomm/oss_scripts/llama/runner/runner.h>
19+
#include <executorch/runtime/platform/log.h>
20+
#include <gflags/gflags.h>
21+
#include <fstream>
22+
#include <vector>
23+
24+
DEFINE_string(
25+
model_path,
26+
"kv_llama_qnn.pte",
27+
"Model serialized in flatbuffer format.");
28+
DEFINE_string(
29+
output_path,
30+
"outputs.txt",
31+
"Executorch inference data output path.");
32+
DEFINE_string(
33+
performance_output_path,
34+
"inference_speed.txt",
35+
"Records inference speed. For CI purpose.");
36+
DEFINE_string(tokenizer_path, "tokenizer.bin", "Tokenizer stuff.");
37+
DEFINE_string(prompt, "The answer to the ultimate question is", "Prompt.");
38+
DEFINE_string(
39+
system_prompt,
40+
"",
41+
"Tells the model what kind of assistant it should be. For example, You are a helpful AI assistant for travel tips and recommendations. Default is None");
42+
DEFINE_double(
43+
temperature,
44+
0.0f,
45+
"Temperature; Default is 0.0f. 0 = greedy argmax sampling (deterministic). Lower temperature = more deterministic");
46+
DEFINE_int32(
47+
seq_len,
48+
128,
49+
"Total number of tokens to generate (prompt + output).");
50+
DEFINE_int32(
51+
eval_mode,
52+
1,
53+
"0: TokenGenerator(kv) / 1: HybridMode (prefill+kv)");
54+
DEFINE_double(logits_scale, 0.0, "Logits scale");
55+
DEFINE_int32(logits_offset, 0, "Logits offset");
56+
DEFINE_string(
57+
kv_updater,
58+
"How to update kv cache. Choose between SmartMask and ShiftPointer",
59+
"SmartMask");
60+
DEFINE_int32(num_iters, 1, "total num of iterations to run.");
61+
62+
std::vector<std::string> CollectPrompts(int argc, char** argv) {
63+
// Collect all prompts from command line, example usage:
64+
// --prompt "prompt1" --prompt "prompt2" --prompt "prompt3"
65+
std::vector<std::string> prompts;
66+
for (int i = 1; i < argc; i++) {
67+
if (std::string(argv[i]) == "--prompt" && i + 1 < argc) {
68+
prompts.push_back(argv[i + 1]);
69+
i++; // Skip the next argument
70+
}
71+
}
72+
return prompts;
73+
}
74+
75+
int main(int argc, char** argv) {
76+
std::vector<std::string> prompts = CollectPrompts(argc, argv);
77+
gflags::ParseCommandLineFlags(&argc, &argv, true);
78+
// create llama runner
79+
example::Runner runner(
80+
{FLAGS_model_path},
81+
FLAGS_tokenizer_path.c_str(),
82+
FLAGS_performance_output_path.c_str(),
83+
FLAGS_logits_scale,
84+
FLAGS_logits_offset,
85+
FLAGS_temperature,
86+
FLAGS_eval_mode,
87+
FLAGS_kv_updater,
88+
FLAGS_num_iters);
89+
std::vector<char> buf;
90+
buf.reserve(5 * FLAGS_seq_len); // assume each token is around 5 char
91+
std::ofstream fout(FLAGS_output_path.c_str());
92+
auto callback = [&](const std::string& piece) {
93+
for (const char c : piece) {
94+
buf.push_back(c);
95+
}
96+
};
97+
// generate tokens & store inference output
98+
for (int i = 0; i < FLAGS_num_iters; i++) {
99+
for (const auto& prompt : prompts) {
100+
runner.generate(
101+
FLAGS_seq_len, prompt.c_str(), FLAGS_system_prompt.c_str(), callback);
102+
}
103+
}
104+
fout.write(buf.data(), buf.size());
105+
fout.close();
106+
return 0;
107+
}

0 commit comments

Comments
 (0)