Skip to content

Commit d87b48f

Browse files
committed
server: move all mutexes away from server.cpp
1 parent 58fe9cf commit d87b48f

File tree

2 files changed

+90
-85
lines changed

2 files changed

+90
-85
lines changed

examples/server/server.cpp

+25-72
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525

2626
#include <cstddef>
2727
#include <thread>
28-
#include <mutex>
2928
#include <chrono>
3029
#include <condition_variable>
3130
#include <atomic>
@@ -328,10 +327,8 @@ struct llama_server_context
328327
// slots / clients
329328
std::vector<llama_client_slot> slots;
330329

331-
llama_server_queue<task_server> queue_tasks;
332-
llama_server_response_event queue_results;
333-
std::vector<task_multi> queue_multitasks;
334-
std::mutex mutex_multitasks;
330+
llama_server_queue queue_tasks;
331+
llama_server_response queue_results;
335332

336333
~llama_server_context()
337334
{
@@ -961,30 +958,6 @@ struct llama_server_context
961958
queue_results.send(res);
962959
}
963960

964-
void add_multitask(int id, std::vector<int>& sub_ids)
965-
{
966-
std::lock_guard<std::mutex> lock(mutex_multitasks);
967-
task_multi multi;
968-
multi.id = id;
969-
std::copy(sub_ids.begin(), sub_ids.end(), std::inserter(multi.subtasks_remaining, multi.subtasks_remaining.end()));
970-
queue_multitasks.push_back(multi);
971-
// TODO @ngxson : Do we need to notify the queue_tasks?
972-
}
973-
974-
void update_multitask(int multitask_id, int subtask_id, task_result& result)
975-
{
976-
std::lock_guard<std::mutex> lock(mutex_multitasks);
977-
for (auto& multitask : queue_multitasks)
978-
{
979-
if (multitask.id == multitask_id)
980-
{
981-
multitask.subtasks_remaining.erase(subtask_id);
982-
multitask.results.push_back(result);
983-
// TODO @ngxson : Do we need to notify the queue_tasks?
984-
}
985-
}
986-
}
987-
988961
json get_model_props()
989962
{
990963
return get_formated_generation(slots[0]);
@@ -1120,7 +1093,7 @@ struct llama_server_context
11201093
// parent multitask, if any, needs to be updated
11211094
if (slot.multitask_id != -1)
11221095
{
1123-
update_multitask(slot.multitask_id, slot.task_id, res);
1096+
queue_tasks.update_multitask(slot.multitask_id, slot.task_id, res);
11241097
}
11251098
}
11261099

@@ -1157,7 +1130,6 @@ struct llama_server_context
11571130

11581131
int request_completion(json data, bool infill, bool embedding, int multitask_id)
11591132
{
1160-
std::unique_lock<std::mutex> lock(mutex_multitasks);
11611133
task_server task;
11621134
task.target_id = 0;
11631135
task.data = std::move(data);
@@ -1169,7 +1141,6 @@ struct llama_server_context
11691141
// when a completion task's prompt array is not a singleton, we split it into multiple requests
11701142
if (task.data.count("prompt") && task.data.at("prompt").size() > 1)
11711143
{
1172-
lock.unlock(); // entering new func scope
11731144
return split_multiprompt_task(task);
11741145
}
11751146

@@ -1270,11 +1241,11 @@ struct llama_server_context
12701241
}
12711242

12721243
// queue up the multitask so we can track its subtask progression
1273-
add_multitask(multitask_id, subtask_ids);
1244+
queue_tasks.add_multitask(multitask_id, subtask_ids);
12741245
return multitask_id;
12751246
}
12761247

1277-
void process_single_task(task_server task)
1248+
void process_single_task(task_server& task)
12781249
{
12791250
switch (task.type)
12801251
{
@@ -1283,7 +1254,7 @@ struct llama_server_context
12831254
if (slot == nullptr)
12841255
{
12851256
// if no slot is available, we defer this task for processing later
1286-
LOG_TEE("no slot\n");
1257+
LOG_VERBOSE("no slot is available", {});
12871258
queue_tasks.defer(task);
12881259
break;
12891260
}
@@ -1333,42 +1304,23 @@ struct llama_server_context
13331304
}
13341305
}
13351306

