diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 53a209736..2f3c3fe4f 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -32,7 +32,7 @@ struct server_params { std::string hostname = "127.0.0.1"; std::string public_path = "examples/server/public"; - int32_t port = 8040; + int32_t port = 8080; int32_t read_timeout = 600; int32_t write_timeout = 600; }; @@ -78,6 +78,8 @@ struct slot_params { std::string grammar = ""; // optional BNF-like grammar to constrain sampling bool remember_generation = false; // remember a the prompt to avoid reprocessing all prompt std::vector antiprompt; + json input_prefix; + json input_suffix; }; // completion token output with probabilities @@ -233,6 +235,7 @@ struct llama_client_slot std::string stopping_word; int32_t multibyte_pending = 0; size_t sent_count = 0; + bool infill = false; struct slot_params params; struct llama_sampling_params sparams; @@ -257,6 +260,7 @@ struct llama_client_slot multibyte_pending = 0; n_past = 0; sent_count = 0; + infill = false; if (grammar != nullptr) { llama_grammar_free(grammar); @@ -508,82 +512,6 @@ struct llama_server_context return true; } - void loadInfill() - { - // bool suff_rm_leading_spc = true; - // if (params.input_suffix.find_first_of(" ") == 0 && params.input_suffix.size() > 1) { - // params.input_suffix.erase(0, 1); - // suff_rm_leading_spc = false; - // } - - // auto prefix_tokens = tokenize(params.input_prefix, false); - // auto suffix_tokens = tokenize(params.input_suffix, false); - // const int space_token = 29871; - // if (suff_rm_leading_spc && suffix_tokens[0] == space_token) { - // suffix_tokens.erase(suffix_tokens.begin()); - // } - // prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(ctx)); - // prefix_tokens.insert(prefix_tokens.begin(), llama_token_bos(ctx)); // always add BOS - // prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(ctx)); - // prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end()); - // prefix_tokens.push_back(llama_token_middle(ctx)); - // auto prompt_tokens = prefix_tokens; - - // num_prompt_tokens = prompt_tokens.size(); - - // if (params.n_keep < 0) - // { - // params.n_keep = (int)num_prompt_tokens; - // } - // params.n_keep = std::min(params.n_ctx - 4, params.n_keep); - - // // if input prompt is too big, truncate like normal - // if (num_prompt_tokens >= (size_t)params.n_ctx) - // { - // printf("Input prompt is too big, truncating. Can only take %d tokens but got %zu\n", params.n_ctx, num_prompt_tokens); - // // todo we probably want to cut from both sides - // const int n_left = (params.n_ctx - params.n_keep) / 2; - // std::vector new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep); - // const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left; - // new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end()); - // std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), last_n_tokens.begin()); - - // LOG_VERBOSE("input truncated", { - // {"n_ctx", params.n_ctx}, - // {"n_keep", params.n_keep}, - // {"n_left", n_left}, - // {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())}, - // }); - - // truncated = true; - // prompt_tokens = new_tokens; - // } - // else - // { - // const size_t ps = num_prompt_tokens; - // std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0); - // std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps); - // } - - // // compare the evaluated prompt with the new prompt - // n_past = common_part(embd, prompt_tokens); - // embd = prompt_tokens; - // if (n_past == num_prompt_tokens) - // { - // // we have to evaluate at least 1 token to generate logits. - // printf("we have to evaluate at least 1 token to generate logits\n"); - // n_past--; - // } - - // LOG_VERBOSE("prompt ingested", { - // {"n_past", n_past}, - // {"cached", tokens_to_str(ctx, embd.cbegin(), embd.cbegin() + n_past)}, - // {"to_eval", tokens_to_str(ctx, embd.cbegin() + n_past, embd.cend())}, - // }); - - // has_next_token = true; - } - void cleanKVCache() { // clear the entire KV cache for (int i = 0; i < params.n_parallel; ++i) @@ -839,8 +767,29 @@ struct llama_server_context if ((slot.state == IDLE || keep_gen) && slot.command == LOAD_PROMPT) { slot.state = PROCESSING; slot.command = NONE; + std::vector prompt_tokens; + if(slot.infill) { + bool suff_rm_leading_spc = true; + if (params.input_suffix.find_first_of(" ") == 0 && params.input_suffix.size() > 1) { + params.input_suffix.erase(0, 1); + suff_rm_leading_spc = false; + } + auto prefix_tokens = tokenize(slot.params.input_prefix, false); + auto suffix_tokens = tokenize(slot.params.input_suffix, false); + const int space_token = 29871; + if (suff_rm_leading_spc && suffix_tokens[0] == space_token) { + suffix_tokens.erase(suffix_tokens.begin()); + } + prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(ctx)); + prefix_tokens.insert(prefix_tokens.begin(), llama_token_bos(ctx)); // always add BOS + prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(ctx)); + prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end()); + prefix_tokens.push_back(llama_token_middle(ctx)); + prompt_tokens = prefix_tokens; + } else { + prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt + } - auto prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt slot.num_prompt_tokens = prompt_tokens.size(); slot.n_past = keep_gen ? common_part(slot.context_tokens, prompt_tokens) : 0; @@ -1304,7 +1253,7 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, } } -static json format_generation_settings(llama_server_context &llama, llama_client_slot* &slot) +static json format_generation_settings(llama_server_context &llama, llama_client_slot* slot) { const auto eos_bias = slot->sparams.logit_bias.find(llama_token_eos(llama.ctx)); const bool ignore_eos = eos_bias != slot->sparams.logit_bias.end() && @@ -1428,7 +1377,7 @@ static T json_value(const json &body, const std::string &key, const T &default_v : default_value; } -static void parse_options_completion(const json &body, llama_client_slot* &slot, llama_server_context &llama) +static void parse_options_completion(const json &body, llama_client_slot* slot, llama_server_context &llama) { slot_params default_params; llama_sampling_params default_sparams; @@ -1508,26 +1457,26 @@ static void parse_options_completion(const json &body, llama_client_slot* &slot, LOG_VERBOSE("completion parameters parsed", format_generation_settings(llama, slot)); } -// static void parse_options_infill(const json &body, llama_server_context &llama) -// { -// if (body.count("input_prefix") != 0) -// { -// llama.params.input_prefix = body["input_prefix"]; -// } -// else -// { -// llama.params.input_prefix = ""; -// } -// if (body.count("input_suffix") != 0) -// { -// llama.params.input_suffix = body["input_suffix"]; -// } -// else -// { -// llama.params.input_suffix = ""; -// } -// parse_options_completion(body, slot, llama); -// } +static void parse_options_infill(const json &body, llama_server_context &llama, llama_client_slot *slot) +{ + if (body.count("input_prefix") != 0) + { + slot->params.input_prefix = body["input_prefix"]; + } + else + { + slot->params.input_prefix = ""; + } + if (body.count("input_suffix") != 0) + { + slot->params.input_suffix = body["input_suffix"]; + } + else + { + slot->params.input_suffix = ""; + } + parse_options_completion(body, slot, llama); +} static void log_server_request(const Request &req, const Response &res) { @@ -1682,7 +1631,6 @@ int main(int argc, char **argv) svr.Post("/completion", [&llama](const Request &req, Response &res) { - json data = json::parse(req.body); llama_client_slot* slot = llama.getSlot(json_value(data, "slot_id", -1)); @@ -1702,7 +1650,7 @@ int main(int argc, char **argv) slot->reset(); - parse_options_completion(json::parse(req.body), slot, llama); + parse_options_completion(data, slot, llama); if (!llama.launchSlot(slot)) { @@ -1711,44 +1659,36 @@ int main(int argc, char **argv) } if (!slot->params.stream) { - // if (llama.params.n_beams) { - // // Fill llama.generated_token_probs vector with final beam. - // llama_beam_search(llama.ctx, beam_search_callback, &llama, llama.params.n_beams, - // llama.n_past, llama.n_remain); - // // Translate llama.generated_token_probs to llama.generated_text. - // append_to_generated_text_from_generated_token_probs(llama); - // } else { - // size_t stop_pos = std::string::npos; + std::string completion_text = ""; + if (llama.params.n_beams) { + // // Fill llama.generated_token_probs vector with final beam. + // llama_beam_search(llama.ctx, beam_search_callback, &llama, llama.params.n_beams, + // slot->n_past, llama.n_remain); + // // Translate llama.generated_token_probs to llama.generated_text. + // append_to_generated_text_from_generated_token_probs(llama); + } else { - // while (llama.has_next_token) { - // const completion_token_output token_with_probs = llama.doCompletion(); - // const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_piece(llama.ctx, token_with_probs.tok); + while (slot->isProcessing()) { + if(slot->hasNewToken()) { + completion_text += slot->next().text_to_send; + } else { + std::this_thread::sleep_for(std::chrono::milliseconds(5)); + } + } + } - // stop_pos = llama.findStoppingStrings(llama.generated_text, - // token_text.size(), STOP_FULL); - // } + auto probs = slot->generated_token_probs; + if (slot->sparams.n_probs > 0 && slot->stopped_word) { + const std::vector stop_word_toks = llama_tokenize(llama.ctx, slot->stopping_word, false); + probs = std::vector(slot->generated_token_probs.begin(), slot->generated_token_probs.end() - stop_word_toks.size()); + } - // if (stop_pos == std::string::npos) { - // stop_pos = llama.findStoppingStrings(llama.generated_text, 0, STOP_PARTIAL); - // } - // if (stop_pos != std::string::npos) { - // llama.generated_text.erase(llama.generated_text.begin() + stop_pos, - // llama.generated_text.end()); - // } - // } + const json data = format_final_response(llama, slot, completion_text, probs); - // auto probs = llama.generated_token_probs; - // if (llama.params.n_probs > 0 && llama.stopped_word) { - // const std::vector stop_word_toks = llama_tokenize(llama.ctx, llama.stopping_word, false); - // probs = std::vector(llama.generated_token_probs.begin(), llama.generated_token_probs.end() - stop_word_toks.size()); - // } + //llama_print_timings(llama.ctx); - // const json data = format_final_response(llama, llama.generated_text, probs); - - // llama_print_timings(llama.ctx); - - // res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace), - // "application/json"); + res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace), + "application/json"); } else { const auto chunked_content_provider = [slot, &llama](size_t, DataSink & sink) { size_t sent_token_probs_index = 0; @@ -1810,131 +1750,101 @@ int main(int argc, char **argv) res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); } }); - // svr.Post("/infill", [&llama](const Request &req, Response &res) - // { - // auto lock = llama.lock(); + svr.Post("/infill", [&llama](const Request &req, Response &res) + { - // llama.rewind(); + json data = json::parse(req.body); - // llama_reset_timings(llama.ctx); + llama_client_slot* slot = llama.getSlot(json_value(data, "slot_id", -1)); - // parse_options_infill(json::parse(req.body), llama); + if(slot == nullptr) { + LOG_TEE("slot unavailable\n"); + res.status = 404; + res.set_content("slot_error", "text/plain"); + return; + } - // if (!llama.loadGrammar()) - // { - // res.status = 400; - // return; - // } - // llama.loadInfill(); - // llama.beginCompletion(); - // const auto chunked_content_provider = [&](size_t, DataSink & sink) { - // size_t sent_count = 0; - // size_t sent_token_probs_index = 0; + if(data.contains("system_prompt")) { + llama.processSystemPromptData(data["system_prompt"]); + } - // while (llama.has_next_token) { - // const completion_token_output token_with_probs = llama.doCompletion(); - // if (token_with_probs.tok == -1 || llama.multibyte_pending > 0) { - // continue; - // } - // const std::string token_text = llama_token_to_piece(llama.ctx, token_with_probs.tok); + // llama_reset_timings(llama.ctx); - // size_t pos = std::min(sent_count, llama.generated_text.size()); + slot->reset(); + slot->infill = true; - // const std::string str_test = llama.generated_text.substr(pos); - // bool is_stop_full = false; - // size_t stop_pos = - // llama.findStoppingStrings(str_test, token_text.size(), STOP_FULL); - // if (stop_pos != std::string::npos) { - // is_stop_full = true; - // llama.generated_text.erase( - // llama.generated_text.begin() + pos + stop_pos, - // llama.generated_text.end()); - // pos = std::min(sent_count, llama.generated_text.size()); - // } else { - // is_stop_full = false; - // stop_pos = llama.findStoppingStrings(str_test, token_text.size(), - // STOP_PARTIAL); - // } + parse_options_infill(data, llama, slot); - // if ( - // stop_pos == std::string::npos || - // // Send rest of the text if we are at the end of the generation - // (!llama.has_next_token && !is_stop_full && stop_pos > 0) - // ) { - // const std::string to_send = llama.generated_text.substr(pos, std::string::npos); + if (!llama.launchSlot(slot)) + { + res.status = 400; + return; + } - // sent_count += to_send.size(); + const auto chunked_content_provider = [slot, &llama](size_t, DataSink & sink) { + size_t sent_token_probs_index = 0; + while(slot->isProcessing()) { + if(slot->hasNewToken()) { // new token notification + const completion_token_output token = slot->next(); + std::vector probs_output = {}; + if (slot->sparams.n_probs > 0) { + const std::vector to_send_toks = llama_tokenize(llama.ctx, token.text_to_send, false); + size_t probs_pos = std::min(sent_token_probs_index, slot->generated_token_probs.size()); + size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), slot->generated_token_probs.size()); + if (probs_pos < probs_stop_pos) { + probs_output = std::vector(slot->generated_token_probs.begin() + probs_pos, slot->generated_token_probs.begin() + probs_stop_pos); + } + sent_token_probs_index = probs_stop_pos; + } + const json data = format_partial_response(llama, slot, token.text_to_send, probs_output); + const std::string str = + "data: " + + data.dump(-1, ' ', false, json::error_handler_t::replace) + + "\n\n"; + LOG_VERBOSE("data stream", { + { "to_send", str } + }); + if(!sink.write(str.c_str(), str.size())) { + slot->release(); + return false; + } + } else { + std::this_thread::sleep_for(std::chrono::milliseconds(5)); + } + } + const json data = format_final_response( + llama, slot, + "", + std::vector( + slot->generated_token_probs.begin(), + slot->generated_token_probs.begin() + sent_token_probs_index) + ); + const std::string str = + "data: " + + data.dump(-1, ' ', false, json::error_handler_t::replace) + + "\n\n"; + LOG_VERBOSE("data stream", { + { "to_send", str } + }); + if (!sink.write(str.data(), str.size())) { + slot->release(); + return false; + } + sink.done(); + return true; + }; + auto on_complete = [slot, &llama] (bool) { + slot->sent_tokens = 0; + slot->generated_token_probs.clear(); + slot->release(); + }; + res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); + }); - // std::vector probs_output = {}; - - // if (llama.params.n_probs > 0) { - // const std::vector to_send_toks = llama_tokenize(llama.ctx, to_send, false); - // size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size()); - // size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size()); - // if (probs_pos < probs_stop_pos) { - // probs_output = std::vector(llama.generated_token_probs.begin() + probs_pos, llama.generated_token_probs.begin() + probs_stop_pos); - // } - // sent_token_probs_index = probs_stop_pos; - // } - - // const json data = format_partial_response(llama, to_send, probs_output); - - // const std::string str = - // "data: " + - // data.dump(-1, ' ', false, json::error_handler_t::replace) + - // "\n\n"; - - // LOG_VERBOSE("data stream", { - // { "to_send", str } - // }); - - // if (!sink.write(str.data(), str.size())) { - // LOG_VERBOSE("stream closed", {}); - // llama_print_timings(llama.ctx); - // return false; - // } - // } - - // if (!llama.has_next_token) { - // // Generation is done, send extra information. - // const json data = format_final_response( - // llama, - // "", - // std::vector(llama.generated_token_probs.begin(), llama.generated_token_probs.begin() + sent_token_probs_index) - // ); - - // const std::string str = - // "data: " + - // data.dump(-1, ' ', false, json::error_handler_t::replace) + - // "\n\n"; - - // LOG_VERBOSE("data stream", { - // { "to_send", str } - // }); - - // if (!sink.write(str.data(), str.size())) { - // LOG_VERBOSE("stream closed", {}); - // llama_print_timings(llama.ctx); - // return false; - // } - // } - // } - - // llama_print_timings(llama.ctx); - // sink.done(); - // return true; - // }; - // const auto on_complete = [&](bool) { - // llama.mutex.unlock(); - // }; - // lock.release(); - // res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); - // }); - - // svr.Get("/model.json", [&llama](const Request &, Response &res) - // { - // const json data = format_generation_settings(llama); - // return res.set_content(data.dump(), "application/json"); }); + svr.Get("/model.json", [&llama](const Request &, Response &res) + { + const json data = format_generation_settings(llama, llama.getSlot(0)); + return res.set_content(data.dump(), "application/json"); }); svr.Options(R"(/.*)", [](const Request &, Response &res) { return res.set_content("", "application/json"); }); @@ -1965,29 +1875,29 @@ int main(int argc, char **argv) const json data = format_detokenized_response(content); return res.set_content(data.dump(), "application/json"); }); - // svr.Post("/embedding", [&llama](const Request &req, Response &res) - // { - // auto lock = llama.lock(); + svr.Post("/embedding", [&llama](const Request &req, Response &res) + { + const json body = json::parse(req.body); - // const json body = json::parse(req.body); + llama_client_slot* slot = llama.getSlot(-1); - // llama.rewind(); - // llama_reset_timings(llama.ctx); - // if (body.count("content") != 0) - // { - // llama.prompt = body["content"]; - // } - // else - // { - // llama.prompt = ""; - // } - // llama.params.n_predict = 0; - // llama.loadPrompt(); - // llama.beginCompletion(); - // llama.doCompletion(); - - // const json data = format_embedding_response(llama); - // return res.set_content(data.dump(), "application/json"); }); + slot->reset(); + //llama_reset_timings(llama.ctx); + if (body.count("content") != 0) + { + slot->prompt = body["content"]; + } + else + { + slot->prompt = ""; + } + llama.params.n_predict = 0; + llama.launchSlot(slot); + while(slot->isProcessing()) { + std::this_thread::sleep_for(std::chrono::microseconds(10)); + } + const json data = format_embedding_response(llama); + return res.set_content(data.dump(), "application/json"); }); svr.set_logger(log_server_request);