Skip to content

Commit fb1faaf

Browse files
authored
Support multiple prompts in the runner
Differential Revision: D72276104 Pull Request resolved: #9817
1 parent 2f408dd commit fb1faaf

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp

+18-6
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,22 @@ DEFINE_string(
5959
"SmartMask");
6060
DEFINE_int32(num_iters, 1, "total num of iterations to run.");
6161

62+
std::vector<std::string> CollectPrompts(int argc, char** argv) {
63+
// Collect all prompts from command line, example usage:
64+
// --prompt "prompt1" --prompt "prompt2" --prompt "prompt3"
65+
std::vector<std::string> prompts;
66+
for (int i = 1; i < argc; i++) {
67+
if (std::string(argv[i]) == "--prompt" && i + 1 < argc) {
68+
prompts.push_back(argv[i + 1]);
69+
i++; // Skip the next argument
70+
}
71+
}
72+
return prompts;
73+
}
74+
6275
int main(int argc, char** argv) {
76+
std::vector<std::string> prompts = CollectPrompts(argc, argv);
6377
gflags::ParseCommandLineFlags(&argc, &argv, true);
64-
6578
// create llama runner
6679
example::Runner runner(
6780
{FLAGS_model_path},
@@ -83,11 +96,10 @@ int main(int argc, char** argv) {
8396
};
8497
// generate tokens & store inference output
8598
for (int i = 0; i < FLAGS_num_iters; i++) {
86-
runner.generate(
87-
FLAGS_seq_len,
88-
FLAGS_prompt.c_str(),
89-
FLAGS_system_prompt.c_str(),
90-
callback);
99+
for (const auto& prompt : prompts) {
100+
runner.generate(
101+
FLAGS_seq_len, prompt.c_str(), FLAGS_system_prompt.c_str(), callback);
102+
}
91103
}
92104
fout.write(buf.data(), buf.size());
93105
fout.close();

0 commit comments

Comments
 (0)