diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 6ffaa8d9f..80714fa58 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -147,7 +147,7 @@ struct server_slot { int32_t n_prompt_tokens = 0; int32_t n_prompt_tokens_processed = 0; - json prompt; + std::string prompt; // when a task is submitted, we first tokenize the prompt and store it here std::vector prompt_tokens; @@ -822,13 +822,8 @@ struct server_context { continue; } - // skip the slot if it does not contains prompt - if (!slot.prompt.is_string()) { - continue; - } - // current slot's prompt - std::string slot_prompt = slot.prompt.get(); + std::string slot_prompt = slot.prompt; // length of the current slot's prompt int slot_prompt_len = slot_prompt.size(); @@ -958,13 +953,16 @@ struct server_context { if (!task.infill) { const auto & prompt = data.find("prompt"); if (prompt == data.end()) { - send_error(task, "Either \"prompt\" or \"messages\" must be provided", ERROR_TYPE_INVALID_REQUEST); + send_error(task, "\"prompt\" must be provided", ERROR_TYPE_INVALID_REQUEST); return false; - } else { - slot.prompt = *prompt; } - if (slot.prompt.is_array() && slot.prompt.size() == 0) { - send_error(task, "\"prompt\" cannot be an empty array", ERROR_TYPE_INVALID_REQUEST); + + if (prompt->is_string()) { + slot.prompt = prompt->get(); + } else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_string()) { + slot.prompt = prompt->at(0).get(); + } else { + send_error(task, "\"prompt\" must be a string or an array of strings", ERROR_TYPE_INVALID_REQUEST); return false; } } @@ -1582,14 +1580,18 @@ struct server_context { switch (task.type) { case SERVER_TASK_TYPE_COMPLETION: { - int id_slot = json_value(task.data, "id_slot", -1); - std::string prompt = json_value(task.data, "prompt", std::string()); + const int id_slot = json_value(task.data, "id_slot", -1); server_slot * slot; if (id_slot != -1) { slot = get_slot_by_id(id_slot); } else { + std::string prompt; + if (task.data.contains("prompt") && task.data.at("prompt").is_string()) { + json_value(task.data, "prompt", std::string()); + } + slot = get_available_slot(prompt); }