@@ -59,9 +59,22 @@ DEFINE_string(
59
59
" SmartMask" );
60
60
DEFINE_int32 (num_iters, 1 , " total num of iterations to run." );
61
61
62
+ std::vector<std::string> CollectPrompts (int argc, char ** argv) {
63
+ // Collect all prompts from command line, example usage:
64
+ // --prompt "prompt1" --prompt "prompt2" --prompt "prompt3"
65
+ std::vector<std::string> prompts;
66
+ for (int i = 1 ; i < argc; i++) {
67
+ if (std::string (argv[i]) == " --prompt" && i + 1 < argc) {
68
+ prompts.push_back (argv[i + 1 ]);
69
+ i++; // Skip the next argument
70
+ }
71
+ }
72
+ return prompts;
73
+ }
74
+
62
75
int main (int argc, char ** argv) {
76
+ std::vector<std::string> prompts = CollectPrompts (argc, argv);
63
77
gflags::ParseCommandLineFlags (&argc, &argv, true );
64
-
65
78
// create llama runner
66
79
example::Runner runner (
67
80
{FLAGS_model_path},
@@ -83,11 +96,10 @@ int main(int argc, char** argv) {
83
96
};
84
97
// generate tokens & store inference output
85
98
for (int i = 0 ; i < FLAGS_num_iters; i++) {
86
- runner.generate (
87
- FLAGS_seq_len,
88
- FLAGS_prompt.c_str (),
89
- FLAGS_system_prompt.c_str (),
90
- callback);
99
+ for (const auto & prompt : prompts) {
100
+ runner.generate (
101
+ FLAGS_seq_len, prompt.c_str (), FLAGS_system_prompt.c_str (), callback);
102
+ }
91
103
}
92
104
fout.write (buf.data (), buf.size ());
93
105
fout.close ();
0 commit comments