1336-
void process_multitask()
1307+
void on_finish_multitask(task_multi& multitask)
13371308
{
1338-
// remove finished multitasks from the queue of multitasks, and add the corresponding result to the result queue
1339-
std::vector<task_result> agg_results;
1340-
auto queue_iterator = queue_multitasks.begin();
1341-
while (queue_iterator != queue_multitasks.end())
1342-
{
1343-
if (queue_iterator->subtasks_remaining.empty())
1344-
{
1345-
// all subtasks done == multitask is done
1346-
task_result aggregate_result;
1347-
aggregate_result.id = queue_iterator->id;
1348-
aggregate_result.stop = true;
1349-
aggregate_result.error = false;
1350-
1351-
// collect json results into one json result
1352-
std::vector<json> result_jsons;
1353-
for (auto& subres : queue_iterator->results)
1354-
{
1355-
result_jsons.push_back(subres.result_json);
1356-
aggregate_result.error = aggregate_result.error && subres.error;
1357-
}
1358-
aggregate_result.result_json = json{ "results", result_jsons };
1359-
agg_results.push_back(aggregate_result);
1360-
queue_iterator = queue_multitasks.erase(queue_iterator);
1361-
}
1362-
else
1363-
{
1364-
++queue_iterator;
1365-
}
1366-
}
1309+
// all subtasks done == multitask is done
1310+
task_result result;
1311+
result.id = multitask.id;
1312+
result.stop = true;
1313+
result.error = false;
13671314

1368-
// copy aggregate results of complete multi-tasks to the results queue
1369-
for (auto& res : agg_results) {
1370-
queue_results.send(res);
1315+
// collect json results into one json result
1316+
std::vector<json> result_jsons;
1317+
for (auto& subres : multitask.results)
1318+
{
1319+
result_jsons.push_back(subres.result_json);
1320+
result.error = result.error && subres.error;
13711321
}
1322+
result.result_json = json{ "results", result_jsons };
1323+
queue_results.send(result);
13721324
}
13731325

13741326
bool update_slots() {
@@ -1704,7 +1656,6 @@ struct llama_server_context
17041656
}
17051657

17061658
void run_on_all_tasks_finished() {
1707-
process_multitask();
17081659
update_slots();
17091660
}
17101661
};
@@ -2861,16 +2812,18 @@ int main(int argc, char **argv)
28612812

28622813
llama.queue_tasks.on_new_task(std::bind(
28632814
&llama_server_context::process_single_task, &llama, std::placeholders::_1));
2815+
llama.queue_tasks.on_finish_multitask(std::bind(
2816+
&llama_server_context::on_finish_multitask, &llama, std::placeholders::_1));
28642817
llama.queue_tasks.on_all_tasks_finished(std::bind(
28652818
&llama_server_context::run_on_all_tasks_finished, &llama));
2866-
llama.queue_tasks.start_loop();
28672819
llama.queue_results.on_multitask_update(std::bind(
2868-
&llama_server_context::update_multitask,
2869-
&llama,
2820+
&llama_server_queue::update_multitask,
2821+
&llama.queue_tasks,
28702822
std::placeholders::_1,
28712823
std::placeholders::_2,
28722824
std::placeholders::_3
28732825
));
2826+
llama.queue_tasks.start_loop();
28742827

28752828
t.join();
28762829

examples/server/utils.hpp

+65-13
Original file line numberDiff line numberDiff line change
@@ -187,18 +187,21 @@ inline std::string format_chatml(std::vector<json> messages)
187187
// work queue utils
188188
//
189189

