server : improve "prompt" handling (#7847)

This commit is contained in:
Georgi Gerganov 2024-06-10 14:59:55 +03:00 committed by GitHub
parent 1f0dabda8d
commit d9da0e4986
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -147,7 +147,7 @@ struct server_slot {
int32_t n_prompt_tokens = 0; int32_t n_prompt_tokens = 0;
int32_t n_prompt_tokens_processed = 0; int32_t n_prompt_tokens_processed = 0;
json prompt; std::string prompt;
// when a task is submitted, we first tokenize the prompt and store it here // when a task is submitted, we first tokenize the prompt and store it here
std::vector<llama_token> prompt_tokens; std::vector<llama_token> prompt_tokens;
@ -822,13 +822,8 @@ struct server_context {
continue; continue;
} }
// skip the slot if it does not contains prompt
if (!slot.prompt.is_string()) {
continue;
}
// current slot's prompt // current slot's prompt
std::string slot_prompt = slot.prompt.get<std::string>(); std::string slot_prompt = slot.prompt;
// length of the current slot's prompt // length of the current slot's prompt
int slot_prompt_len = slot_prompt.size(); int slot_prompt_len = slot_prompt.size();
@ -958,13 +953,16 @@ struct server_context {
if (!task.infill) { if (!task.infill) {
const auto & prompt = data.find("prompt"); const auto & prompt = data.find("prompt");
if (prompt == data.end()) { if (prompt == data.end()) {
send_error(task, "Either \"prompt\" or \"messages\" must be provided", ERROR_TYPE_INVALID_REQUEST); send_error(task, "\"prompt\" must be provided", ERROR_TYPE_INVALID_REQUEST);
return false; return false;
} else {
slot.prompt = *prompt;
} }
if (slot.prompt.is_array() && slot.prompt.size() == 0) {
send_error(task, "\"prompt\" cannot be an empty array", ERROR_TYPE_INVALID_REQUEST); if (prompt->is_string()) {
slot.prompt = prompt->get<std::string>();
} else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_string()) {
slot.prompt = prompt->at(0).get<std::string>();
} else {
send_error(task, "\"prompt\" must be a string or an array of strings", ERROR_TYPE_INVALID_REQUEST);
return false; return false;
} }
} }
@ -1582,14 +1580,18 @@ struct server_context {
switch (task.type) { switch (task.type) {
case SERVER_TASK_TYPE_COMPLETION: case SERVER_TASK_TYPE_COMPLETION:
{ {
int id_slot = json_value(task.data, "id_slot", -1); const int id_slot = json_value(task.data, "id_slot", -1);
std::string prompt = json_value(task.data, "prompt", std::string());
server_slot * slot; server_slot * slot;
if (id_slot != -1) { if (id_slot != -1) {
slot = get_slot_by_id(id_slot); slot = get_slot_by_id(id_slot);
} else { } else {
std::string prompt;
if (task.data.contains("prompt") && task.data.at("prompt").is_string()) {
json_value(task.data, "prompt", std::string());
}
slot = get_available_slot(prompt); slot = get_available_slot(prompt);
} }