server : various fixes for the prompt field in /completion (#5300)

server : fix deadlock when prompt array contains strings and numbers

server : removed an unnecessary generation when generating multi-prompts

server : removed an unnecessary assert
This commit is contained in:
Niall Coates 2024-02-06 08:16:23 +00:00 committed by GitHub
parent 906cff55c2
commit 4ffc7a17d4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1163,13 +1163,30 @@ struct llama_server_context
task.multitask_id = multitask_id; task.multitask_id = multitask_id;
// when a completion task's prompt array is not a singleton, we split it into multiple requests // when a completion task's prompt array is not a singleton, we split it into multiple requests
if (task.data.count("prompt") && task.data.at("prompt").size() > 1) // otherwise, it's a single-prompt task, we actually queue it
{ // if there's numbers in the prompt array it will be treated as an array of tokens
split_multiprompt_task(task_id, task); if (task.data.count("prompt") != 0 && task.data.at("prompt").size() > 1) {
bool numbers = false;
for (const auto& e : task.data.at("prompt")) {
if (e.is_number()) {
numbers = true;
break;
}
} }
// otherwise, it's a single-prompt task, we actually queue it // NOTE: split_multiprompt_task() does not handle a mix of strings and numbers,
// it will completely stall the server. I don't know where the bug for this is.
//
// if there are numbers, it needs to be treated like a single prompt,
// queue_tasks handles a mix of strings and numbers just fine.
if (numbers) {
queue_tasks.post(task); queue_tasks.post(task);
} else {
split_multiprompt_task(task_id, task);
}
} else {
queue_tasks.post(task);
}
} }
// for multiple images processing // for multiple images processing
@ -1251,7 +1268,10 @@ struct llama_server_context
void split_multiprompt_task(int multitask_id, task_server& multiprompt_task) void split_multiprompt_task(int multitask_id, task_server& multiprompt_task)
{ {
int prompt_count = multiprompt_task.data.at("prompt").size(); int prompt_count = multiprompt_task.data.at("prompt").size();
assert(prompt_count > 1); if (prompt_count <= 1) {
send_error(multiprompt_task, "error while handling multiple prompts");
return;
}
// generate all the ID for subtask // generate all the ID for subtask
std::vector<int> subtask_ids(prompt_count); std::vector<int> subtask_ids(prompt_count);