diff --git a/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp b/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp index d7f0d85156c..f23cf2ec44a 100644 --- a/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp @@ -59,9 +59,22 @@ DEFINE_string( "SmartMask"); DEFINE_int32(num_iters, 1, "total num of iterations to run."); +std::vector CollectPrompts(int argc, char** argv) { + // Collect all prompts from command line, example usage: + // --prompt "prompt1" --prompt "prompt2" --prompt "prompt3" + std::vector prompts; + for (int i = 1; i < argc; i++) { + if (std::string(argv[i]) == "--prompt" && i + 1 < argc) { + prompts.push_back(argv[i + 1]); + i++; // Skip the next argument + } + } + return prompts; +} + int main(int argc, char** argv) { + std::vector prompts = CollectPrompts(argc, argv); gflags::ParseCommandLineFlags(&argc, &argv, true); - // create llama runner example::Runner runner( {FLAGS_model_path}, @@ -83,11 +96,10 @@ int main(int argc, char** argv) { }; // generate tokens & store inference output for (int i = 0; i < FLAGS_num_iters; i++) { - runner.generate( - FLAGS_seq_len, - FLAGS_prompt.c_str(), - FLAGS_system_prompt.c_str(), - callback); + for (const auto& prompt : prompts) { + runner.generate( + FLAGS_seq_len, prompt.c_str(), FLAGS_system_prompt.c_str(), callback); + } } fout.write(buf.data(), buf.size()); fout.close();