server : fix infill when prompt is empty (#4833)

This commit is contained in:
Georgi Gerganov 2024-01-11 23:23:49 +02:00 committed by GitHub
parent 7edefbd79c
commit 1d118386fe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1406,7 +1406,7 @@ struct llama_server_context
task.multitask_id = multitask_id; task.multitask_id = multitask_id;
// when a completion task's prompt array is not a singleton, we split it into multiple requests // when a completion task's prompt array is not a singleton, we split it into multiple requests
if (task.data.at("prompt").size() > 1) if (task.data.count("prompt") && task.data.at("prompt").size() > 1)
{ {
lock.unlock(); // entering new func scope lock.unlock(); // entering new func scope
return split_multiprompt_task(task); return split_multiprompt_task(task);
@ -1577,9 +1577,9 @@ struct llama_server_context
slot->reset(); slot->reset();
slot->infill = task.infill_mode; slot->infill = task.infill_mode;
slot->embedding = task.embedding_mode; slot->embedding = task.embedding_mode;
slot->task_id = task.id; slot->task_id = task.id;
slot->multitask_id = task.multitask_id; slot->multitask_id = task.multitask_id;
if (!launch_slot_with_data(slot, task.data)) if (!launch_slot_with_data(slot, task.data))
@ -1731,7 +1731,8 @@ struct llama_server_context
const bool has_prompt = slot.prompt.is_array() || (slot.prompt.is_string() && !slot.prompt.get<std::string>().empty()) || !slot.images.empty(); const bool has_prompt = slot.prompt.is_array() || (slot.prompt.is_string() && !slot.prompt.get<std::string>().empty()) || !slot.images.empty();
// empty prompt passed -> release the slot and send empty response // empty prompt passed -> release the slot and send empty response
if (slot.state == IDLE && slot.command == LOAD_PROMPT && !has_prompt) // note: infill mode allows empty prompt
if (slot.state == IDLE && slot.command == LOAD_PROMPT && !has_prompt && !slot.infill)
{ {
slot.release(); slot.release();
slot.print_timings(); slot.print_timings();
@ -2609,8 +2610,8 @@ static json format_final_response_oaicompat(const json &request, const task_resu
{"object", streaming ? "chat.completion.chunk" : "chat.completion"}, {"object", streaming ? "chat.completion.chunk" : "chat.completion"},
{"usage", {"usage",
json{{"completion_tokens", num_tokens_predicted}, json{{"completion_tokens", num_tokens_predicted},
{"prompt_tokens", num_prompt_tokens}, {"prompt_tokens", num_prompt_tokens},
{"total_tokens", num_tokens_predicted + num_prompt_tokens}}}, {"total_tokens", num_tokens_predicted + num_prompt_tokens}}},
{"id", gen_chatcmplid()}}; {"id", gen_chatcmplid()}};
if (server_verbose) { if (server_verbose) {