190-
template<typename T>
191190
struct llama_server_queue {
192191
int id = 0;
193192
std::mutex mutex_tasks;
194-
std::vector<T> queue_tasks;
195-
std::vector<T> queue_tasks_deferred;
193+
// queues
194+
std::vector<task_server> queue_tasks;
195+
std::vector<task_server> queue_tasks_deferred;
196+
std::vector<task_multi> queue_multitasks;
196197
std::condition_variable condition_tasks;
197-
std::function<void(T)> callback_new_task;
198+
// callback functions
199+
std::function<void(task_server&)> callback_new_task;
200+
std::function<void(task_multi&)> callback_finish_multitask;
198201
std::function<void(void)> callback_all_task_finished;
199202

200203
// Add a new task to the end of the queue
201-
int post(T task) {
204+
int post(task_server task) {
202205
std::unique_lock<std::mutex> lock(mutex_tasks);
203206
task.id = id++;
204207
queue_tasks.push_back(std::move(task));
@@ -207,7 +210,7 @@ struct llama_server_queue {
207210
}
208211

209212
// Add a new task, but defer until the next loop
210-
void defer(T task) {
213+
void defer(task_server task) {
211214
std::unique_lock<std::mutex> lock(mutex_tasks);
212215
queue_tasks_deferred.push_back(std::move(task));
213216
}
@@ -219,10 +222,15 @@ struct llama_server_queue {
219222
}
220223

221224
// Register function to process a new task
222-
void on_new_task(std::function<void(T)> callback) {
225+
void on_new_task(std::function<void(task_server&)> callback) {
223226
callback_new_task = callback;
224227
}
225228

229+
// Register function to process a multitask
230+
void on_finish_multitask(std::function<void(task_multi&)> callback) {
231+
callback_finish_multitask = callback;
232+
}
233+
226234
// Register the function to be called when the batch of tasks is finished
227235
void on_all_tasks_finished(std::function<void(void)> callback) {
228236
callback_all_task_finished = callback;
@@ -257,6 +265,24 @@ struct llama_server_queue {
257265
lock.unlock();
258266
}
259267
LOG_VERBOSE("callback_all_task_finished", {});
268+
// process and update all the multitasks
269+
auto queue_iterator = queue_multitasks.begin();
270+
while (queue_iterator != queue_multitasks.end())
271+
{
272+
if (queue_iterator->subtasks_remaining.empty())
273+
{
274+
// all subtasks done == multitask is done
275+
task_multi current_multitask = *queue_iterator;
276+
callback_finish_multitask(current_multitask);
277+
// remove this multitask
278+
queue_iterator = queue_multitasks.erase(queue_iterator);
279+
}
280+
else
281+
{
282+
++queue_iterator;
283+
}
284+
}
285+
// all tasks in the current loop is finished
260286
callback_all_task_finished();
261287
}
262288
LOG_VERBOSE("wait for new task", {});
@@ -271,26 +297,53 @@ struct llama_server_queue {
271297
}
272298
}
273299
}
300+
301+
//
302+
// functions to manage multitasks
303+
//
304+
305+
// add a multitask by specifying the id of all subtask (subtask is a task_server)
306+
void add_multitask(int multitask_id, std::vector<int>& sub_ids)
307+
{
308+
std::lock_guard<std::mutex> lock(mutex_tasks);
309+
task_multi multi;
310+
multi.id = multitask_id;
311+
std::copy(sub_ids.begin(), sub_ids.end(), std::inserter(multi.subtasks_remaining, multi.subtasks_remaining.end()));
312+
queue_multitasks.push_back(multi);
313+
}
314+
315+
// updatethe remaining subtasks, while appending results to multitask
316+
void update_multitask(int multitask_id, int subtask_id, task_result& result)
317+
{
318+
std::lock_guard<std::mutex> lock(mutex_tasks);
319+
for (auto& multitask : queue_multitasks)
320+
{
321+
if (multitask.id == multitask_id)
322+
{
323+
multitask.subtasks_remaining.erase(subtask_id);
324+
multitask.results.push_back(result);
325+
}
326+
}
327+
}
274328
};
275329

276-
struct llama_server_response_event {
330+
struct llama_server_response {
277331
typedef std::function<void(int, int, task_result&)> callback_multitask_t;
278332
callback_multitask_t callback_update_multitask;
279333
// for keeping track of all tasks waiting for the result
280-
std::mutex mutex_task_ids;
281334
std::set<int> waiting_task_ids;
282335
// the main result queue
283336
std::vector<task_result> queue_results;
284337
std::mutex mutex_results;
285338
std::condition_variable condition_results;
286339

287340
void add_waiting_task_id(int task_id) {
288-
std::unique_lock<std::mutex> lock(mutex_task_ids);
341+
std::unique_lock<std::mutex> lock(mutex_results);
289342
waiting_task_ids.insert(task_id);
290343
}
291344

292345
void remove_waiting_task_id(int task_id) {
293-
std::unique_lock<std::mutex> lock(mutex_task_ids);
346+
std::unique_lock<std::mutex> lock(mutex_results);
294347
waiting_task_ids.erase(task_id);
295348
}
296349

@@ -327,7 +380,6 @@ struct llama_server_response_event {
327380
// Send a new result to a waiting task_id
328381
void send(task_result result) {
329382
std::unique_lock<std::mutex> lock(mutex_results);
330-
std::unique_lock<std::mutex> lock1(mutex_task_ids);
331383
LOG_VERBOSE("send new result", {});
332384
for (auto& task_id : waiting_task_ids) {
333385
// LOG_TEE("waiting task id %i \n", task_id);
@@ -449,4 +501,4 @@ static std::string gen_chatcmplid()
449501
std::stringstream chatcmplid;
450502
chatcmplid << "chatcmpl-" << random_string();
451503
return chatcmplid.str();
452-
}
504+
}

0 commit comments

Comments
 (0)