From 6b2437e32d146ae2012bf81730b35ff85a2e6347 Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Fri, 20 Oct 2023 12:07:32 -0400 Subject: [PATCH] added thread safe pipeline --- examples/server/server.cpp | 1081 ++++++++++++++++++------------------ 1 file changed, 532 insertions(+), 549 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 851a3d314..3fa60e0b9 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -23,6 +23,8 @@ #include #include +#include +#include #include #ifndef SERVER_VERBOSE @@ -137,6 +139,26 @@ static std::vector base64_decode(std::string const &encoded_string) // parallel // +enum task_type { + COMPLETION_TASK, + CANCEL_TASK +}; + +struct task_server { + int id; + int target_id; + task_type type; + json data; + bool infill_mode = false; +}; + +struct task_result { + int id; + bool stop; + bool error; + json result_json; +}; + enum slot_state { IDLE, @@ -309,6 +331,15 @@ static json probs_vector_to_json(const llama_context *ctx, const std::vector +static T json_value(const json &body, const std::string &key, const T &default_value) +{ + // Fallback null to default value + return body.contains(key) && !body.at(key).is_null() + ? body.value(key, default_value) + : default_value; +} + // TODO: this is not needed, should reuse llama_sampling_init from common/sampling.h static struct llama_sampling_context * llama_sampling_init_srv(const struct llama_sampling_params &sparams, const std::string &grammar, int n_ctx) { @@ -342,6 +373,7 @@ static struct llama_sampling_context * llama_sampling_init_srv(const struct llam struct llama_client_slot { int id; + int task_id = -1; // generation props int32_t n_past = 0; int32_t n_decoded = 0; @@ -352,11 +384,9 @@ struct llama_client_slot json prompt; std::string generated_text; - int num_tokens_predicted = 0; llama_token sampled; std::vector cache_tokens; std::vector generated_token_probs; - int sent_tokens = 0; slot_state state = IDLE; slot_command command = NONE; bool truncated = false; @@ -366,6 +396,8 @@ struct llama_client_slot std::string stopping_word; int32_t multibyte_pending = 0; size_t sent_count = 0; + size_t sent_token_probs_index = 0; + bool infill = false; int64_t t_start_process_prompt; int64_t t_start_genereration; @@ -395,8 +427,9 @@ struct llama_client_slot multibyte_pending = 0; n_past = 0; sent_count = 0; + sent_token_probs_index = 0; infill = false; - clean_tokens(); + generated_token_probs.clear(); if (ctx_sampling != nullptr) { @@ -440,10 +473,6 @@ struct llama_client_slot return n_remaining > 0 || n_remaining == -1; // no budget || limitless } - bool has_new_token() const { - return num_tokens_predicted > sent_tokens; - } - bool available() const { return state == IDLE && command == NONE; } @@ -452,21 +481,13 @@ struct llama_client_slot return ((state == IDLE || state == SLEEPING) && command == LOAD_PROMPT) || state == PROCESSING; } - completion_token_output next() { - completion_token_output tkn = generated_token_probs.at(sent_tokens); - sent_tokens++; - return tkn; - } - void add_token_string(const completion_token_output &token) { if (command == RELEASE) { - num_tokens_predicted = 0; return; } cache_tokens.push_back(token.tok); generated_token_probs.push_back(token); - num_tokens_predicted++; } void release() { @@ -477,11 +498,28 @@ struct llama_client_slot } } - void clean_tokens() - { - sent_tokens = 0; - generated_token_probs.clear(); - num_tokens_predicted = 0; + json get_formated_timings() { + return json + { + {"prompt_n", num_prompt_tokens_processed}, + {"prompt_ms", t_prompt_processing}, + {"prompt_per_token_ms", t_prompt_processing / num_prompt_tokens_processed}, + {"prompt_per_second", 1e3 / t_prompt_processing * num_prompt_tokens_processed}, + + {"predicted_n", n_decoded}, + {"predicted_ms", t_token_generation}, + {"predicted_per_token_ms", t_token_generation / n_decoded}, + {"predicted_per_second", 1e3 / t_token_generation * n_decoded}, + }; + } + + void print_timings() { + LOG_TEE("\n"); + LOG_TEE("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n", + __func__, t_prompt_processing, num_prompt_tokens_processed, t_prompt_processing / num_prompt_tokens_processed, 1e3 / t_prompt_processing * num_prompt_tokens_processed); + LOG_TEE("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", + __func__, t_token_generation, n_decoded,t_token_generation / n_decoded, 1e3 / t_token_generation * n_decoded); + LOG_TEE("%s: total time = %10.2f ms\n", __func__, t_prompt_processing + t_token_generation); } }; @@ -513,6 +551,13 @@ struct llama_server_context int max_ctx_per_slot = -1; bool clean_kv_cache = true; + std::atomic id_gen; + + std::vector queue_tasks; + std::vector queue_results; + std::mutex mutex_tasks; + std::mutex mutex_results; + ~llama_server_context() { if (ctx) @@ -566,6 +611,7 @@ struct llama_server_context } void initialize() { + id_gen.store(0); // reset ids to 0 // create slots all_slots_are_idle = true; if(max_ctx_per_slot == -1) { @@ -647,14 +693,179 @@ struct llama_server_context return nullptr; } - bool launch_slot(llama_client_slot* &slot) { + bool launch_slot_with_data(llama_client_slot* &slot, json data) { + slot_params default_params; + llama_sampling_params default_sparams; + slot->params.stream = json_value(data, "stream", false); + slot->params.cache_prompt = json_value(data, "cache_prompt", false); + slot->params.n_predict = json_value(data, "n_predict", default_params.n_predict); + slot->sparams.top_k = json_value(data, "top_k", default_sparams.top_k); + slot->sparams.top_p = json_value(data, "top_p", default_sparams.top_p); + slot->sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z); + slot->sparams.typical_p = json_value(data, "typical_p", default_sparams.typical_p); + slot->sparams.repeat_last_n = json_value(data, "repeat_last_n", default_sparams.repeat_last_n); + slot->sparams.temp = json_value(data, "temperature", default_sparams.temp); + slot->sparams.repeat_penalty = json_value(data, "repeat_penalty", default_sparams.repeat_penalty); + slot->sparams.presence_penalty = json_value(data, "presence_penalty", default_sparams.presence_penalty); + slot->sparams.frequency_penalty = json_value(data, "frequency_penalty", default_sparams.frequency_penalty); + slot->sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat); + slot->sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau); + slot->sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta); + slot->sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl); + slot->params.n_keep = json_value(data, "n_keep", slot->params.n_keep); + slot->params.seed = json_value(data, "seed", default_params.seed); + slot->params.grammar = json_value(data, "grammar", default_params.grammar); + slot->sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs); + + // infill + if (data.count("input_prefix") != 0) + { + slot->params.input_prefix = data["input_prefix"]; + } + else + { + slot->params.input_prefix = ""; + } + if (data.count("input_suffix") != 0) + { + slot->params.input_suffix = data["input_suffix"]; + } + + // common params + else + { + slot->params.input_suffix = ""; + } + if (data.count("prompt") != 0) + { + slot->prompt = data["prompt"]; + } + else + { + slot->prompt = ""; + } + slot->sparams.logit_bias.clear(); + if (json_value(data, "ignore_eos", false)) + { + slot->sparams.logit_bias[llama_token_eos(ctx)] = -INFINITY; + } + const auto &logit_bias = data.find("logit_bias"); + if (logit_bias != data.end() && logit_bias->is_array()) + { + const int n_vocab = llama_n_vocab(model); + for (const auto &el : *logit_bias) + { + if (el.is_array() && el.size() == 2 && el[0].is_number_integer()) + { + llama_token tok = el[0].get(); + if (tok >= 0 && tok < n_vocab) + { + if (el[1].is_number()) + { + slot->sparams.logit_bias[tok] = el[1].get(); + } + else if (el[1].is_boolean() && !el[1].get()) + { + slot->sparams.logit_bias[tok] = -INFINITY; + } + } + } + } + } + + slot->params.antiprompt.clear(); + const auto &stop = data.find("stop"); + if (stop != data.end() && stop->is_array()) + { + for (const auto &word : *stop) + { + if (!word.empty()) + { + slot->params.antiprompt.push_back(word); + } + } + } + + if(multimodal) + { + const auto &images_data = data.find("image_data"); + if (images_data != data.end() && images_data->is_array()) + { + for (const auto &img : *images_data) + { + slot_image img_sl; + std::string data_b64 = img["data"].get(); + img_sl.id = img.count("id") != 0 ? img["id"].get() : slot->images.size(); + int width, height, channels; + std::vector image_buffer = base64_decode(data_b64); + data_b64.clear(); + auto data = stbi_load_from_memory(image_buffer.data(), image_buffer.size(), &width, &height, &channels, 3); + if (!data) { + LOG_TEE("slot %i - failed to load image id= %i\n", slot->id, img_sl.id); + return false; + } + LOG_TEE("slot %i - image id= %i loaded (%i x %i)\n", slot->id, img_sl.id, width, height); + img_sl.img_data.nx = width; + img_sl.img_data.ny = height; + img_sl.img_data.size = width * height * 3; + img_sl.img_data.data = new uint8_t[width * height * 3](); + memcpy(img_sl.img_data.data, data, width * height * 3); + stbi_image_free(data); + img_sl.request_encode_image = true; + slot->images.push_back(img_sl); + } + // process prompt + // example: system prompt [img-102] user [img-103] describe [img-134] -> [{id: 102, prefix: 'system prompt '}, {id: 103, prefix: ' user '}, {id: 134, prefix: ' describe '}]} + if (slot->images.size() > 0 && !slot->prompt.is_array()) + { + std::string prompt = slot->prompt.get(); + size_t pos = 0, begin_prefix = 0; + std::string pattern = "[img-"; + while ((pos = prompt.find(pattern, pos)) != std::string::npos) { + size_t end_prefix = pos; + pos += pattern.length(); + size_t end_pos = prompt.find("]", pos); + if (end_pos != std::string::npos) + { + std::string image_id = prompt.substr(pos, end_pos - pos); + try + { + int img_id = std::stoi(image_id); + bool found = false; + for (slot_image &img : slot->images) + { + if (img.id == img_id) { + found = true; + img.prefix_prompt = prompt.substr(begin_prefix, end_prefix - begin_prefix); + begin_prefix = end_pos + 1; + break; + } + } + if (!found) { + LOG_TEE("ERROR: Image with id %i not found.\n", img_id); + slot->images.clear(); + return false; + } + } catch (const std::invalid_argument& e) { + LOG_TEE("Invalid image number id in prompt\n"); + slot->images.clear(); + return false; + } + } + } + slot->prompt = ""; + slot->params.input_suffix = prompt.substr(begin_prefix); + slot->params.cache_prompt = false; // multimodal doesn't support cache prompt + } + } + } if (!slot->load_grammar()) { return false; } all_slots_are_idle = false; slot->command = LOAD_PROMPT; - LOG_TEE("slot %i is processing\n", slot->id); + LOG_TEE("slot %i is processing [task id: %i]\n", slot->id, slot->task_id); return true; } @@ -811,6 +1022,9 @@ struct llama_server_context // add the token to slot queue and cache } slot.add_token_string(result); + if(slot.params.stream) { + send_partial_response(slot, result); + } if (slot.multibyte_pending > 0) { slot.multibyte_pending -= token_str.size(); @@ -863,7 +1077,7 @@ struct llama_server_context {"token_text", tokens_to_output_formatted_string(ctx, result.tok)}, {"has_next_token", slot.has_next_token}, {"n_remain", slot.n_remaining}, - {"num_tokens_predicted", slot.num_tokens_predicted}, + {"num_tokens_predicted", slot.n_decoded}, {"stopped_eos", slot.stopped_eos}, {"stopped_word", slot.stopped_word}, {"stopped_limit", slot.stopped_limit}, @@ -907,6 +1121,179 @@ struct llama_server_context return slot.images.size() > 0; } + void send_error(int id, std::string error) { + std::lock_guard lock(mutex_results); + task_result res; + res.id = id; + res.error = true; + res.result_json = { { "content", error } }; + queue_results.push_back(res); + } + + json get_model_props() { + return get_formated_generation(slots[0]); + } + + json get_formated_generation(llama_client_slot & slot) { + const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(ctx)); + const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() && + eos_bias->second < 0.0f && std::isinf(eos_bias->second); + return json{ + {"n_ctx", max_ctx_per_slot}, + {"model", params.model_alias}, + {"seed", slot.params.seed}, + {"temp", slot.sparams.temp}, + {"top_k", slot.sparams.top_k}, + {"top_p", slot.sparams.top_p}, + {"tfs_z", slot.sparams.tfs_z}, + {"typical_p", slot.sparams.typical_p}, + {"repeat_last_n", slot.sparams.repeat_last_n}, + {"repeat_penalty", slot.sparams.repeat_penalty}, + {"presence_penalty", slot.sparams.presence_penalty}, + {"frequency_penalty", slot.sparams.frequency_penalty}, + {"mirostat", slot.sparams.mirostat}, + {"mirostat_tau", slot.sparams.mirostat_tau}, + {"mirostat_eta", slot.sparams.mirostat_eta}, + {"penalize_nl", slot.sparams.penalize_nl}, + {"stop", slot.params.antiprompt}, + {"n_predict", slot.params.n_predict}, + {"n_keep", params.n_keep}, + {"ignore_eos", ignore_eos}, + {"stream", slot.params.stream}, + {"logit_bias", slot.sparams.logit_bias}, + {"n_probs", slot.sparams.n_probs}, + {"grammar", slot.params.grammar}, + }; + } + + void send_partial_response(llama_client_slot & slot, completion_token_output tkn) { + std::lock_guard lock(mutex_results); + task_result res; + res.id = slot.task_id; + res.error = false; + res.stop = false; + res.result_json = json + { + {"content", tkn.text_to_send }, + {"stop", false}, + {"slot_id", slot.id }, + {"multimodal", multimodal } + }; + if (slot.sparams.n_probs > 0) + { + std::vector probs_output = {}; + const std::vector to_send_toks = llama_tokenize(ctx, tkn.text_to_send, false); + size_t probs_pos = std::min(slot.sent_token_probs_index, slot.generated_token_probs.size()); + size_t probs_stop_pos = std::min(slot.sent_token_probs_index + to_send_toks.size(), slot.generated_token_probs.size()); + if (probs_pos < probs_stop_pos) + { + probs_output = std::vector(slot.generated_token_probs.begin() + probs_pos, slot.generated_token_probs.begin() + probs_stop_pos); + } + slot.sent_token_probs_index = probs_stop_pos; + res.result_json["completion_probabilities"] = probs_vector_to_json(ctx, probs_output); + } + queue_results.push_back(res); + } + + void send_final_response(llama_client_slot & slot) { + std::lock_guard lock(mutex_results); + task_result res; + res.id = slot.task_id; + res.error = false; + res.stop = true; + res.result_json = json + { + {"content", !slot.params.stream ? slot.generated_text : ""}, + {"slot_id", slot.id}, + {"stop", true}, + {"model", params.model_alias}, + {"tokens_predicted", slot.n_decoded}, + {"tokens_evaluated", slot.num_prompt_tokens}, + {"generation_settings", get_formated_generation(slot)}, + {"prompt", slot.prompt}, + {"truncated", slot.truncated}, + {"stopped_eos", slot.stopped_eos}, + {"stopped_word", slot.stopped_word}, + {"stopped_limit", slot.stopped_limit}, + {"stopping_word", slot.stopping_word}, + {"tokens_cached", slot.n_past}, + {"timings", slot.get_formated_timings()} + }; + + if (slot.sparams.n_probs > 0) + { + std::vector probs = {}; + if(!slot.params.stream && slot.stopped_word) { + const std::vector stop_word_toks = llama_tokenize(ctx, slot.stopping_word, false); + probs = std::vector(slot.generated_token_probs.begin(), slot.generated_token_probs.end() - stop_word_toks.size()); + } else { + probs = std::vector( + slot.generated_token_probs.begin(), + slot.generated_token_probs.begin() + slot.sent_token_probs_index); + } + res.result_json["completion_probabilities"] = probs_vector_to_json(ctx, probs); + } + queue_results.push_back(res); + } + + void send_embedding(llama_client_slot & slot) { + std::lock_guard lock(mutex_results); + task_result res; + res.id = slot.task_id; + res.error = false; + res.stop = true; + static const int n_embd = llama_n_embd(model); + if (!params.embedding) + { + LOG_WARNING("embedding disabled", { + {"params.embedding", params.embedding}, + }); + res.result_json = json + { + {"embedding", std::vector(n_embd, 0.0f)}, + }; + } else { + const float *data = llama_get_embeddings(ctx); + std::vector embedding(data, data + n_embd); + res.result_json = json + { + {"embedding", embedding }, + }; + } + queue_results.push_back(res); + } + + int request_completion(json data, bool infill) { + std::lock_guard lock(mutex_tasks); + task_server task; + task.id = id_gen.load(); + id_gen.fetch_add(1); // increment id generator + task.data = data; + task.infill_mode = infill; + task.type = COMPLETION_TASK; + queue_tasks.push_back(task); + return task.id; + } + + task_result next_result(int task_id) { + while(true) { + std::this_thread::sleep_for(std::chrono::microseconds(5)); + std::lock_guard lock(mutex_results); + if(queue_results.empty()) { + continue; + } + + for(int i = 0; i < queue_results.size(); i++) { + if(queue_results[i].id == task_id) { + task_result res = queue_results[i]; + queue_results.erase(queue_results.begin() + i); + return res; + } + } + } + return task_result{-1, false, false, {}}; + } + // for multiple images processing bool ingest_images(llama_client_slot &slot, int n_batch) { @@ -974,7 +1361,68 @@ struct llama_server_context return true; } + void request_cancel(int task_id) { + std::lock_guard lock(mutex_tasks); + task_server task; + task.id = id_gen.load(); + id_gen.fetch_add(1); // increment id generator + task.type = CANCEL_TASK; + task.target_id = task_id; + queue_tasks.push_back(task); + } + + void process_tasks() { + std::lock_guard lock(mutex_tasks); + while(!queue_tasks.empty()) { + task_server task = queue_tasks.front(); + queue_tasks.erase(queue_tasks.begin()); + switch (task.type) + { + case COMPLETION_TASK: { // perform completion task + llama_client_slot* slot = get_slot(json_value(task.data, "slot_id", -1)); + if (slot == nullptr) { + LOG_TEE("slot unavailable\n"); + // send error result + send_error(task.id, "slot unavaliable"); + return; + } + + if (task.data.contains("system_prompt")) { + process_system_prompt_data(task.data["system_prompt"]); + } + + slot->reset(); + + slot->infill = task.infill_mode; + slot->task_id = task.id; + + if (!launch_slot_with_data(slot, task.data)) + { + // send error result + send_error(task.id, "internal_error"); + break; + } + } + case CANCEL_TASK: { // release slot linked with the task id + for(auto & slot : slots) { + if(slot.task_id == task.target_id) { + slot.release(); + break; + } + } + } + break; + + default: + break; + } + } + } + bool update_slots() { + // attend tasks + process_tasks(); + // update the system prompt wait until all slots are idle state if (need_update_system_prompt) { @@ -983,8 +1431,6 @@ struct llama_server_context llama_batch_clear(batch); - int kv_cache_free = (n_ctx - num_tokens_system); - if (all_slots_are_idle) { if (system_prompt.empty() && clean_kv_cache) @@ -1029,7 +1475,7 @@ struct llama_server_context for (auto & slot : slots) { // release the slot - if (slot.state == PROCESSING && slot.command == RELEASE && !slot.has_new_token()) + if (slot.state == PROCESSING && slot.command == RELEASE) { slot.state = slot.params.cache_prompt ? SLEEPING : IDLE; if(slot.state == SLEEPING) { @@ -1043,8 +1489,6 @@ struct llama_server_context continue; } - kv_cache_free -= slot.num_prompt_tokens; - if ( slot.state == IDLE || slot.state == SLEEPING || @@ -1238,6 +1682,7 @@ struct llama_server_context // prompt evaluated for embedding if (params.embedding) { + send_embedding(slot); slot.release(); slot.i_batch = -1; return true; @@ -1272,34 +1717,15 @@ struct llama_server_context if (!process_token(result, slot)) { slot.release(); + send_final_response(slot); + slot.print_timings(); } - kv_cache_free -= slot.num_tokens_predicted; slot.i_batch = -1; } } - - if(kv_cache_free < 0 && params.n_parallel > 1) { - LOG_TEE("\nError: kv cache is full, increase context size."); - return false; - } return true; } - - std::vector get_embedding() - { - static const int n_embd = llama_n_embd(model); - if (!params.embedding) - { - LOG_WARNING("embedding disabled", { - {"params.embedding", params.embedding}, - }); - return std::vector(n_embd, 0.0f); - } - const float *data = llama_get_embeddings(ctx); - std::vector embedding(data, data + n_embd); - return embedding; - } }; static void server_print_usage(const char *argv0, const gpt_params ¶ms, @@ -1675,103 +2101,6 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, } } -static void slot_print_timings(struct llama_client_slot * slot) { - LOG_TEE("\n"); - LOG_TEE("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n", - __func__, slot->t_prompt_processing, slot->num_prompt_tokens_processed, slot->t_prompt_processing / slot->num_prompt_tokens_processed, 1e3 / slot->t_prompt_processing * slot->num_prompt_tokens_processed); - LOG_TEE("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", - __func__, slot->t_token_generation, slot->n_decoded, slot->t_token_generation / slot->n_decoded, 1e3 / slot->t_token_generation * slot->n_decoded); - LOG_TEE("%s: total time = %10.2f ms\n", __func__, slot->t_prompt_processing + slot->t_token_generation); -} - -static json format_generation_settings(llama_server_context &llama, llama_client_slot* slot) -{ - const auto eos_bias = slot->sparams.logit_bias.find(llama_token_eos(llama.ctx)); - const bool ignore_eos = eos_bias != slot->sparams.logit_bias.end() && - eos_bias->second < 0.0f && std::isinf(eos_bias->second); - - return json{ - {"n_ctx", llama.n_ctx}, - {"model", llama.params.model_alias}, - {"seed", slot->params.seed}, - {"temp", slot->sparams.temp}, - {"top_k", slot->sparams.top_k}, - {"top_p", slot->sparams.top_p}, - {"tfs_z", slot->sparams.tfs_z}, - {"typical_p", slot->sparams.typical_p}, - {"repeat_last_n", slot->sparams.repeat_last_n}, - {"repeat_penalty", slot->sparams.repeat_penalty}, - {"presence_penalty", slot->sparams.presence_penalty}, - {"frequency_penalty", slot->sparams.frequency_penalty}, - {"mirostat", slot->sparams.mirostat}, - {"mirostat_tau", slot->sparams.mirostat_tau}, - {"mirostat_eta", slot->sparams.mirostat_eta}, - {"penalize_nl", slot->sparams.penalize_nl}, - {"stop", slot->params.antiprompt}, - {"n_predict", slot->params.n_predict}, - {"n_keep", llama.params.n_keep}, - {"ignore_eos", ignore_eos}, - {"stream", slot->params.stream}, - {"logit_bias", slot->sparams.logit_bias}, - {"n_probs", slot->sparams.n_probs}, - {"grammar", slot->params.grammar}, - }; -} - -static json format_embedding_response(llama_server_context &llama) -{ - return json - { - {"embedding", llama.get_embedding()}, - }; -} - -static json format_timings(llama_client_slot* slot) -{ - return json - { - {"prompt_n", slot->num_prompt_tokens_processed}, - {"prompt_ms", slot->t_prompt_processing}, - {"prompt_per_token_ms", slot->t_prompt_processing / slot->num_prompt_tokens_processed}, - {"prompt_per_second", 1e3 / slot->t_prompt_processing * slot->num_prompt_tokens_processed}, - - {"predicted_n", slot->n_decoded}, - {"predicted_ms", slot->t_token_generation}, - {"predicted_per_token_ms", slot->t_token_generation / slot->n_decoded}, - {"predicted_per_second", 1e3 / slot->t_token_generation * slot->n_decoded}, - }; -} - -static json format_final_response(llama_server_context &llama, llama_client_slot* slot, const std::string &content, const std::vector &probs) -{ - - json res = json - { - {"content", content}, - {"slot_id", slot->id}, - {"stop", true}, - {"model", llama.params.model_alias}, - {"tokens_predicted", slot->n_decoded}, - {"tokens_evaluated", slot->num_prompt_tokens}, - {"generation_settings", format_generation_settings(llama, slot)}, - {"prompt", slot->prompt}, - {"truncated", slot->truncated}, - {"stopped_eos", slot->stopped_eos}, - {"stopped_word", slot->stopped_word}, - {"stopped_limit", slot->stopped_limit}, - {"stopping_word", slot->stopping_word}, - {"tokens_cached", slot->n_past}, - {"timings", format_timings(slot)} - }; - - if (slot->sparams.n_probs > 0) - { - res["completion_probabilities"] = probs_vector_to_json(llama.ctx, probs); - } - - return res; -} - static json format_partial_response( llama_server_context &llama, llama_client_slot* slot, const std::string &content, const std::vector &probs ) { @@ -1803,191 +2132,6 @@ static json format_detokenized_response(std::string content) {"content", content}}; } -template -static T json_value(const json &body, const std::string &key, const T &default_value) -{ - // Fallback null to default value - return body.contains(key) && !body.at(key).is_null() - ? body.value(key, default_value) - : default_value; -} - -static void parse_options_completion(const json &body, llama_client_slot* slot, llama_server_context &llama) -{ - slot_params default_params; - llama_sampling_params default_sparams; - - slot->params.stream = json_value(body, "stream", false); - slot->params.cache_prompt = json_value(body, "cache_prompt", false); - slot->params.n_predict = json_value(body, "n_predict", default_params.n_predict); - slot->sparams.top_k = json_value(body, "top_k", default_sparams.top_k); - slot->sparams.top_p = json_value(body, "top_p", default_sparams.top_p); - slot->sparams.tfs_z = json_value(body, "tfs_z", default_sparams.tfs_z); - slot->sparams.typical_p = json_value(body, "typical_p", default_sparams.typical_p); - slot->sparams.repeat_last_n = json_value(body, "repeat_last_n", default_sparams.repeat_last_n); - slot->sparams.temp = json_value(body, "temperature", default_sparams.temp); - slot->sparams.repeat_penalty = json_value(body, "repeat_penalty", default_sparams.repeat_penalty); - slot->sparams.presence_penalty = json_value(body, "presence_penalty", default_sparams.presence_penalty); - slot->sparams.frequency_penalty = json_value(body, "frequency_penalty", default_sparams.frequency_penalty); - slot->sparams.mirostat = json_value(body, "mirostat", default_sparams.mirostat); - slot->sparams.mirostat_tau = json_value(body, "mirostat_tau", default_sparams.mirostat_tau); - slot->sparams.mirostat_eta = json_value(body, "mirostat_eta", default_sparams.mirostat_eta); - slot->sparams.penalize_nl = json_value(body, "penalize_nl", default_sparams.penalize_nl); - slot->params.n_keep = json_value(body, "n_keep", slot->params.n_keep); - slot->params.seed = json_value(body, "seed", default_params.seed); - slot->params.grammar = json_value(body, "grammar", default_params.grammar); - slot->sparams.n_probs = json_value(body, "n_probs", default_sparams.n_probs); - - if (body.count("prompt") != 0) - { - slot->prompt = body["prompt"]; - } - else - { - slot->prompt = ""; - } - - slot->sparams.logit_bias.clear(); - if (json_value(body, "ignore_eos", false)) - { - slot->sparams.logit_bias[llama_token_eos(llama.ctx)] = -INFINITY; - } - - const auto &logit_bias = body.find("logit_bias"); - if (logit_bias != body.end() && logit_bias->is_array()) - { - const int n_vocab = llama_n_vocab(llama.model); - for (const auto &el : *logit_bias) - { - if (el.is_array() && el.size() == 2 && el[0].is_number_integer()) - { - llama_token tok = el[0].get(); - if (tok >= 0 && tok < n_vocab) - { - if (el[1].is_number()) - { - slot->sparams.logit_bias[tok] = el[1].get(); - } - else if (el[1].is_boolean() && !el[1].get()) - { - slot->sparams.logit_bias[tok] = -INFINITY; - } - } - } - } - } - - slot->params.antiprompt.clear(); - const auto &stop = body.find("stop"); - if (stop != body.end() && stop->is_array()) - { - for (const auto &word : *stop) - { - if (!word.empty()) - { - slot->params.antiprompt.push_back(word); - } - } - } - - LOG_VERBOSE("completion parameters parsed", format_generation_settings(llama, slot)); - if(!llama.multimodal) - { - return; - } - - const auto &images_data = body.find("image_data"); - if (images_data != body.end() && images_data->is_array()) - { - for (const auto &img : *images_data) - { - slot_image img_sl; - std::string data_b64 = img["data"].get(); - img_sl.id = img.count("id") != 0 ? img["id"].get() : slot->images.size(); - int width, height, channels; - std::vector image_buffer = base64_decode(data_b64); - data_b64.clear(); - auto data = stbi_load_from_memory(image_buffer.data(), image_buffer.size(), &width, &height, &channels, 3); - if (!data) { - LOG_TEE("slot %i - failed to load image id= %i\n", slot->id, img_sl.id); - return; - } - LOG_TEE("slot %i - image id= %i loaded (%i x %i)\n", slot->id, img_sl.id, width, height); - img_sl.img_data.nx = width; - img_sl.img_data.ny = height; - img_sl.img_data.size = width * height * 3; - img_sl.img_data.data = new uint8_t[width * height * 3](); - memcpy(img_sl.img_data.data, data, width * height * 3); - stbi_image_free(data); - img_sl.request_encode_image = true; - slot->images.push_back(img_sl); - } - // process prompt - // example: system prompt [img-102] user [img-103] describe [img-134] -> [{id: 102, prefix: 'system prompt '}, {id: 103, prefix: ' user '}, {id: 134, prefix: ' describe '}]} - if (slot->images.size() > 0 && !slot->prompt.is_array()) - { - std::string prompt = slot->prompt.get(); - size_t pos = 0, begin_prefix = 0; - std::string pattern = "[img-"; - while ((pos = prompt.find(pattern, pos)) != std::string::npos) { - size_t end_prefix = pos; - pos += pattern.length(); - size_t end_pos = prompt.find("]", pos); - if (end_pos != std::string::npos) - { - std::string image_id = prompt.substr(pos, end_pos - pos); - try - { - int img_id = std::stoi(image_id); - bool found = false; - for (slot_image &img : slot->images) - { - if (img.id == img_id) { - found = true; - img.prefix_prompt = prompt.substr(begin_prefix, end_prefix - begin_prefix); - begin_prefix = end_pos + 1; - break; - } - } - if (!found) { - LOG_TEE("ERROR: Image with id %i not found.\n", img_id); - slot->images.clear(); - return; - } - } catch (const std::invalid_argument& e) { - LOG_TEE("Invalid image number id in prompt\n"); - slot->images.clear(); - return; - } - } - } - slot->prompt = ""; - slot->params.input_suffix = prompt.substr(begin_prefix); - slot->params.cache_prompt = false; // multimodal doesn't support cache prompt - } - } -} - -static void parse_options_infill(const json &body, llama_server_context &llama, llama_client_slot *slot) -{ - if (body.count("input_prefix") != 0) - { - slot->params.input_prefix = body["input_prefix"]; - } - else - { - slot->params.input_prefix = ""; - } - if (body.count("input_suffix") != 0) - { - slot->params.input_suffix = body["input_suffix"]; - } - else - { - slot->params.input_suffix = ""; - } - parse_options_completion(body, slot, llama); -} static void log_server_request(const httplib::Request &req, const httplib::Response &res) { @@ -2111,123 +2255,48 @@ int main(int argc, char **argv) svr.Post("/completion", [&llama](const httplib::Request &req, httplib::Response &res) { json data = json::parse(req.body); - - llama_client_slot* slot = llama.get_slot(json_value(data, "slot_id", -1)); - - if (slot == nullptr) { - LOG_TEE("slot unavailable\n"); - res.status = 404; - res.set_content("slot_error", "text/plain"); - return; - } - - if (data.contains("system_prompt")) { - llama.process_system_prompt_data(data["system_prompt"]); - } - - slot->reset(); - - parse_options_completion(data, slot, llama); - - if (!llama.launch_slot(slot)) - { - res.status = 400; - return; - } - - if (!slot->params.stream) { + const int task_id = llama.request_completion(data, false); + if (!json_value(data, "stream", false)) { std::string completion_text; - - while (slot->is_processing()) - { - if (slot->has_new_token()) - { - completion_text += slot->next().text_to_send; - } - else - { - std::this_thread::sleep_for(std::chrono::microseconds(5)); - } + task_result result = llama.next_result(task_id); + if(!result.error && result.stop) { + res.set_content(result.result_json.dump(-1, ' ', false, json::error_handler_t::replace), "application/json"); } - - auto probs = slot->generated_token_probs; - if (slot->sparams.n_probs > 0 && slot->stopped_word) + else { - const std::vector stop_word_toks = llama_tokenize(llama.ctx, slot->stopping_word, false); - probs = std::vector(slot->generated_token_probs.begin(), slot->generated_token_probs.end() - stop_word_toks.size()); + res.status = 404; + res.set_content(result.result_json["content"], "text/plain"); + return; } - - const json data = format_final_response(llama, slot, completion_text, probs); - slot_print_timings(slot); - slot->release(); - res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace), "application/json"); } else { - const auto chunked_content_provider = [slot, &llama](size_t, httplib::DataSink & sink) { - size_t sent_token_probs_index = 0; - while (slot->is_processing()) - { - if (slot->has_new_token()) - { // new token notification - const completion_token_output token = slot->next(); - std::vector probs_output = {}; - if (slot->sparams.n_probs > 0) - { - const std::vector to_send_toks = llama_tokenize(llama.ctx, token.text_to_send, false); - size_t probs_pos = std::min(sent_token_probs_index, slot->generated_token_probs.size()); - size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), slot->generated_token_probs.size()); - if (probs_pos < probs_stop_pos) - { - probs_output = std::vector(slot->generated_token_probs.begin() + probs_pos, slot->generated_token_probs.begin() + probs_stop_pos); - } - sent_token_probs_index = probs_stop_pos; - } - const json data = format_partial_response(llama, slot, token.text_to_send, probs_output); + const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink & sink) { + while(true) { + task_result result = llama.next_result(task_id); + if(!result.error) { const std::string str = - "data: " + - data.dump(-1, ' ', false, json::error_handler_t::replace) + - "\n\n"; + "data: " + + result.result_json.dump(-1, ' ', false, json::error_handler_t::replace) + + "\n\n"; LOG_VERBOSE("data stream", { { "to_send", str } }); if (!sink.write(str.c_str(), str.size())) { - slot->release(); return false; } + if(result.stop) { + break; + } + } else { + break; } - else - { - std::this_thread::sleep_for(std::chrono::microseconds(5)); - } - } - const json data = format_final_response( - llama, slot, - "", - std::vector( - slot->generated_token_probs.begin(), - slot->generated_token_probs.begin() + sent_token_probs_index) - ); - slot_print_timings(slot); - const std::string str = - "data: " + - data.dump(-1, ' ', false, json::error_handler_t::replace) + - "\n\n"; - - LOG_VERBOSE("data stream", { - { "to_send", str } - }); - - if (!sink.write(str.data(), str.size())) - { - slot->release(); - return false; } sink.done(); return true; }; - auto on_complete = [slot] (bool) { - slot->release(); - slot->clean_tokens(); + auto on_complete = [task_id, &llama] (bool) { + // cancel + llama.request_cancel(task_id); }; res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); } @@ -2236,129 +2305,48 @@ int main(int argc, char **argv) svr.Post("/infill", [&llama](const httplib::Request &req, httplib::Response &res) { json data = json::parse(req.body); - - llama_client_slot* slot = llama.get_slot(json_value(data, "slot_id", -1)); - - if (slot == nullptr) - { - LOG_TEE("slot unavailable\n"); - res.status = 404; - res.set_content("slot_error", "text/plain"); - return; - } - - if (data.contains("system_prompt")) - { - llama.process_system_prompt_data(data["system_prompt"]); - } - - slot->reset(); - slot->infill = true; - - parse_options_infill(data, llama, slot); - - if (!llama.launch_slot(slot)) - { - res.status = 400; - return; - } - - if (!slot->params.stream) - { - std::string completion_text = ""; - while (slot->is_processing()) - { - if(slot->has_new_token()) - { - completion_text += slot->next().text_to_send; - } - else - { - std::this_thread::sleep_for(std::chrono::microseconds(5)); - } + const int task_id = llama.request_completion(data, true); + if (!json_value(data, "stream", false)) { + std::string completion_text; + task_result result = llama.next_result(task_id); + if(!result.error && result.stop) { + res.set_content(result.result_json.dump(-1, ' ', false, json::error_handler_t::replace), "application/json"); } - - auto probs = slot->generated_token_probs; - if (slot->sparams.n_probs > 0 && slot->stopped_word) + else { - const std::vector stop_word_toks = llama_tokenize(llama.ctx, slot->stopping_word, false); - probs = std::vector(slot->generated_token_probs.begin(), slot->generated_token_probs.end() - stop_word_toks.size()); + res.status = 404; + res.set_content(result.result_json["content"], "text/plain"); + return; } - - const json data = format_final_response(llama, slot, completion_text, probs); - slot_print_timings(slot); - res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace), "application/json"); - } - else - { - const auto chunked_content_provider = [slot, &llama](size_t, httplib::DataSink & sink) { - size_t sent_token_probs_index = 0; - while (slot->is_processing()) - { - if (slot->has_new_token()) - { - // new token notification - const completion_token_output token = slot->next(); - std::vector probs_output = {}; - if (slot->sparams.n_probs > 0) - { - const std::vector to_send_toks = llama_tokenize(llama.ctx, token.text_to_send, false); - size_t probs_pos = std::min(sent_token_probs_index, slot->generated_token_probs.size()); - size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), slot->generated_token_probs.size()); - if (probs_pos < probs_stop_pos) - { - probs_output = std::vector(slot->generated_token_probs.begin() + probs_pos, slot->generated_token_probs.begin() + probs_stop_pos); - } - sent_token_probs_index = probs_stop_pos; - } - const json data = format_partial_response(llama, slot, token.text_to_send, probs_output); + } else { + const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink & sink) { + while(true) { + task_result result = llama.next_result(task_id); + if(!result.error) { const std::string str = - "data: " + - data.dump(-1, ' ', false, json::error_handler_t::replace) + - "\n\n"; + "data: " + + result.result_json.dump(-1, ' ', false, json::error_handler_t::replace) + + "\n\n"; LOG_VERBOSE("data stream", { { "to_send", str } }); if (!sink.write(str.c_str(), str.size())) { - slot->release(); return false; } + if(result.stop) { + break; + } + } else { + break; } - else - { - std::this_thread::sleep_for(std::chrono::milliseconds(5)); - } - } - const json data = format_final_response( - llama, slot, - "", - std::vector( - slot->generated_token_probs.begin(), - slot->generated_token_probs.begin() + sent_token_probs_index) - ); - slot_print_timings(slot); - const std::string str = - "data: " + - data.dump(-1, ' ', false, json::error_handler_t::replace) + - "\n\n"; - - LOG_VERBOSE("data stream", { - { "to_send", str } - }); - - if (!sink.write(str.data(), str.size())) - { - slot->release(); - return false; } sink.done(); return true; }; - auto on_complete = [slot] (bool) - { - slot->clean_tokens(); - slot->release(); + auto on_complete = [task_id, &llama] (bool) { + // cancel + llama.request_cancel(task_id); }; res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); } @@ -2366,7 +2354,7 @@ int main(int argc, char **argv) svr.Get("/model.json", [&llama](const httplib::Request &, httplib::Response &res) { - const json data = format_generation_settings(llama, llama.get_slot(0)); + const json data = llama.get_model_props(); return res.set_content(data.dump(), "application/json"); }); @@ -2402,23 +2390,18 @@ int main(int argc, char **argv) svr.Post("/embedding", [&llama](const httplib::Request &req, httplib::Response &res) { const json body = json::parse(req.body); - llama_client_slot* slot = llama.get_slot(-1); - slot->reset(); + json prompt; if (body.count("content") != 0) { - slot->prompt = body["content"]; + prompt = body["content"]; } else { - slot->prompt = ""; + prompt = ""; } - llama.params.n_predict = 0; - llama.launch_slot(slot); - while (slot->is_processing()) { - std::this_thread::sleep_for(std::chrono::microseconds(10)); - } - const json data = format_embedding_response(llama); - return res.set_content(data.dump(), "application/json"); + const int task_id = llama.request_completion({ {"prompt", prompt}, { "n_predict", 0} }, false); + task_result result = llama.next_result(task_id); + return res.set_content(result.result_json.dump(), "application/json"); }); svr.set_logger(log_server_request);