From 63f99b1ea622cfe4f9ddeda2ac7a8231416e0dbb Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Wed, 11 Oct 2023 18:14:11 -0400 Subject: [PATCH] implementing parallel decoding in server example --- examples/server/server.cpp | 1730 +++++++++++++++++++++--------------- 1 file changed, 1017 insertions(+), 713 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 8c5318c65..00b46bc3d 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -18,6 +18,8 @@ #include "json-schema-to-grammar.mjs.hpp" #include +#include +#include #ifndef SERVER_VERBOSE #define SERVER_VERBOSE 1 @@ -35,6 +37,68 @@ struct server_params int32_t write_timeout = 600; }; +static bool server_verbose = false; + +#if SERVER_VERBOSE != 1 +#define LOG_VERBOSE(MSG, ...) +#else +#define LOG_VERBOSE(MSG, ...) \ + do \ + { \ + if (server_verbose) \ + { \ + server_log("VERBOSE", __func__, __LINE__, MSG, __VA_ARGS__); \ + } \ + } while (0) +#endif + +#define LOG_ERROR(MSG, ...) server_log("ERROR", __func__, __LINE__, MSG, __VA_ARGS__) +#define LOG_WARNING(MSG, ...) server_log("WARNING", __func__, __LINE__, MSG, __VA_ARGS__) +#define LOG_INFO(MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__) + + +// parallel +enum slot_state +{ + IDLE, + SLEEPING, + PROCESSING +}; + +enum slot_command { + NONE, + LOAD_PROMPT, + RELEASE +}; + +struct slot_params { + bool stream = true; + uint32_t seed = -1; // RNG seed + int32_t n_predict = 128; // new tokens to predict + + // sampler params + int32_t top_k = 40; // <= 0 to use vocab size + float top_p = 0.95f; // 1.0 = disabled + float tfs_z = 1.00f; // 1.0 = disabled + float typical_p = 1.00f; // 1.0 = disabled + float temp = 0.80f; // 1.0 = disabled + float repeat_penalty = 1.10f; // 1.0 = disabled + int32_t repeat_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) + float frequency_penalty = 0.00f; // 0.0 = disabled + float presence_penalty = 0.00f; // 0.0 = disabled + int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 + float mirostat_tau = 5.00f; // target entropy + float mirostat_eta = 0.10f; // learning rate + int n_probs = 0; + bool penalize_nl = false; + + std::unordered_map logit_bias; // logit bias for specific tokens + + std::string grammar = ""; // optional BNF-like grammar to constrain sampling + bool remember_generation = false; // remember a part of the prompt to avoid reprocessing all prompt + std::vector antiprompt; +}; + // completion token output with probabilities struct completion_token_output { @@ -69,6 +133,23 @@ static bool ends_with(const std::string &str, const std::string &suffix) 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix); } +static void slot_params_to_gpt_params(const slot_params &src, gpt_params & dst) +{ + dst.frequency_penalty = src.frequency_penalty; + dst.temp = src.temp; + dst.top_k = src.top_k; + dst.top_p = src.top_p; + dst.grammar = src.grammar; + dst.logit_bias = src.logit_bias; + dst.mirostat = src.mirostat; + dst.mirostat_eta = src.mirostat_eta; + dst.mirostat_tau = src.mirostat_tau; + dst.typical_p = src.typical_p; + dst.repeat_penalty = src.repeat_penalty; + dst.repeat_last_n = src.repeat_last_n; + dst.presence_penalty = src.presence_penalty; +} + static size_t find_partial_stop_string(const std::string &stop, const std::string &text) { @@ -162,49 +243,24 @@ static json probs_vector_to_json(const llama_context *ctx, const std::vector generated_token_probs; - - size_t num_prompt_tokens = 0; - size_t num_tokens_predicted = 0; + int id; + // generation props + int32_t num_prompt_tokens = 0; + int32_t n_decoded = 0; + int32_t i_batch = -1; size_t n_past = 0; - size_t n_remain = 0; - json prompt; - std::vector embd; + std::string generated_text = ""; + int num_tokens_predicted = 0; + llama_token sampled; + std::vector context_tokens; std::vector last_n_tokens; - - llama_model *model = nullptr; - llama_context *ctx = nullptr; - gpt_params params; - int n_ctx; - - grammar_parser::parse_state parsed_grammar; - llama_grammar *grammar = nullptr; - + std::vector generated_token_probs; + int sent_tokens = 0; + slot_state state = IDLE; + slot_command command = NONE; bool truncated = false; bool stopped_eos = false; bool stopped_word = false; @@ -212,6 +268,122 @@ struct llama_server_context std::string stopping_word; int32_t multibyte_pending = 0; + slot_params params; + + // grammar props + grammar_parser::parse_state parsed_grammar; + llama_grammar *grammar = nullptr; + + void reset() { + state = IDLE; + command = NONE; + num_prompt_tokens = 0; + num_tokens_predicted = 0; + generated_text = ""; + generated_token_probs.clear(); + truncated = false; + stopped_eos = false; + stopped_word = false; + stopped_limit = false; + stopping_word = ""; + multibyte_pending = 0; + n_past = 0; + + if (grammar != nullptr) { + llama_grammar_free(grammar); + grammar = nullptr; + } + + // llama_set_rng_seed(ctx, params.seed); in batched the seed matter??????? + } + + bool loadGrammar() + { + if (!params.grammar.empty()) { + parsed_grammar = grammar_parser::parse(params.grammar.c_str()); + // will be empty (default) if there are parse errors + if (parsed_grammar.rules.empty()) { + LOG_ERROR("grammar parse error", {{"grammar", params.grammar}}); + return false; + } + grammar_parser::print_grammar(stderr, parsed_grammar); + + // TODO: fix this comment + // { + // auto it = params.logit_bias.find(llama_token_eos(ctx)); + // if (it != params.logit_bias.end() && it->second == -INFINITY) { + // LOG_WARNING("EOS token is disabled, which will cause most grammars to fail", {}); + // } + // } + + std::vector grammar_rules(parsed_grammar.c_rules()); + grammar = llama_grammar_init( + grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); + } + return true; + } + + bool hasNewToken() { + return generated_token_probs.size() > sent_tokens; + } + + bool available() { + return state == IDLE && + command == NONE && + !params.remember_generation; + } + + bool isProcessing() { + return (state == IDLE && command == LOAD_PROMPT) || state == PROCESSING; + } + + completion_token_output next() { + completion_token_output tkn = generated_token_probs.at(sent_tokens); + sent_tokens++; + return tkn; + } + + void addTokenString(completion_token_output token) { + if(command == RELEASE) { + generated_token_probs.clear(); + sent_tokens = 0; + return; + } + context_tokens.push_back(token.tok); + generated_token_probs.push_back(token); + num_tokens_predicted++; + } + + void release() { + if(state == PROCESSING) { + command = RELEASE; + } + } +}; + +struct llama_server_context +{ + std::vector slots; + + // system prompt + std::string system_prompt = ""; + bool update_system_prompt = false; + std::vector tokens_system; + int32_t n_tokens_system = 0; + + // broadcast to all clients to keep the same prompt format + std::string user_name = ""; // this should be the anti prompt + std::string assistant_name = ""; // this is for generate the prompt + + llama_model *model = nullptr; + llama_context *ctx = nullptr; + llama_batch batch; + std::vector candidates; + bool all_slots_are_idle = false; + gpt_params params; + int n_ctx; + int n_vocab; + std::mutex mutex; std::unique_lock lock() @@ -231,29 +403,11 @@ struct llama_server_context llama_free_model(model); model = nullptr; } - } - void rewind() - { - params.antiprompt.clear(); - params.grammar.clear(); - num_prompt_tokens = 0; - num_tokens_predicted = 0; - generated_text = ""; - generated_text.reserve(n_ctx); - generated_token_probs.clear(); - truncated = false; - stopped_eos = false; - stopped_word = false; - stopped_limit = false; - stopping_word = ""; - multibyte_pending = 0; - n_remain = 0; - n_past = 0; - - if (grammar != nullptr) { - llama_grammar_free(grammar); - grammar = nullptr; + for(auto &slot : slots) { + if(slot.grammar) { + llama_grammar_free(slot.grammar); + } } } @@ -267,11 +421,31 @@ struct llama_server_context return false; } n_ctx = llama_n_ctx(ctx); - last_n_tokens.resize(n_ctx); - std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); + n_vocab = llama_n_vocab(model); + candidates.reserve(n_vocab); return true; } + void initialize() { + // create slots + LOG_TEE("Available slots:\n"); + for (int i = 0; i < params.n_parallel; i++) + { + llama_client_slot slot; + slot.id = i; + slot.last_n_tokens.resize(params.n_predict); // max prediction per slot + slot.reset(); + std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end(), 0); + LOG_TEE(" - slot %i\n", slot.id); + slots.push_back(slot); + } + batch = llama_batch_init(params.n_ctx, 0); + + // empty system prompt + system_prompt = ""; + all_slots_are_idle = true; + } + std::vector tokenize(const json & json_prompt, bool add_bos) const { // If `add_bos` is true, we only add BOS, when json_prompt is a string, @@ -317,305 +491,236 @@ struct llama_server_context return prompt_tokens; } - bool loadGrammar() - { - if (!params.grammar.empty()) { - parsed_grammar = grammar_parser::parse(params.grammar.c_str()); - // will be empty (default) if there are parse errors - if (parsed_grammar.rules.empty()) { - LOG_ERROR("grammar parse error", {{"grammar", params.grammar}}); - return false; - } - grammar_parser::print_grammar(stderr, parsed_grammar); + void processPrompt() { + + //params.n_keep = std::min(n_ctx - 4, params.n_keep); + + // if input prompt is too big, truncate like normal + // if (num_prompt_tokens >= (size_t)n_ctx) + // { + // const int n_left = (n_ctx - params.n_keep) / 2; + // std::vector new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep); + // const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left; + // new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end()); + // std::copy(prompt_tokens.end() - n_ctx, prompt_tokens.end(), last_n_tokens.begin()); + + // LOG_VERBOSE("input truncated", { + // {"n_ctx", n_ctx}, + // {"n_keep", params.n_keep}, + // {"n_left", n_left}, + // {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())}, + // }); + + // truncated = true; + // prompt_tokens = new_tokens; + // } + // else + // { + // const size_t ps = num_prompt_tokens; + // std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0); + // std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps); + // } + + // compare the evaluated prompt with the new prompt + } + + + llama_client_slot* getSlot(int id) { + for (llama_client_slot & slot : slots) + { + if ((id == -1 && slot.available()) || slot.id == id) { - auto it = params.logit_bias.find(llama_token_eos(ctx)); - if (it != params.logit_bias.end() && it->second == -INFINITY) { - LOG_WARNING("EOS token is disabled, which will cause most grammars to fail", {}); - } + return &slot; } - - std::vector grammar_rules(parsed_grammar.c_rules()); - grammar = llama_grammar_init( - grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); } + return nullptr; + } + + bool launchSlot(llama_client_slot* &slot) { + if(!slot->loadGrammar()) { + return false; + } + all_slots_are_idle = false; + slot->command = LOAD_PROMPT; + LOG_TEE("slot %i is processing\n", slot->id); return true; } void loadInfill() { - bool suff_rm_leading_spc = true; - if (params.input_suffix.find_first_of(" ") == 0 && params.input_suffix.size() > 1) { - params.input_suffix.erase(0, 1); - suff_rm_leading_spc = false; - } + // bool suff_rm_leading_spc = true; + // if (params.input_suffix.find_first_of(" ") == 0 && params.input_suffix.size() > 1) { + // params.input_suffix.erase(0, 1); + // suff_rm_leading_spc = false; + // } - auto prefix_tokens = tokenize(params.input_prefix, false); - auto suffix_tokens = tokenize(params.input_suffix, false); - const int space_token = 29871; - if (suff_rm_leading_spc && suffix_tokens[0] == space_token) { - suffix_tokens.erase(suffix_tokens.begin()); - } - prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(ctx)); - prefix_tokens.insert(prefix_tokens.begin(), llama_token_bos(ctx)); // always add BOS - prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(ctx)); - prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end()); - prefix_tokens.push_back(llama_token_middle(ctx)); - auto prompt_tokens = prefix_tokens; + // auto prefix_tokens = tokenize(params.input_prefix, false); + // auto suffix_tokens = tokenize(params.input_suffix, false); + // const int space_token = 29871; + // if (suff_rm_leading_spc && suffix_tokens[0] == space_token) { + // suffix_tokens.erase(suffix_tokens.begin()); + // } + // prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(ctx)); + // prefix_tokens.insert(prefix_tokens.begin(), llama_token_bos(ctx)); // always add BOS + // prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(ctx)); + // prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end()); + // prefix_tokens.push_back(llama_token_middle(ctx)); + // auto prompt_tokens = prefix_tokens; - num_prompt_tokens = prompt_tokens.size(); + // num_prompt_tokens = prompt_tokens.size(); - if (params.n_keep < 0) - { - params.n_keep = (int)num_prompt_tokens; - } - params.n_keep = std::min(params.n_ctx - 4, params.n_keep); + // if (params.n_keep < 0) + // { + // params.n_keep = (int)num_prompt_tokens; + // } + // params.n_keep = std::min(params.n_ctx - 4, params.n_keep); - // if input prompt is too big, truncate like normal - if (num_prompt_tokens >= (size_t)params.n_ctx) - { - printf("Input prompt is too big, truncating. Can only take %d tokens but got %zu\n", params.n_ctx, num_prompt_tokens); - // todo we probably want to cut from both sides - const int n_left = (params.n_ctx - params.n_keep) / 2; - std::vector new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep); - const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left; - new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end()); - std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), last_n_tokens.begin()); + // // if input prompt is too big, truncate like normal + // if (num_prompt_tokens >= (size_t)params.n_ctx) + // { + // printf("Input prompt is too big, truncating. Can only take %d tokens but got %zu\n", params.n_ctx, num_prompt_tokens); + // // todo we probably want to cut from both sides + // const int n_left = (params.n_ctx - params.n_keep) / 2; + // std::vector new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep); + // const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left; + // new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end()); + // std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), last_n_tokens.begin()); - LOG_VERBOSE("input truncated", { - {"n_ctx", params.n_ctx}, - {"n_keep", params.n_keep}, - {"n_left", n_left}, - {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())}, - }); + // LOG_VERBOSE("input truncated", { + // {"n_ctx", params.n_ctx}, + // {"n_keep", params.n_keep}, + // {"n_left", n_left}, + // {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())}, + // }); - truncated = true; - prompt_tokens = new_tokens; - } - else - { - const size_t ps = num_prompt_tokens; - std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0); - std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps); - } + // truncated = true; + // prompt_tokens = new_tokens; + // } + // else + // { + // const size_t ps = num_prompt_tokens; + // std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0); + // std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps); + // } - // compare the evaluated prompt with the new prompt - n_past = common_part(embd, prompt_tokens); - embd = prompt_tokens; - if (n_past == num_prompt_tokens) - { - // we have to evaluate at least 1 token to generate logits. - printf("we have to evaluate at least 1 token to generate logits\n"); - n_past--; - } + // // compare the evaluated prompt with the new prompt + // n_past = common_part(embd, prompt_tokens); + // embd = prompt_tokens; + // if (n_past == num_prompt_tokens) + // { + // // we have to evaluate at least 1 token to generate logits. + // printf("we have to evaluate at least 1 token to generate logits\n"); + // n_past--; + // } - LOG_VERBOSE("prompt ingested", { - {"n_past", n_past}, - {"cached", tokens_to_str(ctx, embd.cbegin(), embd.cbegin() + n_past)}, - {"to_eval", tokens_to_str(ctx, embd.cbegin() + n_past, embd.cend())}, - }); + // LOG_VERBOSE("prompt ingested", { + // {"n_past", n_past}, + // {"cached", tokens_to_str(ctx, embd.cbegin(), embd.cbegin() + n_past)}, + // {"to_eval", tokens_to_str(ctx, embd.cbegin() + n_past, embd.cend())}, + // }); - has_next_token = true; - } - void loadPrompt() - { - auto prompt_tokens = tokenize(prompt, true); // always add BOS - - num_prompt_tokens = prompt_tokens.size(); - - if (params.n_keep < 0) - { - params.n_keep = (int)num_prompt_tokens; - } - params.n_keep = std::min(n_ctx - 4, params.n_keep); - - // if input prompt is too big, truncate like normal - if (num_prompt_tokens >= (size_t)n_ctx) - { - const int n_left = (n_ctx - params.n_keep) / 2; - std::vector new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep); - const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left; - new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end()); - std::copy(prompt_tokens.end() - n_ctx, prompt_tokens.end(), last_n_tokens.begin()); - - LOG_VERBOSE("input truncated", { - {"n_ctx", n_ctx}, - {"n_keep", params.n_keep}, - {"n_left", n_left}, - {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())}, - }); - - truncated = true; - prompt_tokens = new_tokens; - } - else - { - const size_t ps = num_prompt_tokens; - std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0); - std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps); - } - - // compare the evaluated prompt with the new prompt - n_past = common_part(embd, prompt_tokens); - - // since #3228 we now have to manually manage the KV cache - llama_kv_cache_seq_rm(ctx, 0, n_past, -1); - - embd = prompt_tokens; - if (n_past == num_prompt_tokens) - { - // we have to evaluate at least 1 token to generate logits. - n_past--; - } - - LOG_VERBOSE("prompt ingested", { - {"n_past", n_past}, - {"cached", tokens_to_str(ctx, embd.cbegin(), embd.cbegin() + n_past)}, - {"to_eval", tokens_to_str(ctx, embd.cbegin() + n_past, embd.cend())}, - }); - - has_next_token = true; + // has_next_token = true; } - void beginCompletion() - { - // number of tokens to keep when resetting context - n_remain = params.n_predict; - llama_set_rng_seed(ctx, params.seed); + void updateSystemPrompt() { + tokens_system = ::llama_tokenize(ctx, system_prompt, true); + n_tokens_system = tokens_system.size(); + + batch.n_tokens = n_tokens_system; + + // clear the entire KV cache + for (int i = 0; i < params.n_parallel; ++i) + { + llama_kv_cache_seq_rm(ctx, i, 0, -1); + } + + for (int32_t i = 0; i < batch.n_tokens; ++i) + { + batch.token[i] = tokens_system[i]; + batch.pos[i] = i; + batch.seq_id[i] = 0; + batch.logits[i] = false; + } + + if (llama_decode(ctx, batch) != 0) + { + LOG_TEE("%s: llama_decode() failed\n", __func__); + return; + } + + // assign the system KV cache to all parallel sequences + for (int32_t i = 1; i < params.n_parallel; ++i) + { + llama_kv_cache_seq_cp(ctx, 0, i, 0, n_tokens_system); + } + + LOG_TEE("system prompt updated\n"); + update_system_prompt = false; } - completion_token_output nextToken() - { - completion_token_output result; - result.tok = -1; - - if (embd.size() >= (size_t)n_ctx) + void notifySystemPromptChanged() { + // release all slots + for (llama_client_slot &slot : slots) { - // Shift context - - const int n_left = n_past - params.n_keep - 1; - const int n_discard = n_left/2; - - llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1); - llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard); - - for (size_t i = params.n_keep + 1 + n_discard; i < embd.size(); i++) - { - embd[i - n_discard] = embd[i]; - } - embd.resize(embd.size() - n_discard); - - n_past -= n_discard; - - truncated = true; - LOG_VERBOSE("input truncated", { - {"n_ctx", n_ctx}, - {"n_keep", params.n_keep}, - {"n_left", n_left}, - }); + slot.release(); } - - bool tg = true; - while (n_past < embd.size()) - { - int n_eval = (int)embd.size() - n_past; - tg = n_eval == 1; - if (n_eval > params.n_batch) - { - n_eval = params.n_batch; - } - - if (llama_decode(ctx, llama_batch_get_one(&embd[n_past], n_eval, n_past, 0))) - { - LOG_ERROR("failed to eval", { - {"n_eval", n_eval}, - {"n_past", n_past}, - {"embd", tokens_to_str(ctx, embd.cbegin() + n_past, embd.cend())}, - }); - has_next_token = false; - return result; - } - n_past += n_eval; + waitAllAreIdle(); + all_slots_are_idle = true; + // wait until system prompt load + update_system_prompt = true; + while(update_system_prompt) { + std::this_thread::sleep_for(std::chrono::milliseconds(5)); } - - if (params.n_predict == 0) - { - has_next_token = false; - result.tok = llama_token_eos(ctx); - return result; - } - - { - // out of user input, sample next token - std::vector candidates; - candidates.reserve(llama_n_vocab(model)); - - result.tok = llama_sample_token(ctx, NULL, grammar, params, last_n_tokens, candidates); - - llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; - - const int32_t n_probs = params.n_probs; - if (params.temp <= 0 && n_probs > 0) - { - // For llama_sample_token_greedy we need to sort candidates - llama_sample_softmax(ctx, &candidates_p); - } - - for (size_t i = 0; i < std::min(candidates_p.size, (size_t)n_probs); ++i) - { - result.probs.push_back({candidates_p.data[i].id, candidates_p.data[i].p}); - } - - last_n_tokens.erase(last_n_tokens.begin()); - last_n_tokens.push_back(result.tok); - if (tg) { - num_tokens_predicted++; - } - } - - // add it to the context - embd.push_back(result.tok); - // decrement remaining sampling budget - --n_remain; - - if (!embd.empty() && embd.back() == llama_token_eos(ctx)) - { - // stopping_word = llama_token_to_piece(ctx, embd.back()); - has_next_token = false; - stopped_eos = true; - LOG_VERBOSE("eos token found", {}); - return result; - } - - has_next_token = params.n_predict == -1 || n_remain != 0; - return result; + // system prompt loaded, continue } - size_t findStoppingStrings(const std::string &text, const size_t last_token_size, - const stop_type type) + void processSystemPromptData(json sys_props) { + system_prompt = sys_props.value("system_prompt", ""); + user_name = sys_props.value("anti_prompt", ""); + assistant_name = sys_props.value("assistant_name", ""); + notifySystemPromptChanged(); + } + + void waitAllAreIdle() { + bool wait = true; + while(wait) { + wait = false; + for (auto &slot : slots) + { + if (!slot.available()) + { + wait = true; + break; + } + } + } + } + + size_t findStoppingStrings(const size_t last_token_size, + const stop_type type, llama_client_slot & slot) { size_t stop_pos = std::string::npos; - for (const std::string &word : params.antiprompt) + for (const std::string &word : slot.params.antiprompt) { size_t pos; if (type == STOP_FULL) { const size_t tmp = word.size() + last_token_size; - const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0; - pos = text.find(word, from_pos); + const size_t from_pos = slot.generated_text.size() > tmp ? slot.generated_text.size() - tmp : 0; + pos = slot.generated_text.find(word, from_pos); } else { - pos = find_partial_stop_string(word, text); + pos = find_partial_stop_string(word, slot.generated_text); } if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) { if (type == STOP_FULL) { - stopping_word = word; - stopped_word = true; - has_next_token = false; + slot.stopping_word = word; + slot.stopped_word = true; } stop_pos = pos; } @@ -623,70 +728,255 @@ struct llama_server_context return stop_pos; } - completion_token_output doCompletion() - { - auto token_with_probs = nextToken(); + bool processToken(completion_token_output & result, llama_client_slot & slot) { + // remember which tokens were sampled - used for repetition penalties during sampling + slot.last_n_tokens.erase(slot.last_n_tokens.begin()); + slot.last_n_tokens.push_back(result.tok); + const std::string token_str = llama_token_to_piece(ctx, result.tok); + slot.sampled = result.tok; - const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_piece(ctx, token_with_probs.tok); - generated_text += token_text; + size_t stop_pos = + findStoppingStrings(token_str.size(), STOP_FULL, slot); - if (params.n_probs > 0) + slot.addTokenString(result); + + slot.generated_text += token_str; + + bool has_next_token = !(slot.n_decoded > 2 && + (result.tok == llama_token_eos(ctx) || + (slot.n_decoded + slot.n_past >= + params.n_predict) || + stop_pos != std::string::npos)); + + if (slot.params.n_probs > 0) { - generated_token_probs.push_back(token_with_probs); + slot.generated_token_probs.push_back(result); } - if (multibyte_pending > 0) + if (slot.multibyte_pending > 0) { - multibyte_pending -= token_text.size(); + slot.multibyte_pending -= token_str.size(); } - else if (token_text.size() == 1) + else if (token_str.size() == 1) { - const char c = token_text[0]; + const char c = token_str[0]; // 2-byte characters: 110xxxxx 10xxxxxx if ((c & 0xE0) == 0xC0) { - multibyte_pending = 1; + slot.multibyte_pending = 1; // 3-byte characters: 1110xxxx 10xxxxxx 10xxxxxx } else if ((c & 0xF0) == 0xE0) { - multibyte_pending = 2; + slot.multibyte_pending = 2; // 4-byte characters: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx } else if ((c & 0xF8) == 0xF0) { - multibyte_pending = 3; + slot.multibyte_pending = 3; } else { - multibyte_pending = 0; + slot.multibyte_pending = 0; } } - if (multibyte_pending > 0 && !has_next_token) + if (slot.multibyte_pending > 0 && !has_next_token) { has_next_token = true; - n_remain++; } - if (!has_next_token && n_remain == 0) + if (!has_next_token && (slot.n_decoded + slot.n_past >= params.n_predict)) { - stopped_limit = true; + slot.stopped_limit = true; + } + + if (!slot.context_tokens.empty() && result.tok == llama_token_eos(ctx)){ + slot.stopped_eos = true; + LOG_VERBOSE("eos token found", {}); } LOG_VERBOSE("next token", { - {"token", token_with_probs.tok}, - {"token_text", tokens_to_output_formatted_string(ctx, token_with_probs.tok)}, + {"token", result.tok}, + {"token_text", tokens_to_output_formatted_string(ctx, result.tok)}, {"has_next_token", has_next_token}, - {"n_remain", n_remain}, - {"num_tokens_predicted", num_tokens_predicted}, - {"stopped_eos", stopped_eos}, - {"stopped_word", stopped_word}, - {"stopped_limit", stopped_limit}, - {"stopping_word", stopping_word}, + {"n_remain", (params.n_predict - slot.n_decoded + slot.n_past)}, + {"num_tokens_predicted", slot.num_tokens_predicted}, + {"stopped_eos", slot.stopped_eos}, + {"stopped_word", slot.stopped_word}, + {"stopped_limit", slot.stopped_limit}, + {"stopping_word", slot.stopping_word}, }); + return has_next_token; // continue + } - return token_with_probs; + bool updateSlots() { + + // update the system prompt wait until all slots are idle state + if(update_system_prompt) { + updateSystemPrompt(); + } + + batch.n_tokens = 0; + int kv_cache_free = (n_ctx - n_tokens_system); + + if(all_slots_are_idle) { + // avoid 100% usage of cpu all time + std::this_thread::sleep_for(std::chrono::milliseconds(5)); + } + + // decode any currently ongoing sequences + for (auto & slot : slots) { + // release the slot + if (slot.state == PROCESSING && slot.command == RELEASE && !slot.hasNewToken()) + { + LOG_TEE("slot %i released\n", slot.id); + if(!slot.params.remember_generation) { + llama_kv_cache_seq_rm(ctx, slot.id, n_tokens_system, n_ctx); + slot.state = IDLE; + slot.command = NONE; + slot.num_prompt_tokens = 0; + slot.num_tokens_predicted = 0; + } else { + slot.state = SLEEPING; + slot.command = NONE; + } + continue; + } + + kv_cache_free -= slot.num_prompt_tokens; + + if (slot.state == IDLE || slot.command == RELEASE) { + continue; + } + + batch.token [batch.n_tokens] = slot.sampled; + batch.pos [batch.n_tokens] = n_tokens_system + slot.n_past + slot.n_decoded; + batch.seq_id[batch.n_tokens] = slot.id; + batch.logits[batch.n_tokens] = true; + + slot.n_decoded += 1; + slot.i_batch = batch.n_tokens; + + batch.n_tokens += 1; + } + + // assign workload to the slots + if (params.cont_batching || batch.n_tokens == 0) { + for (auto & slot : slots) { + // need process the prompt + bool keep_gen = slot.state == SLEEPING; // remember generation + if ((slot.state == IDLE || keep_gen) && slot.command == LOAD_PROMPT) { + slot.state = PROCESSING; + slot.command = NONE; + + auto prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt + slot.num_prompt_tokens = prompt_tokens.size(); + + slot.n_past = keep_gen ? common_part(slot.context_tokens, prompt_tokens) : 0; + + slot.context_tokens = prompt_tokens; + + LOG_VERBOSE("prompt ingested", { + {"n_past", slot.n_past}, + {"cached", tokens_to_str(ctx, slot.context_tokens.cbegin(), slot.context_tokens.cbegin() + slot.n_past)}, + {"to_eval", tokens_to_str(ctx, slot.context_tokens.cbegin() + slot.n_past, slot.context_tokens.cend())}, + }); + + std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end(), 0); + + for (size_t i = slot.n_past; i < slot.context_tokens.size(); ++i) { + batch.token [batch.n_tokens] = slot.context_tokens[i]; + batch.pos [batch.n_tokens] = i + n_tokens_system; + batch.seq_id[batch.n_tokens] = slot.id; + batch.logits[batch.n_tokens] = false; + batch.n_tokens += 1; + } + + // extract the logits only for the last token + if (batch.n_tokens > 0) { + batch.logits[batch.n_tokens - 1] = true; + } + + slot.n_decoded = 0; + slot.i_batch = batch.n_tokens - 1; + } + } + } + + if (batch.n_tokens == 0) { + all_slots_are_idle = true; + return true; + } + + // process in chunks of params.n_batch + int32_t n_batch = params.n_batch; + + for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { + const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); + + llama_batch batch_view = { + n_tokens, + batch.token + i, + nullptr, + batch.pos + i, + batch.seq_id + i, + batch.logits + i, + 0, 0, 0, // unused + }; + + const int ret = llama_decode(ctx, batch_view); + if (ret != 0) { + if (n_batch == 1 || ret < 0) { + // if you get here, it means the KV cache is full - try increasing it via the context size + LOG_TEE("%s : failed to decode the batch, n_batch = %d, ret = %d\n", __func__, n_batch, ret); + return false; + } + + LOG("%s : failed to decode the batch, retrying with n_batch = %d\n", __func__, n_batch / 2); + + // retry with half the batch size to try to find a free slot in the KV cache + n_batch /= 2; + i -= n_batch; + continue; + } + + for (auto & slot : slots) { + if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) { + continue; + } + + slot_params_to_gpt_params(slot.params, params); + completion_token_output result; + const llama_token id = llama_sample_token(ctx, NULL, NULL, params, slot.last_n_tokens, candidates, slot.i_batch - i); + llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; + result.tok = id; + const int32_t n_probs = params.n_probs; + if (params.temp <= 0 && n_probs > 0) + { + // For llama_sample_token_greedy we need to sort candidates + llama_sample_softmax(ctx, &candidates_p); + } + + for (size_t i = 0; i < std::min(candidates_p.size, (size_t)n_probs); ++i) + { + result.probs.push_back({candidates_p.data[i].id, candidates_p.data[i].p}); + } + + if (!processToken(result, slot)) { + slot.generated_text.clear(); + slot.release(); + } + kv_cache_free -= slot.num_tokens_predicted; + slot.i_batch = -1; + } + } + + if(kv_cache_free < 0) { + LOG_TEE("\nError: kv cache is full, increase context size."); + return false; + } + return true; } std::vector getEmbedding() @@ -999,6 +1289,29 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, else if (arg == "--embedding") { params.embedding = true; + } else if (arg == "-cb" || arg == "--cont-batching") + { + params.cont_batching = true; + } + else if (arg == "-np" || arg == "--parallel") + { + if (++i >= argc) + { + invalid_param = true; + break; + } + params.n_parallel = std::stoi(argv[i]); + } else if (arg == "-n" || arg == "--n-predict") + { + if (++i >= argc) + { + invalid_param = true; + break; + } + params.n_predict = std::stoi(argv[i]); + if(params.n_predict <= 128) { // this example don't support long prompts + params.n_predict = 128; + } } else { @@ -1016,37 +1329,37 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, } } -static json format_generation_settings(llama_server_context &llama) +static json format_generation_settings(llama_server_context &llama, llama_client_slot* &slot) { - const auto eos_bias = llama.params.logit_bias.find(llama_token_eos(llama.ctx)); - const bool ignore_eos = eos_bias != llama.params.logit_bias.end() && + const auto eos_bias = slot->params.logit_bias.find(llama_token_eos(llama.ctx)); + const bool ignore_eos = eos_bias != slot->params.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second); return json{ {"n_ctx", llama.n_ctx}, {"model", llama.params.model_alias}, - {"seed", llama.params.seed}, - {"temp", llama.params.temp}, - {"top_k", llama.params.top_k}, - {"top_p", llama.params.top_p}, - {"tfs_z", llama.params.tfs_z}, - {"typical_p", llama.params.typical_p}, - {"repeat_last_n", llama.params.repeat_last_n}, - {"repeat_penalty", llama.params.repeat_penalty}, - {"presence_penalty", llama.params.presence_penalty}, - {"frequency_penalty", llama.params.frequency_penalty}, - {"mirostat", llama.params.mirostat}, - {"mirostat_tau", llama.params.mirostat_tau}, - {"mirostat_eta", llama.params.mirostat_eta}, - {"penalize_nl", llama.params.penalize_nl}, - {"stop", llama.params.antiprompt}, - {"n_predict", llama.params.n_predict}, - {"n_keep", llama.params.n_keep}, + {"seed", slot->params.seed}, + {"temp", slot->params.temp}, + {"top_k", slot->params.top_k}, + {"top_p", slot->params.top_p}, + {"tfs_z", slot->params.tfs_z}, + {"typical_p", slot->params.typical_p}, + {"repeat_last_n", slot->params.repeat_last_n}, + {"repeat_penalty", slot->params.repeat_penalty}, + {"presence_penalty",slot->params.presence_penalty}, + {"frequency_penalty", slot->params.frequency_penalty}, + {"mirostat", slot->params.mirostat}, + {"mirostat_tau", slot->params.mirostat_tau}, + {"mirostat_eta", slot->params.mirostat_eta}, + {"penalize_nl", slot->params.penalize_nl}, + {"stop", slot->params.antiprompt}, + {"n_predict", slot->params.n_predict}, + // {"n_keep", slot.params.n_keep}, {"ignore_eos", ignore_eos}, - {"stream", llama.stream}, - {"logit_bias", llama.params.logit_bias}, - {"n_probs", llama.params.n_probs}, - {"grammar", llama.params.grammar}, + {"stream", slot->params.stream}, + {"logit_bias", slot->params.logit_bias}, + {"n_probs", slot->params.n_probs}, + {"grammar", slot->params.grammar}, }; } @@ -1074,27 +1387,27 @@ static json format_timings(llama_server_context &llama) }; } -static json format_final_response(llama_server_context &llama, const std::string &content, const std::vector &probs) +static json format_final_response(llama_server_context &llama, llama_client_slot* &slot, const std::string &content, const std::vector &probs) { json res = json{ {"content", content}, {"stop", true}, {"model", llama.params.model_alias}, - {"tokens_predicted", llama.num_tokens_predicted}, - {"tokens_evaluated", llama.num_prompt_tokens}, - {"generation_settings", format_generation_settings(llama)}, - {"prompt", llama.prompt}, - {"truncated", llama.truncated}, - {"stopped_eos", llama.stopped_eos}, - {"stopped_word", llama.stopped_word}, - {"stopped_limit", llama.stopped_limit}, - {"stopping_word", llama.stopping_word}, - {"tokens_cached", llama.n_past}, - {"timings", format_timings(llama)}, + {"tokens_predicted", slot->num_tokens_predicted}, + {"tokens_evaluated", slot->num_prompt_tokens}, + {"generation_settings", format_generation_settings(llama, slot)}, + {"prompt", slot->prompt}, + {"truncated", slot->truncated}, + {"stopped_eos", slot->stopped_eos}, + {"stopped_word", slot->stopped_word}, + {"stopped_limit", slot->stopped_limit}, + {"stopping_word", slot->stopping_word}, + {"tokens_cached", slot->n_past}, + // {"timings", format_timings(llama)}, }; - if (llama.params.n_probs > 0) + if (slot->params.n_probs > 0) { res["completion_probabilities"] = probs_vector_to_json(llama.ctx, probs); } @@ -1103,14 +1416,15 @@ static json format_final_response(llama_server_context &llama, const std::string } static json format_partial_response( - llama_server_context &llama, const std::string &content, const std::vector &probs + llama_server_context &llama, llama_client_slot* &slot, const std::string &content, const std::vector &probs ) { json res = json{ {"content", content}, {"stop", false}, + { "slot_id", slot->id } }; - if (llama.params.n_probs > 0) + if (slot->params.n_probs > 0) { res["completion_probabilities"] = probs_vector_to_json(llama.ctx, probs); } @@ -1139,43 +1453,43 @@ static T json_value(const json &body, const std::string &key, const T &default_v : default_value; } -static void parse_options_completion(const json &body, llama_server_context &llama) +static void parse_options_completion(const json &body, llama_client_slot* &slot, llama_server_context &llama) { - gpt_params default_params; + slot_params default_params; - llama.stream = json_value(body, "stream", false); - llama.params.n_predict = json_value(body, "n_predict", default_params.n_predict); - llama.params.top_k = json_value(body, "top_k", default_params.top_k); - llama.params.top_p = json_value(body, "top_p", default_params.top_p); - llama.params.tfs_z = json_value(body, "tfs_z", default_params.tfs_z); - llama.params.typical_p = json_value(body, "typical_p", default_params.typical_p); - llama.params.repeat_last_n = json_value(body, "repeat_last_n", default_params.repeat_last_n); - llama.params.temp = json_value(body, "temperature", default_params.temp); - llama.params.repeat_penalty = json_value(body, "repeat_penalty", default_params.repeat_penalty); - llama.params.presence_penalty = json_value(body, "presence_penalty", default_params.presence_penalty); - llama.params.frequency_penalty = json_value(body, "frequency_penalty", default_params.frequency_penalty); - llama.params.mirostat = json_value(body, "mirostat", default_params.mirostat); - llama.params.mirostat_tau = json_value(body, "mirostat_tau", default_params.mirostat_tau); - llama.params.mirostat_eta = json_value(body, "mirostat_eta", default_params.mirostat_eta); - llama.params.penalize_nl = json_value(body, "penalize_nl", default_params.penalize_nl); - llama.params.n_keep = json_value(body, "n_keep", default_params.n_keep); - llama.params.seed = json_value(body, "seed", default_params.seed); - llama.params.grammar = json_value(body, "grammar", default_params.grammar); - llama.params.n_probs = json_value(body, "n_probs", default_params.n_probs); + slot->params.stream = json_value(body, "stream", false); + slot->params.n_predict = json_value(body, "n_predict", default_params.n_predict); + slot->params.top_k = json_value(body, "top_k", default_params.top_k); + slot->params.top_p = json_value(body, "top_p", default_params.top_p); + slot->params.tfs_z = json_value(body, "tfs_z", default_params.tfs_z); + slot->params.typical_p = json_value(body, "typical_p", default_params.typical_p); + slot->params.repeat_last_n = json_value(body, "repeat_last_n", default_params.repeat_last_n); + slot->params.temp = json_value(body, "temperature", default_params.temp); + slot->params.repeat_penalty = json_value(body, "repeat_penalty", default_params.repeat_penalty); + slot->params.presence_penalty = json_value(body, "presence_penalty", default_params.presence_penalty); + slot->params.frequency_penalty = json_value(body, "frequency_penalty", default_params.frequency_penalty); + slot->params.mirostat = json_value(body, "mirostat", default_params.mirostat); + slot->params.mirostat_tau = json_value(body, "mirostat_tau", default_params.mirostat_tau); + slot->params.mirostat_eta = json_value(body, "mirostat_eta", default_params.mirostat_eta); + slot->params.penalize_nl = json_value(body, "penalize_nl", default_params.penalize_nl); + //llama.params.n_keep = json_value(body, "n_keep", default_params.n_keep); + slot->params.seed = json_value(body, "seed", default_params.seed); + slot->params.grammar = json_value(body, "grammar", default_params.grammar); + slot->params.n_probs = json_value(body, "n_probs", default_params.n_probs); if (body.count("prompt") != 0) { - llama.prompt = body["prompt"]; + slot->prompt = body["prompt"]; } else { - llama.prompt = ""; + slot->prompt = ""; } - llama.params.logit_bias.clear(); + slot->params.logit_bias.clear(); if (json_value(body, "ignore_eos", false)) { - llama.params.logit_bias[llama_token_eos(llama.ctx)] = -INFINITY; + slot->params.logit_bias[llama_token_eos(llama.ctx)] = -INFINITY; } const auto &logit_bias = body.find("logit_bias"); @@ -1191,18 +1505,18 @@ static void parse_options_completion(const json &body, llama_server_context &lla { if (el[1].is_number()) { - llama.params.logit_bias[tok] = el[1].get(); + slot->params.logit_bias[tok] = el[1].get(); } else if (el[1].is_boolean() && !el[1].get()) { - llama.params.logit_bias[tok] = -INFINITY; + slot->params.logit_bias[tok] = -INFINITY; } } } } } - llama.params.antiprompt.clear(); + slot->params.antiprompt.clear(); const auto &stop = body.find("stop"); if (stop != body.end() && stop->is_array()) { @@ -1210,34 +1524,34 @@ static void parse_options_completion(const json &body, llama_server_context &lla { if (!word.empty()) { - llama.params.antiprompt.push_back(word); + slot->params.antiprompt.push_back(word); } } } - LOG_VERBOSE("completion parameters parsed", format_generation_settings(llama)); + LOG_VERBOSE("completion parameters parsed", format_generation_settings(llama, slot)); } -static void parse_options_infill(const json &body, llama_server_context &llama) -{ - if (body.count("input_prefix") != 0) - { - llama.params.input_prefix = body["input_prefix"]; - } - else - { - llama.params.input_prefix = ""; - } - if (body.count("input_suffix") != 0) - { - llama.params.input_suffix = body["input_suffix"]; - } - else - { - llama.params.input_suffix = ""; - } - parse_options_completion(body, llama); -} +// static void parse_options_infill(const json &body, llama_server_context &llama) +// { +// if (body.count("input_prefix") != 0) +// { +// llama.params.input_prefix = body["input_prefix"]; +// } +// else +// { +// llama.params.input_prefix = ""; +// } +// if (body.count("input_suffix") != 0) +// { +// llama.params.input_suffix = body["input_suffix"]; +// } +// else +// { +// llama.params.input_suffix = ""; +// } +// parse_options_completion(body, slot, llama); +// } static void log_server_request(const Request &req, const Response &res) { @@ -1266,32 +1580,35 @@ static bool is_at_eob(llama_server_context &server_context, const llama_token *t // * When all beams converge to a common prefix, they are made available in beams_state.beams[0]. // This is also called when the stop condition is met. // Collect tokens into std::vector response which is pointed to by callback_data. -static void beam_search_callback(void *callback_data, llama_beams_state beams_state) { - auto & llama = *static_cast(callback_data); - // Mark beams as EOS as needed. - for (size_t i = 0 ; i < beams_state.n_beams ; ++i) { - llama_beam_view& beam_view = beams_state.beam_views[i]; - if (!beam_view.eob && is_at_eob(llama, beam_view.tokens, beam_view.n_tokens)) { - beam_view.eob = true; - } - } - printf(","); // Show progress - if (const size_t n = beams_state.common_prefix_length) { - llama.generated_token_probs.resize(llama.generated_token_probs.size() + n); - assert(0u < beams_state.n_beams); - const llama_token * tokens = beams_state.beam_views[0].tokens; - const auto map = [](llama_token tok) { return completion_token_output{{},tok}; }; - std::transform(tokens, tokens + n, llama.generated_token_probs.end() - n, map); - printf("%zu", n); - } - fflush(stdout); -#if 0 // DEBUG: print current beams for this iteration - std::cout << "\n\nCurrent beams:\n"; - for (size_t i=0 ; i < beams_state.n_beams ; ++i) { - std::cout << "beams["<(callback_data); +// // Mark beams as EOS as needed. +// for (size_t i = 0 ; i < beams_state.n_beams ; ++i) { +// llama_beam_view& beam_view = beams_state.beam_views[i]; +// if (!beam_view.eob && is_at_eob(llama, beam_view.tokens, beam_view.n_tokens)) { +// beam_view.eob = true; +// } +// } +// printf(","); // Show progress +// if (const size_t n = beams_state.common_prefix_length) { +// llama.generated_token_probs.resize(llama.generated_token_probs.size() + n); +// assert(0u < beams_state.n_beams); +// const llama_token * tokens = beams_state.beam_views[0].tokens; +// const auto map = [](llama_token tok) { return completion_token_output{{},tok}; }; +// std::transform(tokens, tokens + n, llama.generated_token_probs.end() - n, map); +// printf("%zu", n); +// } +// fflush(stdout); +// #if 0 // DEBUG: print current beams for this iteration +// std::cout << "\n\nCurrent beams:\n"; +// for (size_t i=0 ; i < beams_state.n_beams ; ++i) { +// std::cout << "beams["<(&json_schema_to_grammar_mjs), json_schema_to_grammar_mjs_len, "application/javascript"); return false; }); - svr.Post("/completion", [&llama](const Request &req, Response &res) + svr.Get("/props", [&llama](const Request & /*req*/, Response &res) + { + res.set_header("Access-Control-Allow-Origin", "*"); + json data = { + { "user_name", llama.user_name.c_str() }, + { "assistant_name", llama.assistant_name.c_str() } + }; + res.set_content(data.dump(), "application/json"); }); + + svr.Post("/completion", [&](const Request &req, Response &res) { - auto lock = llama.lock(); + //auto lock = llama.lock(); - llama.rewind(); + json data = json::parse(req.body); - llama_reset_timings(llama.ctx); + llama_client_slot* slot = llama.getSlot(json_value(data, "slot_id", -1)); - parse_options_completion(json::parse(req.body), llama); + if(slot == nullptr) { + LOG_TEE("slot unavailable\n"); + res.status = 404; + res.set_content("slot_error", "text/plain"); + return; + } - if (!llama.loadGrammar()) + if(data.contains("system_prompt")) { + llama.processSystemPromptData(data["system_prompt"]); + } + + // llama_reset_timings(llama.ctx); + + slot->reset(); + + parse_options_completion(json::parse(req.body), slot, llama); + + if (!llama.launchSlot(slot)) { res.status = 400; return; } - llama.loadPrompt(); - llama.beginCompletion(); + if (!slot->params.stream) { + // if (llama.params.n_beams) { + // // Fill llama.generated_token_probs vector with final beam. + // llama_beam_search(llama.ctx, beam_search_callback, &llama, llama.params.n_beams, + // llama.n_past, llama.n_remain); + // // Translate llama.generated_token_probs to llama.generated_text. + // append_to_generated_text_from_generated_token_probs(llama); + // } else { + // size_t stop_pos = std::string::npos; - if (!llama.stream) { - if (llama.params.n_beams) { - // Fill llama.generated_token_probs vector with final beam. - llama_beam_search(llama.ctx, beam_search_callback, &llama, llama.params.n_beams, - llama.n_past, llama.n_remain); - // Translate llama.generated_token_probs to llama.generated_text. - append_to_generated_text_from_generated_token_probs(llama); - } else { - size_t stop_pos = std::string::npos; + // while (llama.has_next_token) { + // const completion_token_output token_with_probs = llama.doCompletion(); + // const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_piece(llama.ctx, token_with_probs.tok); - while (llama.has_next_token) { - const completion_token_output token_with_probs = llama.doCompletion(); - const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_piece(llama.ctx, token_with_probs.tok); + // stop_pos = llama.findStoppingStrings(llama.generated_text, + // token_text.size(), STOP_FULL); + // } - stop_pos = llama.findStoppingStrings(llama.generated_text, - token_text.size(), STOP_FULL); - } + // if (stop_pos == std::string::npos) { + // stop_pos = llama.findStoppingStrings(llama.generated_text, 0, STOP_PARTIAL); + // } + // if (stop_pos != std::string::npos) { + // llama.generated_text.erase(llama.generated_text.begin() + stop_pos, + // llama.generated_text.end()); + // } + // } - if (stop_pos == std::string::npos) { - stop_pos = llama.findStoppingStrings(llama.generated_text, 0, STOP_PARTIAL); - } - if (stop_pos != std::string::npos) { - llama.generated_text.erase(llama.generated_text.begin() + stop_pos, - llama.generated_text.end()); - } - } + // auto probs = llama.generated_token_probs; + // if (llama.params.n_probs > 0 && llama.stopped_word) { + // const std::vector stop_word_toks = llama_tokenize(llama.ctx, llama.stopping_word, false); + // probs = std::vector(llama.generated_token_probs.begin(), llama.generated_token_probs.end() - stop_word_toks.size()); + // } - auto probs = llama.generated_token_probs; - if (llama.params.n_probs > 0 && llama.stopped_word) { - const std::vector stop_word_toks = llama_tokenize(llama.ctx, llama.stopping_word, false); - probs = std::vector(llama.generated_token_probs.begin(), llama.generated_token_probs.end() - stop_word_toks.size()); - } + // const json data = format_final_response(llama, llama.generated_text, probs); - const json data = format_final_response(llama, llama.generated_text, probs); + // llama_print_timings(llama.ctx); - llama_print_timings(llama.ctx); - - res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace), - "application/json"); + // res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace), + // "application/json"); } else { - const auto chunked_content_provider = [&](size_t, DataSink & sink) { - size_t sent_count = 0; - size_t sent_token_probs_index = 0; + auto chunked_content_provider = [&](size_t /*offset*/, DataSink &sink) { + size_t sent_count = 0; + size_t sent_token_probs_index = 0; + while(slot->isProcessing()) { + if(slot->hasNewToken()) { // new token notification + const completion_token_output token = slot->next(); + std::string token_str = llama_token_to_piece(llama.ctx, token.tok); - while (llama.has_next_token) { - const completion_token_output token_with_probs = llama.doCompletion(); - if (token_with_probs.tok == -1 || llama.multibyte_pending > 0) { - continue; - } - const std::string token_text = llama_token_to_piece(llama.ctx, token_with_probs.tok); + std::vector probs_output = {}; - size_t pos = std::min(sent_count, llama.generated_text.size()); + const json data = format_partial_response(llama, slot, token_str, probs_output); + const std::string str = + "data: " + + data.dump(-1, ' ', false, json::error_handler_t::replace) + + "\n\n"; - const std::string str_test = llama.generated_text.substr(pos); - bool is_stop_full = false; - size_t stop_pos = - llama.findStoppingStrings(str_test, token_text.size(), STOP_FULL); - if (stop_pos != std::string::npos) { - is_stop_full = true; - llama.generated_text.erase( - llama.generated_text.begin() + pos + stop_pos, - llama.generated_text.end()); - pos = std::min(sent_count, llama.generated_text.size()); - } else { - is_stop_full = false; - stop_pos = llama.findStoppingStrings(str_test, token_text.size(), - STOP_PARTIAL); - } - - if ( - stop_pos == std::string::npos || - // Send rest of the text if we are at the end of the generation - (!llama.has_next_token && !is_stop_full && stop_pos > 0) - ) { - const std::string to_send = llama.generated_text.substr(pos, std::string::npos); - - sent_count += to_send.size(); - - std::vector probs_output = {}; - - if (llama.params.n_probs > 0) { - const std::vector to_send_toks = llama_tokenize(llama.ctx, to_send, false); - size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size()); - size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size()); - if (probs_pos < probs_stop_pos) { - probs_output = std::vector(llama.generated_token_probs.begin() + probs_pos, llama.generated_token_probs.begin() + probs_stop_pos); + LOG_VERBOSE("data stream", { + { "to_send", str } + }); + if(!sink.write(str.c_str(), str.size())) { + slot->release(); + return false; } - sent_token_probs_index = probs_stop_pos; - } - - const json data = format_partial_response(llama, to_send, probs_output); - - const std::string str = - "data: " + - data.dump(-1, ' ', false, json::error_handler_t::replace) + - "\n\n"; - - LOG_VERBOSE("data stream", { - { "to_send", str } - }); - - if (!sink.write(str.data(), str.size())) { - LOG_VERBOSE("stream closed", {}); - llama_print_timings(llama.ctx); - return false; + } else { + std::this_thread::sleep_for(std::chrono::milliseconds(5)); } } - - if (!llama.has_next_token) { - // Generation is done, send extra information. - const json data = format_final_response( - llama, + const json data = format_final_response( + llama, slot, "", - std::vector(llama.generated_token_probs.begin(), llama.generated_token_probs.begin() + sent_token_probs_index) + std::vector( + slot->generated_token_probs.begin(), + slot->generated_token_probs.begin() + sent_token_probs_index) ); const std::string str = @@ -1525,145 +1824,144 @@ int main(int argc, char **argv) llama_print_timings(llama.ctx); return false; } - } - } - - llama_print_timings(llama.ctx); - sink.done(); - return true; + sink.done(); + return true; + }; + auto on_complete = [&] (bool) { + //llama.mutex.unlock(); + slot->sent_tokens = 0; + slot->generated_token_probs.clear(); + slot->release(); }; - const auto on_complete = [&](bool) { - llama.mutex.unlock(); - }; - lock.release(); + //lock.release(); res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); } }); - svr.Post("/infill", [&llama](const Request &req, Response &res) - { - auto lock = llama.lock(); + // svr.Post("/infill", [&llama](const Request &req, Response &res) + // { + // auto lock = llama.lock(); - llama.rewind(); + // llama.rewind(); - llama_reset_timings(llama.ctx); + // llama_reset_timings(llama.ctx); - parse_options_infill(json::parse(req.body), llama); + // parse_options_infill(json::parse(req.body), llama); - if (!llama.loadGrammar()) - { - res.status = 400; - return; - } - llama.loadInfill(); - llama.beginCompletion(); - const auto chunked_content_provider = [&](size_t, DataSink & sink) { - size_t sent_count = 0; - size_t sent_token_probs_index = 0; + // if (!llama.loadGrammar()) + // { + // res.status = 400; + // return; + // } + // llama.loadInfill(); + // llama.beginCompletion(); + // const auto chunked_content_provider = [&](size_t, DataSink & sink) { + // size_t sent_count = 0; + // size_t sent_token_probs_index = 0; - while (llama.has_next_token) { - const completion_token_output token_with_probs = llama.doCompletion(); - if (token_with_probs.tok == -1 || llama.multibyte_pending > 0) { - continue; - } - const std::string token_text = llama_token_to_piece(llama.ctx, token_with_probs.tok); + // while (llama.has_next_token) { + // const completion_token_output token_with_probs = llama.doCompletion(); + // if (token_with_probs.tok == -1 || llama.multibyte_pending > 0) { + // continue; + // } + // const std::string token_text = llama_token_to_piece(llama.ctx, token_with_probs.tok); - size_t pos = std::min(sent_count, llama.generated_text.size()); + // size_t pos = std::min(sent_count, llama.generated_text.size()); - const std::string str_test = llama.generated_text.substr(pos); - bool is_stop_full = false; - size_t stop_pos = - llama.findStoppingStrings(str_test, token_text.size(), STOP_FULL); - if (stop_pos != std::string::npos) { - is_stop_full = true; - llama.generated_text.erase( - llama.generated_text.begin() + pos + stop_pos, - llama.generated_text.end()); - pos = std::min(sent_count, llama.generated_text.size()); - } else { - is_stop_full = false; - stop_pos = llama.findStoppingStrings(str_test, token_text.size(), - STOP_PARTIAL); - } + // const std::string str_test = llama.generated_text.substr(pos); + // bool is_stop_full = false; + // size_t stop_pos = + // llama.findStoppingStrings(str_test, token_text.size(), STOP_FULL); + // if (stop_pos != std::string::npos) { + // is_stop_full = true; + // llama.generated_text.erase( + // llama.generated_text.begin() + pos + stop_pos, + // llama.generated_text.end()); + // pos = std::min(sent_count, llama.generated_text.size()); + // } else { + // is_stop_full = false; + // stop_pos = llama.findStoppingStrings(str_test, token_text.size(), + // STOP_PARTIAL); + // } - if ( - stop_pos == std::string::npos || - // Send rest of the text if we are at the end of the generation - (!llama.has_next_token && !is_stop_full && stop_pos > 0) - ) { - const std::string to_send = llama.generated_text.substr(pos, std::string::npos); + // if ( + // stop_pos == std::string::npos || + // // Send rest of the text if we are at the end of the generation + // (!llama.has_next_token && !is_stop_full && stop_pos > 0) + // ) { + // const std::string to_send = llama.generated_text.substr(pos, std::string::npos); - sent_count += to_send.size(); + // sent_count += to_send.size(); - std::vector probs_output = {}; + // std::vector probs_output = {}; - if (llama.params.n_probs > 0) { - const std::vector to_send_toks = llama_tokenize(llama.ctx, to_send, false); - size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size()); - size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size()); - if (probs_pos < probs_stop_pos) { - probs_output = std::vector(llama.generated_token_probs.begin() + probs_pos, llama.generated_token_probs.begin() + probs_stop_pos); - } - sent_token_probs_index = probs_stop_pos; - } + // if (llama.params.n_probs > 0) { + // const std::vector to_send_toks = llama_tokenize(llama.ctx, to_send, false); + // size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size()); + // size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size()); + // if (probs_pos < probs_stop_pos) { + // probs_output = std::vector(llama.generated_token_probs.begin() + probs_pos, llama.generated_token_probs.begin() + probs_stop_pos); + // } + // sent_token_probs_index = probs_stop_pos; + // } - const json data = format_partial_response(llama, to_send, probs_output); + // const json data = format_partial_response(llama, to_send, probs_output); - const std::string str = - "data: " + - data.dump(-1, ' ', false, json::error_handler_t::replace) + - "\n\n"; + // const std::string str = + // "data: " + + // data.dump(-1, ' ', false, json::error_handler_t::replace) + + // "\n\n"; - LOG_VERBOSE("data stream", { - { "to_send", str } - }); + // LOG_VERBOSE("data stream", { + // { "to_send", str } + // }); - if (!sink.write(str.data(), str.size())) { - LOG_VERBOSE("stream closed", {}); - llama_print_timings(llama.ctx); - return false; - } - } + // if (!sink.write(str.data(), str.size())) { + // LOG_VERBOSE("stream closed", {}); + // llama_print_timings(llama.ctx); + // return false; + // } + // } - if (!llama.has_next_token) { - // Generation is done, send extra information. - const json data = format_final_response( - llama, - "", - std::vector(llama.generated_token_probs.begin(), llama.generated_token_probs.begin() + sent_token_probs_index) - ); + // if (!llama.has_next_token) { + // // Generation is done, send extra information. + // const json data = format_final_response( + // llama, + // "", + // std::vector(llama.generated_token_probs.begin(), llama.generated_token_probs.begin() + sent_token_probs_index) + // ); - const std::string str = - "data: " + - data.dump(-1, ' ', false, json::error_handler_t::replace) + - "\n\n"; + // const std::string str = + // "data: " + + // data.dump(-1, ' ', false, json::error_handler_t::replace) + + // "\n\n"; - LOG_VERBOSE("data stream", { - { "to_send", str } - }); + // LOG_VERBOSE("data stream", { + // { "to_send", str } + // }); - if (!sink.write(str.data(), str.size())) { - LOG_VERBOSE("stream closed", {}); - llama_print_timings(llama.ctx); - return false; - } - } - } + // if (!sink.write(str.data(), str.size())) { + // LOG_VERBOSE("stream closed", {}); + // llama_print_timings(llama.ctx); + // return false; + // } + // } + // } - llama_print_timings(llama.ctx); - sink.done(); - return true; - }; - const auto on_complete = [&](bool) { - llama.mutex.unlock(); - }; - lock.release(); - res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); - }); + // llama_print_timings(llama.ctx); + // sink.done(); + // return true; + // }; + // const auto on_complete = [&](bool) { + // llama.mutex.unlock(); + // }; + // lock.release(); + // res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); + // }); - svr.Get("/model.json", [&llama](const Request &, Response &res) - { - const json data = format_generation_settings(llama); - return res.set_content(data.dump(), "application/json"); }); + // svr.Get("/model.json", [&llama](const Request &, Response &res) + // { + // const json data = format_generation_settings(llama); + // return res.set_content(data.dump(), "application/json"); }); svr.Options(R"(/.*)", [](const Request &, Response &res) { return res.set_content("", "application/json"); }); @@ -1696,29 +1994,29 @@ int main(int argc, char **argv) const json data = format_detokenized_response(content); return res.set_content(data.dump(), "application/json"); }); - svr.Post("/embedding", [&llama](const Request &req, Response &res) - { - auto lock = llama.lock(); + // svr.Post("/embedding", [&llama](const Request &req, Response &res) + // { + // auto lock = llama.lock(); - const json body = json::parse(req.body); + // const json body = json::parse(req.body); - llama.rewind(); - llama_reset_timings(llama.ctx); - if (body.count("content") != 0) - { - llama.prompt = body["content"]; - } - else - { - llama.prompt = ""; - } - llama.params.n_predict = 0; - llama.loadPrompt(); - llama.beginCompletion(); - llama.doCompletion(); + // llama.rewind(); + // llama_reset_timings(llama.ctx); + // if (body.count("content") != 0) + // { + // llama.prompt = body["content"]; + // } + // else + // { + // llama.prompt = ""; + // } + // llama.params.n_predict = 0; + // llama.loadPrompt(); + // llama.beginCompletion(); + // llama.doCompletion(); - const json data = format_embedding_response(llama); - return res.set_content(data.dump(), "application/json"); }); + // const json data = format_embedding_response(llama); + // return res.set_content(data.dump(), "application/json"); }); svr.set_logger(log_server_request); @@ -1755,6 +2053,16 @@ int main(int argc, char **argv) return 1; } + if(!params.embedding) { + std::thread t([&llama]() + { + bool running = true; + while (running) + { + running = llama.updateSlots(); + } }); + } + // Set the base directory for serving static files svr.set_base_dir(sparams.public_path); @@ -1770,10 +2078,6 @@ int main(int argc, char **argv) { return 1; } - - if (llama.grammar != nullptr) { - llama_grammar_free(llama.grammar); - } llama_backend_free(); return 0;