@@ -28,7 +28,7 @@ static std::string trim(const std::string & str) {
28
28
}
29
29
30
30
static std::string k_system = R"(
31
- Transcript of a dialog, where the User interacts with an Assistant.
31
+ Transcript of a never ending dialog, where the User interacts with an Assistant.
32
32
The Assistant is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.
33
33
34
34
User: Hello, what is the temperature outside?
@@ -59,6 +59,9 @@ struct client {
59
59
60
60
llama_token sampled;
61
61
62
+ int64_t t_start_prompt;
63
+ int64_t t_start_gen;
64
+
62
65
int32_t n_prompt = 0 ;
63
66
int32_t n_decoded = 0 ;
64
67
int32_t i_batch = -1 ;
@@ -133,33 +136,47 @@ int main(int argc, char ** argv) {
133
136
134
137
for (auto & client : clients) {
135
138
if (client.seq_id == -1 ) {
136
- client.seq_id = g_seq_id;
137
- client.input = k_prompts[rand () % k_prompts.size ()];
138
- client.prompt = k_system + client.input + " \n Assistant:" ;
139
- client.response = " " ;
140
- std::fill (client.last_tokens .begin (), client.last_tokens .end (), 0 );
141
-
142
- std::vector<llama_token> prompt_tokens;
143
- prompt_tokens = ::llama_tokenize (ctx, client.prompt , true );
144
-
145
- for (size_t i = 0 ; i < prompt_tokens.size (); ++i) {
146
- batch_token.push_back (prompt_tokens[i]);
147
- batch_pos.push_back (i);
148
- batch_seq_id.push_back (client.seq_id );
149
- batch_clients.push_back (&client);
139
+ continue ;
140
+ }
141
+
142
+ batch_token.push_back (client.sampled );
143
+ batch_pos.push_back (client.n_decoded );
144
+ batch_seq_id.push_back (client.seq_id );
145
+ batch_clients.push_back (&client);
146
+ client.n_decoded += 1 ;
147
+ client.i_batch = batch_token.size () - 1 ;
148
+ }
149
+
150
+ if (batch_token.empty ()) {
151
+ // all sequences have ended - clear the entire KV cache
152
+ llama_kv_cache_rm_tokens (ctx, -1 , -1 );
153
+
154
+ for (auto & client : clients) {
155
+ if (client.seq_id == -1 ) {
156
+ client.seq_id = g_seq_id;
157
+ client.t_start_prompt = ggml_time_us ();
158
+ client.t_start_gen = 0 ;
159
+
160
+ client.input = k_prompts[rand () % k_prompts.size ()];
161
+ client.prompt = k_system + client.input + " \n Assistant:" ;
162
+ client.response = " " ;
163
+ std::fill (client.last_tokens .begin (), client.last_tokens .end (), 0 );
164
+
165
+ std::vector<llama_token> prompt_tokens;
166
+ prompt_tokens = ::llama_tokenize (ctx, client.prompt , true );
167
+
168
+ for (size_t i = 0 ; i < prompt_tokens.size (); ++i) {
169
+ batch_token.push_back (prompt_tokens[i]);
170
+ batch_pos.push_back (i);
171
+ batch_seq_id.push_back (client.seq_id );
172
+ batch_clients.push_back (&client);
173
+ }
174
+ client.n_prompt = prompt_tokens.size ();
175
+ client.n_decoded = prompt_tokens.size ();
176
+ client.i_batch = batch_token.size () - 1 ;
177
+
178
+ g_seq_id += 1 ;
150
179
}
151
- client.n_prompt = prompt_tokens.size ();
152
- client.n_decoded = prompt_tokens.size ();
153
- client.i_batch = batch_token.size () - 1 ;
154
-
155
- g_seq_id += 1 ;
156
- } else {
157
- batch_token.push_back (client.sampled );
158
- batch_pos.push_back (client.n_decoded );
159
- batch_seq_id.push_back (client.seq_id );
160
- batch_clients.push_back (&client);
161
- client.n_decoded += 1 ;
162
- client.i_batch = batch_token.size () - 1 ;
163
180
}
164
181
}
165
182
@@ -188,6 +205,10 @@ int main(int argc, char ** argv) {
188
205
189
206
const llama_token id = llama_sample_token (ctx, NULL , NULL , params, client.last_tokens , candidates, client.i_batch - i);
190
207
208
+ if (client.t_start_gen == 0 ) {
209
+ client.t_start_gen = ggml_time_us ();
210
+ }
211
+
191
212
// remember which tokens were sampled - used for repetition penalties during sampling
192
213
client.last_tokens .erase (client.last_tokens .begin ());
193
214
client.last_tokens .push_back (id);
@@ -199,7 +220,10 @@ int main(int argc, char ** argv) {
199
220
// printf("client %d, seq %d, token %d, pos %d, batch %d: %s\n",
200
221
// client.id, client.seq_id, id, client.n_decoded, client.i_batch, token_str.c_str());
201
222
202
- if (id == llama_token_eos (ctx) || client.n_decoded > params.n_predict || client.response .find (" User:" ) != std::string::npos) {
223
+ if (id == llama_token_eos (ctx) || client.n_decoded > params.n_predict ||
224
+ client.response .find (" User:" ) != std::string::npos ||
225
+ client.response .find (' \n ' ) != std::string::npos) {
226
+ // basic reverse prompt
203
227
const size_t pos = client.response .find (" User:" );
204
228
if (pos != std::string::npos) {
205
229
client.response = client.response .substr (0 , pos);
@@ -211,13 +235,18 @@ int main(int argc, char ** argv) {
211
235
212
236
n_tokens_total += client.n_decoded - client.n_prompt ;
213
237
214
- printf (" \033 [1mClient %d , seq %d , prompt %d t, response %d t, speed: % .2f t/s\033 [0m: \n\n Input: %s\n Response: %s\n\n " ,
238
+ printf (" \033 [1mClient %2d , seq %4d , prompt %4d t, response %4d t, speed: PP %5 .2f t/s, TG %5.2f, AVG %5.2f \033 [0m: \n\n Input: %s\n Response: %s\n\n " ,
215
239
client.id , client.seq_id , client.n_prompt , client.n_decoded - client.n_prompt ,
216
- (double ) n_tokens_total / (t_main_end - t_main_start) * 1e6 ,
217
- client.input .c_str (), ::trim (client.response ).c_str ());
240
+ (double ) (client.n_prompt ) / (client.t_start_gen - client.t_start_prompt ) * 1e6 ,
241
+ (double ) (client.n_decoded - client.n_prompt ) / (t_main_end - client.t_start_gen ) * 1e6 ,
242
+ (double ) (client.n_decoded ) / (t_main_end - client.t_start_prompt ) * 1e6 ,
243
+ ::trim (client.input).c_str(),
244
+ ::trim(client.response).c_str());
218
245
219
246
client.seq_id = -1 ;
220
247
}
248
+
249
+ client.i_batch = -1 ;
221
250
}
222
251
}
223
252
0 commit comments