server : add single-client multi-prompt support (#4232)

* * add multiprompt support

* * cleanup

* * more cleanup

* * remove atomicity of id_gen, and change lock_guard to unique_lock on completion requests

* * remove all references to mutex_multitasks

* Update examples/server/server.cpp

Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com>

* Update examples/server/server.cpp

Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com>

* Update examples/server/server.cpp

Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com>

* Update examples/server/server.cpp

Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com>

* * change to set

---------

Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com>
This commit is contained in:
Ziad Ben Hadj-Alouane 2023-11-30 17:25:04 -05:00 committed by GitHub
parent d2809a3ba2
commit f43f09366d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -155,15 +155,23 @@ struct task_server {
json data; json data;
bool infill_mode = false; bool infill_mode = false;
bool embedding_mode = false; bool embedding_mode = false;
int multitask_id = -1;
}; };
struct task_result { struct task_result {
int id; int id;
int multitask_id = -1;
bool stop; bool stop;
bool error; bool error;
json result_json; json result_json;
}; };
struct task_multi {
int id;
std::set<int> subtasks_remaining{};
std::vector<task_result> results{};
};
// TODO: can become bool if we can't find use of more states // TODO: can become bool if we can't find use of more states
enum slot_state enum slot_state
{ {
@ -406,6 +414,9 @@ struct llama_client_slot
double t_prompt_processing; // ms double t_prompt_processing; // ms
double t_token_generation; // ms double t_token_generation; // ms
// multitasks
int multitask_id = -1;
void reset() { void reset() {
num_prompt_tokens = 0; num_prompt_tokens = 0;
generated_text = ""; generated_text = "";
@ -529,7 +540,8 @@ struct llama_server_context
std::vector<task_server> queue_tasks; std::vector<task_server> queue_tasks;
std::vector<task_result> queue_results; std::vector<task_result> queue_results;
std::mutex mutex_tasks; std::vector<task_multi> queue_multitasks;
std::mutex mutex_tasks; // also guards id_gen, and queue_multitasks
std::mutex mutex_results; std::mutex mutex_results;
~llama_server_context() ~llama_server_context()
@ -1112,17 +1124,40 @@ struct llama_server_context
return slot.images.size() > 0; return slot.images.size() > 0;
} }
void send_error(int id, std::string error) void send_error(task_server& task, std::string error)
{ {
std::lock_guard<std::mutex> lock(mutex_results); std::lock_guard<std::mutex> lock(mutex_results);
task_result res; task_result res;
res.id = id; res.id = task.id;
res.multitask_id = task.multitask_id;
res.stop = false; res.stop = false;
res.error = true; res.error = true;
res.result_json = { { "content", error } }; res.result_json = { { "content", error } };
queue_results.push_back(res); queue_results.push_back(res);
} }
void add_multi_task(int id, std::vector<int>& sub_ids)
{
std::lock_guard<std::mutex> lock(mutex_tasks);
task_multi multi;
multi.id = id;
std::copy(sub_ids.begin(), sub_ids.end(), std::inserter(multi.subtasks_remaining, multi.subtasks_remaining.end()));
queue_multitasks.push_back(multi);
}
void update_multi_task(int multitask_id, int subtask_id, task_result& result)
{
std::lock_guard<std::mutex> lock(mutex_tasks);
for (auto& multitask : queue_multitasks)
{
if (multitask.id == multitask_id)
{
multitask.subtasks_remaining.erase(subtask_id);
multitask.results.push_back(result);
}
}
}
json get_model_props() json get_model_props()
{ {
return get_formated_generation(slots[0]); return get_formated_generation(slots[0]);
@ -1167,6 +1202,7 @@ struct llama_server_context
std::lock_guard<std::mutex> lock(mutex_results); std::lock_guard<std::mutex> lock(mutex_results);
task_result res; task_result res;
res.id = slot.task_id; res.id = slot.task_id;
res.multitask_id = slot.multitask_id;
res.error = false; res.error = false;
res.stop = false; res.stop = false;
@ -1206,6 +1242,7 @@ struct llama_server_context
std::lock_guard<std::mutex> lock(mutex_results); std::lock_guard<std::mutex> lock(mutex_results);
task_result res; task_result res;
res.id = slot.task_id; res.id = slot.task_id;
res.multitask_id = slot.multitask_id;
res.error = false; res.error = false;
res.stop = true; res.stop = true;
@ -1251,6 +1288,12 @@ struct llama_server_context
res.result_json["model"] = slot.oaicompat_model; res.result_json["model"] = slot.oaicompat_model;
} }
// parent multitask, if any, needs to be updated
if (slot.multitask_id != -1)
{
update_multi_task(slot.multitask_id, slot.task_id, res);
}
queue_results.push_back(res); queue_results.push_back(res);
} }
@ -1259,6 +1302,7 @@ struct llama_server_context
std::lock_guard<std::mutex> lock(mutex_results); std::lock_guard<std::mutex> lock(mutex_results);
task_result res; task_result res;
res.id = slot.task_id; res.id = slot.task_id;
res.multitask_id = slot.multitask_id;
res.error = false; res.error = false;
res.stop = true; res.stop = true;
@ -1285,9 +1329,9 @@ struct llama_server_context
queue_results.push_back(res); queue_results.push_back(res);
} }
int request_completion(json data, bool infill, bool embedding) int request_completion(json data, bool infill, bool embedding, int multitask_id)
{ {
std::lock_guard<std::mutex> lock(mutex_tasks); std::unique_lock<std::mutex> lock(mutex_tasks);
task_server task; task_server task;
task.id = id_gen++; task.id = id_gen++;
task.target_id = 0; task.target_id = 0;
@ -1295,6 +1339,16 @@ struct llama_server_context
task.infill_mode = infill; task.infill_mode = infill;
task.embedding_mode = embedding; task.embedding_mode = embedding;
task.type = COMPLETION_TASK; task.type = COMPLETION_TASK;
task.multitask_id = multitask_id;
// when a completion task's prompt array is not a singleton, we split it into multiple requests
if (task.data.at("prompt").size() > 1)
{
lock.unlock(); // entering new func scope
return split_multiprompt_task(task);
}
// otherwise, it's a single-prompt task, we actually queue it
queue_tasks.push_back(task); queue_tasks.push_back(task);
return task.id; return task.id;
} }
@ -1313,8 +1367,17 @@ struct llama_server_context
for (int i = 0; i < (int) queue_results.size(); i++) for (int i = 0; i < (int) queue_results.size(); i++)
{ {
// for now, tasks that have associated parent multitasks just get erased once multitask picks up the result
if (queue_results[i].multitask_id == task_id)
{
update_multi_task(task_id, queue_results[i].id, queue_results[i]);
queue_results.erase(queue_results.begin() + i);
continue;
}
if (queue_results[i].id == task_id) if (queue_results[i].id == task_id)
{ {
assert(queue_results[i].multitask_id == -1);
task_result res = queue_results[i]; task_result res = queue_results[i];
queue_results.erase(queue_results.begin() + i); queue_results.erase(queue_results.begin() + i);
return res; return res;
@ -1404,6 +1467,27 @@ struct llama_server_context
queue_tasks.push_back(task); queue_tasks.push_back(task);
} }
int split_multiprompt_task(task_server& multiprompt_task)
{
auto prompt_count = multiprompt_task.data.at("prompt").size();
assert(prompt_count > 1);
int multitask_id = id_gen++;
std::vector<int> subtask_ids(prompt_count);
for (int i = 0; i < prompt_count; i++)
{
json subtask_data = multiprompt_task.data;
subtask_data["prompt"] = subtask_data["prompt"][i];
// subtasks inherit everything else (infill mode, embedding mode, etc.)
subtask_ids[i] = request_completion(subtask_data, multiprompt_task.infill_mode, multiprompt_task.embedding_mode, multitask_id);
}
// queue up the multitask so we can track its subtask progression
add_multi_task(multitask_id, subtask_ids);
return multitask_id;
}
void process_tasks() void process_tasks()
{ {
std::lock_guard<std::mutex> lock(mutex_tasks); std::lock_guard<std::mutex> lock(mutex_tasks);
@ -1419,7 +1503,7 @@ struct llama_server_context
{ {
LOG_TEE("slot unavailable\n"); LOG_TEE("slot unavailable\n");
// send error result // send error result
send_error(task.id, "slot unavailable"); send_error(task, "slot unavailable");
return; return;
} }
@ -1433,11 +1517,12 @@ struct llama_server_context
slot->infill = task.infill_mode; slot->infill = task.infill_mode;
slot->embedding = task.embedding_mode; slot->embedding = task.embedding_mode;
slot->task_id = task.id; slot->task_id = task.id;
slot->multitask_id = task.multitask_id;
if (!launch_slot_with_data(slot, task.data)) if (!launch_slot_with_data(slot, task.data))
{ {
// send error result // send error result
send_error(task.id, "internal_error"); send_error(task, "internal_error");
break; break;
} }
} break; } break;
@ -1453,6 +1538,38 @@ struct llama_server_context
} break; } break;
} }
} }
// remove finished multitasks from the queue of multitasks, and add the corresponding result to the result queue
auto queue_iterator = queue_multitasks.begin();
while (queue_iterator != queue_multitasks.end())
{
if (queue_iterator->subtasks_remaining.empty())
{
// all subtasks done == multitask is done
task_result aggregate_result;
aggregate_result.id = queue_iterator->id;
aggregate_result.stop = true;
aggregate_result.error = false;
// collect json results into one json result
std::vector<json> result_jsons;
for (auto& subres : queue_iterator->results)
{
result_jsons.push_back(subres.result_json);
aggregate_result.error = aggregate_result.error && subres.error;
}
aggregate_result.result_json = json{ "results", result_jsons };
std::lock_guard<std::mutex> lock(mutex_results);
queue_results.push_back(aggregate_result);
queue_iterator = queue_multitasks.erase(queue_iterator);
}
else
{
++queue_iterator;
}
}
} }
bool update_slots() { bool update_slots() {
@ -2596,7 +2713,7 @@ int main(int argc, char **argv)
svr.Post("/completion", [&llama](const httplib::Request &req, httplib::Response &res) svr.Post("/completion", [&llama](const httplib::Request &req, httplib::Response &res)
{ {
json data = json::parse(req.body); json data = json::parse(req.body);
const int task_id = llama.request_completion(data, false, false); const int task_id = llama.request_completion(data, false, false, -1);
if (!json_value(data, "stream", false)) { if (!json_value(data, "stream", false)) {
std::string completion_text; std::string completion_text;
task_result result = llama.next_result(task_id); task_result result = llama.next_result(task_id);
@ -2685,7 +2802,7 @@ int main(int argc, char **argv)
{ {
json data = oaicompat_completion_params_parse(json::parse(req.body)); json data = oaicompat_completion_params_parse(json::parse(req.body));
const int task_id = llama.request_completion(data, false, false); const int task_id = llama.request_completion(data, false, false, -1);
if (!json_value(data, "stream", false)) { if (!json_value(data, "stream", false)) {
std::string completion_text; std::string completion_text;
@ -2754,7 +2871,7 @@ int main(int argc, char **argv)
svr.Post("/infill", [&llama](const httplib::Request &req, httplib::Response &res) svr.Post("/infill", [&llama](const httplib::Request &req, httplib::Response &res)
{ {
json data = json::parse(req.body); json data = json::parse(req.body);
const int task_id = llama.request_completion(data, true, false); const int task_id = llama.request_completion(data, true, false, -1);
if (!json_value(data, "stream", false)) { if (!json_value(data, "stream", false)) {
std::string completion_text; std::string completion_text;
task_result result = llama.next_result(task_id); task_result result = llama.next_result(task_id);
@ -2858,7 +2975,7 @@ int main(int argc, char **argv)
{ {
prompt = ""; prompt = "";
} }
const int task_id = llama.request_completion({ {"prompt", prompt}, { "n_predict", 0} }, false, true); const int task_id = llama.request_completion({ {"prompt", prompt}, { "n_predict", 0} }, false, true, -1);
task_result result = llama.next_result(task_id); task_result result = llama.next_result(task_id);
return res.set_content(result.result_json.dump(), "application/json"); return res.set_content(result.result_json.dump(), "application/json");
}); });