Skip to content

Commit 466b513

Browse files
committed
parallel : disable hot-plug to avoid cache fragmentation
1 parent 0161372 commit 466b513

File tree

2 files changed

+64
-31
lines changed

2 files changed

+64
-31
lines changed

examples/parallel/parallel.cpp

Lines changed: 60 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ static std::string trim(const std::string & str) {
2828
}
2929

3030
static std::string k_system = R"(
31-
Transcript of a dialog, where the User interacts with an Assistant.
31+
Transcript of a never ending dialog, where the User interacts with an Assistant.
3232
The Assistant is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.
3333
3434
User: Hello, what is the temperature outside?
@@ -59,6 +59,9 @@ struct client {
5959

6060
llama_token sampled;
6161

62+
int64_t t_start_prompt;
63+
int64_t t_start_gen;
64+
6265
int32_t n_prompt = 0;
6366
int32_t n_decoded = 0;
6467
int32_t i_batch = -1;
@@ -133,33 +136,47 @@ int main(int argc, char ** argv) {
133136

134137
for (auto & client : clients) {
135138
if (client.seq_id == -1) {
136-
client.seq_id = g_seq_id;
137-
client.input = k_prompts[rand() % k_prompts.size()];
138-
client.prompt = k_system + client.input + "\nAssistant:";
139-
client.response = "";
140-
std::fill(client.last_tokens.begin(), client.last_tokens.end(), 0);
141-
142-
std::vector<llama_token> prompt_tokens;
143-
prompt_tokens = ::llama_tokenize(ctx, client.prompt, true);
144-
145-
for (size_t i = 0; i < prompt_tokens.size(); ++i) {
146-
batch_token.push_back(prompt_tokens[i]);
147-
batch_pos.push_back(i);
148-
batch_seq_id.push_back(client.seq_id);
149-
batch_clients.push_back(&client);
139+
continue;
140+
}
141+
142+
batch_token.push_back(client.sampled);
143+
batch_pos.push_back(client.n_decoded);
144+
batch_seq_id.push_back(client.seq_id);
145+
batch_clients.push_back(&client);
146+
client.n_decoded += 1;
147+
client.i_batch = batch_token.size() - 1;
148+
}
149+
150+
if (batch_token.empty()) {
151+
// all sequences have ended - clear the entire KV cache
152+
llama_kv_cache_rm_tokens(ctx, -1, -1);
153+
154+
for (auto & client : clients) {
155+
if (client.seq_id == -1) {
156+
client.seq_id = g_seq_id;
157+
client.t_start_prompt = ggml_time_us();
158+
client.t_start_gen = 0;
159+
160+
client.input = k_prompts[rand() % k_prompts.size()];
161+
client.prompt = k_system + client.input + "\nAssistant:";
162+
client.response = "";
163+
std::fill(client.last_tokens.begin(), client.last_tokens.end(), 0);
164+
165+
std::vector<llama_token> prompt_tokens;
166+
prompt_tokens = ::llama_tokenize(ctx, client.prompt, true);
167+
168+
for (size_t i = 0; i < prompt_tokens.size(); ++i) {
169+
batch_token.push_back(prompt_tokens[i]);
170+
batch_pos.push_back(i);
171+
batch_seq_id.push_back(client.seq_id);
172+
batch_clients.push_back(&client);
173+
}
174+
client.n_prompt = prompt_tokens.size();
175+
client.n_decoded = prompt_tokens.size();
176+
client.i_batch = batch_token.size() - 1;
177+
178+
g_seq_id += 1;
150179
}
151-
client.n_prompt = prompt_tokens.size();
152-
client.n_decoded = prompt_tokens.size();
153-
client.i_batch = batch_token.size() - 1;
154-
155-
g_seq_id += 1;
156-
} else {
157-
batch_token.push_back(client.sampled);
158-
batch_pos.push_back(client.n_decoded);
159-
batch_seq_id.push_back(client.seq_id);
160-
batch_clients.push_back(&client);
161-
client.n_decoded += 1;
162-
client.i_batch = batch_token.size() - 1;
163180
}
164181
}
165182

@@ -188,6 +205,10 @@ int main(int argc, char ** argv) {
188205

189206
const llama_token id = llama_sample_token(ctx, NULL, NULL, params, client.last_tokens, candidates, client.i_batch - i);
190207

208+
if (client.t_start_gen == 0) {
209+
client.t_start_gen = ggml_time_us();
210+
}
211+
191212
// remember which tokens were sampled - used for repetition penalties during sampling
192213
client.last_tokens.erase(client.last_tokens.begin());
193214
client.last_tokens.push_back(id);
@@ -199,7 +220,10 @@ int main(int argc, char ** argv) {
199220
//printf("client %d, seq %d, token %d, pos %d, batch %d: %s\n",
200221
// client.id, client.seq_id, id, client.n_decoded, client.i_batch, token_str.c_str());
201222

202-
if (id == llama_token_eos(ctx) || client.n_decoded > params.n_predict || client.response.find("User:") != std::string::npos) {
223+
if (id == llama_token_eos(ctx) || client.n_decoded > params.n_predict ||
224+
client.response.find("User:") != std::string::npos ||
225+
client.response.find('\n') != std::string::npos) {
226+
// basic reverse prompt
203227
const size_t pos = client.response.find("User:");
204228
if (pos != std::string::npos) {
205229
client.response = client.response.substr(0, pos);
@@ -211,13 +235,18 @@ int main(int argc, char ** argv) {
211235

212236
n_tokens_total += client.n_decoded - client.n_prompt;
213237

214-
printf("\033[1mClient %d, seq %d, prompt %d t, response %d t, speed: %.2f t/s\033[0m: \n\nInput: %s\nResponse: %s\n\n",
238+
printf("\033[1mClient %2d, seq %4d, prompt %4d t, response %4d t, speed: PP %5.2f t/s, TG %5.2f, AVG %5.2f \033[0m: \n\nInput: %s\nResponse: %s\n\n",
215239
client.id, client.seq_id, client.n_prompt, client.n_decoded - client.n_prompt,
216-
(double) n_tokens_total / (t_main_end - t_main_start) * 1e6,
217-
client.input.c_str(), ::trim(client.response).c_str());
240+
(double) (client.n_prompt ) / (client.t_start_gen - client.t_start_prompt) * 1e6,
241+
(double) (client.n_decoded - client.n_prompt) / (t_main_end - client.t_start_gen) * 1e6,
242+
(double) (client.n_decoded ) / (t_main_end - client.t_start_prompt) * 1e6,
243+
::trim(client.input).c_str(),
244+
::trim(client.response).c_str());
218245

219246
client.seq_id = -1;
220247
}
248+
249+
client.i_batch = -1;
221250
}
222251
}
223252

llama.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2606,6 +2606,8 @@ static struct ggml_cgraph * llm_build_llama(
26062606
const int32_t n_tokens = batch.n_tokens;
26072607
const int32_t n_kv = llama_kv_cache_cell_max(kv_self);
26082608

2609+
//printf("n_kv = %d\n", n_kv);
2610+
26092611
const bool do_rope_shift = kv_self.has_shift || ggml_allocr_is_measure(lctx.alloc);
26102612

26112613
auto & buf_compute = lctx.buf_compute;
@@ -4052,6 +4054,8 @@ static bool llama_eval_internal(
40524054
batch.seq_id = seq_id.data();
40534055
}
40544056

4057+
kv_self.head = 0;
4058+
40554059
if (!llama_kv_cache_find_slot(kv_self, batch)) {
40564060
return false;
40574061
}

0 commit comments

Comments
 (0)