From 29c8cdd65d1c6ced44f9f0d4f4c2d03a215ebed5 Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Thu, 12 Oct 2023 15:02:19 -0400 Subject: [PATCH] refactored sampling function --- examples/server/server.cpp | 750 ++++++++++++++----------------------- 1 file changed, 289 insertions(+), 461 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index de10af2b0..ad2ae4ee4 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -75,25 +75,6 @@ struct slot_params { bool stream = true; uint32_t seed = -1; // RNG seed int32_t n_predict = 128; // new tokens to predict - - // sampler params - int32_t top_k = 40; // <= 0 to use vocab size - float top_p = 0.95f; // 1.0 = disabled - float tfs_z = 1.00f; // 1.0 = disabled - float typical_p = 1.00f; // 1.0 = disabled - float temp = 0.80f; // 1.0 = disabled - float repeat_penalty = 1.10f; // 1.0 = disabled - int32_t repeat_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) - float frequency_penalty = 0.00f; // 0.0 = disabled - float presence_penalty = 0.00f; // 0.0 = disabled - int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 - float mirostat_tau = 5.00f; // target entropy - float mirostat_eta = 0.10f; // learning rate - int n_probs = 0; - bool penalize_nl = false; - - std::unordered_map logit_bias; // logit bias for specific tokens - std::string grammar = ""; // optional BNF-like grammar to constrain sampling bool remember_generation = false; // remember a part of the prompt to avoid reprocessing all prompt std::vector antiprompt; @@ -133,23 +114,6 @@ static bool ends_with(const std::string &str, const std::string &suffix) 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix); } -static void slot_params_to_gpt_params(const slot_params &src, gpt_params & dst) -{ - dst.frequency_penalty = src.frequency_penalty; - dst.temp = src.temp; - dst.top_k = src.top_k; - dst.top_p = src.top_p; - dst.grammar = src.grammar; - dst.logit_bias = src.logit_bias; - dst.mirostat = src.mirostat; - dst.mirostat_eta = src.mirostat_eta; - dst.mirostat_tau = src.mirostat_tau; - dst.typical_p = src.typical_p; - dst.repeat_penalty = src.repeat_penalty; - dst.repeat_last_n = src.repeat_last_n; - dst.presence_penalty = src.presence_penalty; -} - static size_t find_partial_stop_string(const std::string &stop, const std::string &text) { @@ -247,10 +211,10 @@ struct llama_client_slot { int id; // generation props - int32_t num_prompt_tokens = 0; + int32_t n_past = 0; int32_t n_decoded = 0; int32_t i_batch = -1; - size_t n_past = 0; + int32_t num_prompt_tokens = 0; json prompt; std::string generated_text = ""; int num_tokens_predicted = 0; @@ -268,7 +232,9 @@ struct llama_client_slot std::string stopping_word; int32_t multibyte_pending = 0; - slot_params params; + struct slot_params params; + struct llama_sampling_params sparams; + llama_sampling_context ctx_sampling; // grammar props grammar_parser::parse_state parsed_grammar; @@ -292,12 +258,14 @@ struct llama_client_slot if (grammar != nullptr) { llama_grammar_free(grammar); grammar = nullptr; + ctx_sampling.params = sparams; + ctx_sampling.grammar = NULL; } // llama_set_rng_seed(ctx, params.seed); in batched the seed matter??????? } - bool loadGrammar() + bool loadGrammar(llama_token eos) { if (!params.grammar.empty()) { parsed_grammar = grammar_parser::parse(params.grammar.c_str()); @@ -308,18 +276,19 @@ struct llama_client_slot } grammar_parser::print_grammar(stderr, parsed_grammar); - // TODO: fix this comment - // { - // auto it = params.logit_bias.find(llama_token_eos(ctx)); - // if (it != params.logit_bias.end() && it->second == -INFINITY) { - // LOG_WARNING("EOS token is disabled, which will cause most grammars to fail", {}); - // } - // } + { + auto it = sparams.logit_bias.find(eos); + if (it != sparams.logit_bias.end() && it->second == -INFINITY) { + LOG_WARNING("EOS token is disabled, which will cause most grammars to fail", {}); + } + } std::vector grammar_rules(parsed_grammar.c_rules()); grammar = llama_grammar_init( grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); } + ctx_sampling.params = sparams; + ctx_sampling.grammar = grammar; return true; } @@ -368,7 +337,7 @@ struct llama_server_context std::string system_prompt = ""; bool update_system_prompt = false; std::vector tokens_system; - int32_t n_tokens_system = 0; + int32_t num_tokens_system; // broadcast to all clients to keep the same prompt format std::string user_name = ""; // this should be the anti prompt @@ -380,7 +349,6 @@ struct llama_server_context std::vector candidates; bool all_slots_are_idle = false; gpt_params params; - llama_sampling_context ctx_sampling; int n_ctx; int n_vocab; bool clean_kv_cache = true; @@ -403,29 +371,10 @@ struct llama_server_context llama_free_model(model); model = nullptr; } - } - - void rewind() - { - params.antiprompt.clear(); - params.grammar.clear(); - num_prompt_tokens = 0; - num_tokens_predicted = 0; - generated_text = ""; - generated_text.reserve(n_ctx); - generated_token_probs.clear(); - truncated = false; - stopped_eos = false; - stopped_word = false; - stopped_limit = false; - stopping_word = ""; - multibyte_pending = 0; - n_remain = 0; - n_past = 0; - - if (grammar != nullptr) { - llama_grammar_free(grammar); - grammar = nullptr; + for(auto &slot : slots) { + if(slot.grammar) { + llama_grammar_free(slot.grammar); + } } } @@ -447,6 +396,7 @@ struct llama_server_context void initialize() { // create slots LOG_TEE("Available slots:\n"); + all_slots_are_idle = true; for (int i = 0; i < params.n_parallel; i++) { llama_client_slot slot; @@ -457,11 +407,11 @@ struct llama_server_context LOG_TEE(" - slot %i\n", slot.id); slots.push_back(slot); } + LOG_TEE("Context Size: %i\n", params.n_ctx); batch = llama_batch_init(params.n_ctx, 0); - // empty system prompt system_prompt = ""; - all_slots_are_idle = true; + num_tokens_system = 0; } std::vector tokenize(const json & json_prompt, bool add_bos) const @@ -510,28 +460,57 @@ struct llama_server_context return prompt_tokens; } - bool loadGrammar() - { - if (!params.grammar.empty()) { - parsed_grammar = grammar_parser::parse(params.grammar.c_str()); - // will be empty (default) if there are parse errors - if (parsed_grammar.rules.empty()) { - LOG_ERROR("grammar parse error", {{"grammar", params.grammar}}); - return false; - } - grammar_parser::print_grammar(stderr, parsed_grammar); + void processPrompt() { + //params.n_keep = std::min(n_ctx - 4, params.n_keep); + // if input prompt is too big, truncate like normal + // if (num_prompt_tokens >= (size_t)n_ctx) + // { + // const int n_left = (n_ctx - params.n_keep) / 2; + // std::vector new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep); + // const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left; + // new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end()); + // std::copy(prompt_tokens.end() - n_ctx, prompt_tokens.end(), last_n_tokens.begin()); + + // LOG_VERBOSE("input truncated", { + // {"n_ctx", n_ctx}, + // {"n_keep", params.n_keep}, + // {"n_left", n_left}, + // {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())}, + // }); + + // truncated = true; + // prompt_tokens = new_tokens; + // } + // else + // { + // const size_t ps = num_prompt_tokens; + // std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0); + // std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps); + // } + + // compare the evaluated prompt with the new prompt + } + + + llama_client_slot* getSlot(int id) { + for (llama_client_slot & slot : slots) + { + if ((id == -1 && slot.available()) || slot.id == id) { - auto it = params.logit_bias.find(llama_token_eos(ctx)); - if (it != params.logit_bias.end() && it->second == -INFINITY) { - LOG_WARNING("EOS token is disabled, which will cause most grammars to fail", {}); - } + return &slot; } - - std::vector grammar_rules(parsed_grammar.c_rules()); - grammar = llama_grammar_init( - grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); } + return nullptr; + } + + bool launchSlot(llama_client_slot* &slot) { + if(!slot->loadGrammar(llama_token_eos(ctx))) { + return false; + } + all_slots_are_idle = false; + slot->command = LOAD_PROMPT; + LOG_TEE("slot %i is processing\n", slot->id); return true; } @@ -592,15 +571,15 @@ struct llama_server_context // std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps); // } - // compare the evaluated prompt with the new prompt - n_past = common_part(embd, prompt_tokens); - embd = prompt_tokens; - if (n_past == num_prompt_tokens) - { - // we have to evaluate at least 1 token to generate logits. - printf("we have to evaluate at least 1 token to generate logits\n"); - n_past--; - } + // // compare the evaluated prompt with the new prompt + // n_past = common_part(embd, prompt_tokens); + // embd = prompt_tokens; + // if (n_past == num_prompt_tokens) + // { + // // we have to evaluate at least 1 token to generate logits. + // printf("we have to evaluate at least 1 token to generate logits\n"); + // n_past--; + // } // LOG_VERBOSE("prompt ingested", { // {"n_past", n_past}, @@ -617,172 +596,81 @@ struct llama_server_context { llama_kv_cache_seq_rm(ctx, i, 0, -1); } - params.n_keep = std::min(n_ctx - 4, params.n_keep); - - // if input prompt is too big, truncate like normal - if (num_prompt_tokens >= (size_t)n_ctx) - { - const int n_left = (n_ctx - params.n_keep) / 2; - std::vector new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep); - const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left; - new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end()); - std::copy(prompt_tokens.end() - n_ctx, prompt_tokens.end(), last_n_tokens.begin()); - - LOG_VERBOSE("input truncated", { - {"n_ctx", n_ctx}, - {"n_keep", params.n_keep}, - {"n_left", n_left}, - {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())}, - }); - - truncated = true; - prompt_tokens = new_tokens; - } - else - { - const size_t ps = num_prompt_tokens; - std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0); - std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps); - } - - // compare the evaluated prompt with the new prompt - n_past = common_part(embd, prompt_tokens); - - // since #3228 we now have to manually manage the KV cache - llama_kv_cache_seq_rm(ctx, 0, n_past, -1); - - embd = prompt_tokens; - if (n_past == num_prompt_tokens) - { - // we have to evaluate at least 1 token to generate logits. - n_past--; - } - - LOG_VERBOSE("prompt ingested", { - {"n_past", n_past}, - {"cached", tokens_to_str(ctx, embd.cbegin(), embd.cbegin() + n_past)}, - {"to_eval", tokens_to_str(ctx, embd.cbegin() + n_past, embd.cend())}, - }); - - has_next_token = true; + clean_kv_cache = false; } - void beginCompletion() - { - // number of tokens to keep when resetting context - n_remain = params.n_predict; - llama_set_rng_seed(ctx, params.seed); + void updateSystemPrompt() { + tokens_system = ::llama_tokenize(ctx, system_prompt, true); + num_tokens_system = tokens_system.size(); + + batch.n_tokens = num_tokens_system; + + cleanKVCache(); + + for (int32_t i = 0; i < batch.n_tokens; ++i) + { + batch.token[i] = tokens_system[i]; + batch.pos[i] = i; + batch.seq_id[i] = 0; + batch.logits[i] = false; + } + + if (llama_decode(ctx, batch) != 0) + { + LOG_TEE("%s: llama_decode() failed\n", __func__); + return; + } + + // assign the system KV cache to all parallel sequences + for (int32_t i = 1; i < params.n_parallel; ++i) + { + llama_kv_cache_seq_cp(ctx, 0, i, 0, num_tokens_system); + } + + LOG_TEE("system prompt updated\n"); + update_system_prompt = false; } - completion_token_output nextToken() - { - completion_token_output result; - result.tok = -1; - - if (embd.size() >= (size_t)n_ctx) + void notifySystemPromptChanged() { + // release all slots + for (llama_client_slot &slot : slots) { - // Shift context - - const int n_left = n_past - params.n_keep - 1; - const int n_discard = n_left/2; - - llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1); - llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard); - - for (size_t i = params.n_keep + 1 + n_discard; i < embd.size(); i++) - { - embd[i - n_discard] = embd[i]; - } - embd.resize(embd.size() - n_discard); - - n_past -= n_discard; - - truncated = true; - LOG_VERBOSE("input truncated", { - {"n_ctx", n_ctx}, - {"n_keep", params.n_keep}, - {"n_left", n_left}, - }); + slot.release(); } - - bool tg = true; - while (n_past < embd.size()) - { - int n_eval = (int)embd.size() - n_past; - tg = n_eval == 1; - if (n_eval > params.n_batch) - { - n_eval = params.n_batch; - } - - if (llama_decode(ctx, llama_batch_get_one(&embd[n_past], n_eval, n_past, 0))) - { - LOG_ERROR("failed to eval", { - {"n_eval", n_eval}, - {"n_past", n_past}, - {"embd", tokens_to_str(ctx, embd.cbegin() + n_past, embd.cend())}, - }); - has_next_token = false; - return result; - } - n_past += n_eval; + waitAllAreIdle(); + all_slots_are_idle = true; + // wait until system prompt load + update_system_prompt = true; + while(update_system_prompt) { + std::this_thread::sleep_for(std::chrono::milliseconds(5)); } - - if (params.n_predict == 0) - { - has_next_token = false; - result.tok = llama_token_eos(ctx); - return result; - } - - { - // out of user input, sample next token - std::vector candidates; - candidates.reserve(llama_n_vocab(model)); - - result.tok = llama_sample_token(ctx, NULL, grammar, params, last_n_tokens, candidates); - - llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; - - const int32_t n_probs = params.n_probs; - if (params.temp <= 0 && n_probs > 0) - { - // For llama_sample_token_greedy we need to sort candidates - llama_sample_softmax(ctx, &candidates_p); - } - - for (size_t i = 0; i < std::min(candidates_p.size, (size_t)n_probs); ++i) - { - result.probs.push_back({candidates_p.data[i].id, candidates_p.data[i].p}); - } - - last_n_tokens.erase(last_n_tokens.begin()); - last_n_tokens.push_back(result.tok); - if (tg) { - num_tokens_predicted++; - } - } - - // add it to the context - embd.push_back(result.tok); - // decrement remaining sampling budget - --n_remain; - - if (!embd.empty() && embd.back() == llama_token_eos(ctx)) - { - // stopping_word = llama_token_to_piece(ctx, embd.back()); - has_next_token = false; - stopped_eos = true; - LOG_VERBOSE("eos token found", {}); - return result; - } - - has_next_token = params.n_predict == -1 || n_remain != 0; - return result; + // system prompt loaded, continue } - size_t findStoppingStrings(const size_t last_token_size, - const stop_type type, llama_client_slot & slot) + void processSystemPromptData(json sys_props) { + system_prompt = sys_props.value("system_prompt", ""); + user_name = sys_props.value("anti_prompt", ""); + assistant_name = sys_props.value("assistant_name", ""); + notifySystemPromptChanged(); + } + + void waitAllAreIdle() { + bool wait = true; + while(wait) { + wait = false; + for (auto &slot : slots) + { + if (!slot.available()) + { + wait = true; + break; + } + } + } + } + + size_t findStoppingStrings(const std::string &text, const size_t last_token_size, + const stop_type type, llama_client_slot &slot) { size_t stop_pos = std::string::npos; for (const std::string &word : slot.params.antiprompt) @@ -791,22 +679,19 @@ struct llama_server_context if (type == STOP_FULL) { const size_t tmp = word.size() + last_token_size; - const size_t from_pos = slot.generated_text.size() > tmp ? slot.generated_text.size() - tmp : 0; - pos = slot.generated_text.find(word, from_pos); + const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0; + pos = text.find(word, from_pos); } else { - pos = find_partial_stop_string(word, slot.generated_text); + pos = find_partial_stop_string(word, text); } if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) { - if (type == STOP_FULL) - { - slot.stopping_word = word; - slot.stopped_word = true; - } stop_pos = pos; + slot.stopped_word = true; + slot.stopping_word = word; } } return stop_pos; @@ -817,23 +702,19 @@ struct llama_server_context slot.last_n_tokens.erase(slot.last_n_tokens.begin()); slot.last_n_tokens.push_back(result.tok); const std::string token_str = llama_token_to_piece(ctx, result.tok); - printf("%s", token_str.c_str()); slot.sampled = result.tok; - - size_t stop_pos = - findStoppingStrings(token_str.size(), STOP_FULL, slot); - slot.addTokenString(result); - slot.generated_text += token_str; + size_t stop_pos = findStoppingStrings(slot.generated_text, token_str.size(), STOP_FULL, slot); + bool has_next_token = !(slot.n_decoded > 2 && (result.tok == llama_token_eos(ctx) || (slot.n_decoded + slot.n_past >= params.n_predict) || stop_pos != std::string::npos)); - if (params.n_probs > 0) + if (slot.sparams.n_probs > 0) { slot.generated_token_probs.push_back(result); } @@ -902,7 +783,7 @@ struct llama_server_context } batch.n_tokens = 0; - int kv_cache_free = (n_ctx - n_tokens_system); + int kv_cache_free = (n_ctx - num_tokens_system); if(all_slots_are_idle) { if(system_prompt.empty() && clean_kv_cache) { @@ -919,18 +800,16 @@ struct llama_server_context { LOG_TEE("slot %i released\n", slot.id); if(!slot.params.remember_generation) { - llama_kv_cache_seq_rm(ctx, slot.id, n_tokens_system, n_ctx); + llama_kv_cache_seq_rm(ctx, slot.id, num_tokens_system, n_ctx); slot.state = IDLE; slot.command = NONE; - slot.num_prompt_tokens = 0; - slot.num_tokens_predicted = 0; + slot.generated_text.clear(); } else { slot.state = SLEEPING; slot.command = NONE; } continue; } - kv_cache_free -= slot.num_prompt_tokens; if (slot.state == IDLE || slot.command == RELEASE) { @@ -938,7 +817,7 @@ struct llama_server_context } batch.token [batch.n_tokens] = slot.sampled; - batch.pos [batch.n_tokens] = n_tokens_system + slot.n_past + slot.n_decoded; + batch.pos [batch.n_tokens] = num_tokens_system + slot.n_past + slot.n_decoded; batch.seq_id[batch.n_tokens] = slot.id; batch.logits[batch.n_tokens] = true; @@ -954,13 +833,11 @@ struct llama_server_context // need process the prompt bool keep_gen = slot.state == SLEEPING; // remember generation if ((slot.state == IDLE || keep_gen) && slot.command == LOAD_PROMPT) { - LOG_TEE("processing prompt\n"); slot.state = PROCESSING; slot.command = NONE; auto prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt slot.num_prompt_tokens = prompt_tokens.size(); - slot.n_past = keep_gen ? common_part(slot.context_tokens, prompt_tokens) : 0; slot.context_tokens = prompt_tokens; @@ -972,14 +849,14 @@ struct llama_server_context }); if(system_prompt.empty()) { - LOG_TEE("cleaning kv: %i\n", slot.n_past); llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1); } std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end(), 0); - for (size_t i = slot.n_past; i < slot.context_tokens.size(); ++i) { - batch.token [batch.n_tokens] = slot.context_tokens[i]; - batch.pos [batch.n_tokens] = i + n_tokens_system; + for (; slot.n_past < prompt_tokens.size(); ++slot.n_past) { + printf(llama_token_to_piece(ctx, prompt_tokens[slot.n_past]).c_str()); + batch.token [batch.n_tokens] = prompt_tokens[slot.n_past]; + batch.pos [batch.n_tokens] = slot.n_past + num_tokens_system; batch.seq_id[batch.n_tokens] = slot.id; batch.logits[batch.n_tokens] = false; batch.n_tokens += 1; @@ -1037,13 +914,12 @@ struct llama_server_context continue; } - slot_params_to_gpt_params(slot.params, params); completion_token_output result; - const llama_token id = llama_sample_token(ctx, NULL, NULL, params, slot.last_n_tokens, candidates, slot.i_batch - i); + const llama_token id = llama_sampling_sample(ctx, NULL, slot.ctx_sampling, slot.last_n_tokens, candidates, slot.i_batch - i); llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; result.tok = id; - const int32_t n_probs = params.n_probs; - if (params.temp <= 0 && n_probs > 0) + const int32_t n_probs = slot.sparams.n_probs; + if (slot.sparams.temp <= 0 && n_probs > 0) { // For llama_sample_token_greedy we need to sort candidates llama_sample_softmax(ctx, &candidates_p); @@ -1055,7 +931,6 @@ struct llama_server_context } if (!processToken(result, slot)) { - slot.generated_text.clear(); slot.release(); } kv_cache_free -= slot.num_tokens_predicted; @@ -1092,16 +967,15 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms, printf("usage: %s [options]\n", argv0); printf("\n"); printf("options:\n"); - printf(" -h, --help show this help message and exit\n"); - printf(" -v, --verbose verbose output (default: %s)\n", server_verbose ? "enabled" : "disabled"); - printf(" -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads); - printf(" -tb N, --threads-batch N number of threads to use during batch and prompt processing (default: same as --threads)\n"); - printf(" -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx); - printf(" --rope-freq-base N RoPE base frequency (default: loaded from model)\n"); - printf(" --rope-freq-scale N RoPE frequency scaling factor (default: loaded from model)\n"); - printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch); - printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n"); - printf(" not recommended: doubles context memory required and no measurable increase in quality\n"); + printf(" -h, --help show this help message and exit\n"); + printf(" -v, --verbose verbose output (default: %s)\n", server_verbose ? "enabled" : "disabled"); + printf(" -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads); + printf(" -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx); + printf(" --rope-freq-base N RoPE base frequency (default: loaded from model)\n"); + printf(" --rope-freq-scale N RoPE frequency scaling factor (default: loaded from model)\n"); + printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch); + printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n"); + printf(" not recommended: doubles context memory required and no measurable increase in quality\n"); if (llama_mlock_supported()) { printf(" --mlock force system to keep model in RAM rather than swapping or compressing\n"); @@ -1246,15 +1120,6 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, } params.n_threads = std::stoi(argv[i]); } - else if (arg == "--threads-batch" || arg == "-tb") - { - if (++i >= argc) - { - invalid_param = true; - break; - } - params.n_threads_batch = std::stoi(argv[i]); - } else if (arg == "-b" || arg == "--batch-size") { if (++i >= argc) @@ -1432,35 +1297,35 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, static json format_generation_settings(llama_server_context &llama, llama_client_slot* &slot) { - const auto eos_bias = llama.params.logit_bias.find(llama_token_eos(llama.ctx)); - const bool ignore_eos = eos_bias != llama.params.logit_bias.end() && + const auto eos_bias = slot->sparams.logit_bias.find(llama_token_eos(llama.ctx)); + const bool ignore_eos = eos_bias != slot->sparams.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second); return json{ {"n_ctx", llama.n_ctx}, {"model", llama.params.model_alias}, - {"seed", llama.params.seed}, - {"temp", llama.params.temp}, - {"top_k", llama.params.top_k}, - {"top_p", llama.params.top_p}, - {"tfs_z", llama.params.tfs_z}, - {"typical_p", llama.params.typical_p}, - {"repeat_last_n", llama.params.repeat_last_n}, - {"repeat_penalty", llama.params.repeat_penalty}, - {"presence_penalty", llama.params.presence_penalty}, - {"frequency_penalty", llama.params.frequency_penalty}, - {"mirostat", llama.params.mirostat}, - {"mirostat_tau", llama.params.mirostat_tau}, - {"mirostat_eta", llama.params.mirostat_eta}, - {"penalize_nl", llama.params.penalize_nl}, - {"stop", llama.params.antiprompt}, - {"n_predict", llama.params.n_predict}, - {"n_keep", llama.params.n_keep}, + {"seed", slot->params.seed}, + {"temp", slot->sparams.temp}, + {"top_k", slot->sparams.top_k}, + {"top_p", slot->sparams.top_p}, + {"tfs_z", slot->sparams.tfs_z}, + {"typical_p", slot->sparams.typical_p}, + {"repeat_last_n", slot->sparams.repeat_last_n}, + {"repeat_penalty", slot->sparams.repeat_penalty}, + {"presence_penalty",slot->sparams.presence_penalty}, + {"frequency_penalty", slot->sparams.frequency_penalty}, + {"mirostat", slot->sparams.mirostat}, + {"mirostat_tau", slot->sparams.mirostat_tau}, + {"mirostat_eta", slot->sparams.mirostat_eta}, + {"penalize_nl", slot->sparams.penalize_nl}, + {"stop", slot->params.antiprompt}, + {"n_predict", slot->params.n_predict}, + // {"n_keep", slot.params.n_keep}, {"ignore_eos", ignore_eos}, - {"stream", llama.stream}, - {"logit_bias", llama.params.logit_bias}, - {"n_probs", llama.params.n_probs}, - {"grammar", llama.params.grammar}, + {"stream", slot->params.stream}, + {"logit_bias", slot->sparams.logit_bias}, + {"n_probs", slot->sparams.n_probs}, + {"grammar", slot->params.grammar}, }; } @@ -1488,7 +1353,7 @@ static json format_timings(llama_server_context &llama) }; } -static json format_final_response(llama_server_context &llama, llama_client_slot* &slot, const std::string &content, const std::vector &probs) +static json format_final_response(llama_server_context &llama, llama_client_slot* slot, const std::string &content, const std::vector &probs) { json res = json{ @@ -1508,7 +1373,7 @@ static json format_final_response(llama_server_context &llama, llama_client_slot // {"timings", format_timings(llama)}, }; - if (llama.params.n_probs > 0) + if (slot->sparams.n_probs > 0) { res["completion_probabilities"] = probs_vector_to_json(llama.ctx, probs); } @@ -1517,7 +1382,7 @@ static json format_final_response(llama_server_context &llama, llama_client_slot } static json format_partial_response( - llama_server_context &llama, llama_client_slot* &slot, const std::string &content, const std::vector &probs + llama_server_context &llama, llama_client_slot* slot, const std::string &content, const std::vector &probs ) { json res = json{ {"content", content}, @@ -1525,7 +1390,7 @@ static json format_partial_response( { "slot_id", slot->id } }; - if (llama.params.n_probs > 0) + if (slot->sparams.n_probs > 0) { res["completion_probabilities"] = probs_vector_to_json(llama.ctx, probs); } @@ -1556,27 +1421,28 @@ static T json_value(const json &body, const std::string &key, const T &default_v static void parse_options_completion(const json &body, llama_client_slot* &slot, llama_server_context &llama) { - gpt_params default_params; + slot_params default_params; + llama_sampling_params default_sparams; - llama.stream = json_value(body, "stream", false); - llama.params.n_predict = json_value(body, "n_predict", default_params.n_predict); - llama.params.top_k = json_value(body, "top_k", default_params.top_k); - llama.params.top_p = json_value(body, "top_p", default_params.top_p); - llama.params.tfs_z = json_value(body, "tfs_z", default_params.tfs_z); - llama.params.typical_p = json_value(body, "typical_p", default_params.typical_p); - llama.params.repeat_last_n = json_value(body, "repeat_last_n", default_params.repeat_last_n); - llama.params.temp = json_value(body, "temperature", default_params.temp); - llama.params.repeat_penalty = json_value(body, "repeat_penalty", default_params.repeat_penalty); - llama.params.presence_penalty = json_value(body, "presence_penalty", default_params.presence_penalty); - llama.params.frequency_penalty = json_value(body, "frequency_penalty", default_params.frequency_penalty); - llama.params.mirostat = json_value(body, "mirostat", default_params.mirostat); - llama.params.mirostat_tau = json_value(body, "mirostat_tau", default_params.mirostat_tau); - llama.params.mirostat_eta = json_value(body, "mirostat_eta", default_params.mirostat_eta); - llama.params.penalize_nl = json_value(body, "penalize_nl", default_params.penalize_nl); - llama.params.n_keep = json_value(body, "n_keep", default_params.n_keep); - llama.params.seed = json_value(body, "seed", default_params.seed); - llama.params.grammar = json_value(body, "grammar", default_params.grammar); - llama.params.n_probs = json_value(body, "n_probs", default_params.n_probs); + slot->params.stream = json_value(body, "stream", false); + slot->params.n_predict = json_value(body, "n_predict", default_params.n_predict); + slot->sparams.top_k = json_value(body, "top_k", default_sparams.top_k); + slot->sparams.top_p = json_value(body, "top_p", default_sparams.top_p); + slot->sparams.tfs_z = json_value(body, "tfs_z", default_sparams.tfs_z); + slot->sparams.typical_p = json_value(body, "typical_p", default_sparams.typical_p); + slot->sparams.repeat_last_n = json_value(body, "repeat_last_n", default_sparams.repeat_last_n); + slot->sparams.temp = json_value(body, "temperature", default_sparams.temp); + slot->sparams.repeat_penalty = json_value(body, "repeat_penalty", default_sparams.repeat_penalty); + slot->sparams.presence_penalty = json_value(body, "presence_penalty", default_sparams.presence_penalty); + slot->sparams.frequency_penalty = json_value(body, "frequency_penalty", default_sparams.frequency_penalty); + slot->sparams.mirostat = json_value(body, "mirostat", default_sparams.mirostat); + slot->sparams.mirostat_tau = json_value(body, "mirostat_tau", default_sparams.mirostat_tau); + slot->sparams.mirostat_eta = json_value(body, "mirostat_eta", default_sparams.mirostat_eta); + slot->sparams.penalize_nl = json_value(body, "penalize_nl", default_sparams.penalize_nl); + //llama.params.n_keep = json_value(body, "n_keep", default_params.n_keep); + slot->params.seed = json_value(body, "seed", default_params.seed); + slot->params.grammar = json_value(body, "grammar", default_params.grammar); + slot->sparams.n_probs = json_value(body, "n_probs", default_sparams.n_probs); if (body.count("prompt") != 0) { @@ -1587,10 +1453,10 @@ static void parse_options_completion(const json &body, llama_client_slot* &slot, slot->prompt = ""; } - llama.params.logit_bias.clear(); + slot->sparams.logit_bias.clear(); if (json_value(body, "ignore_eos", false)) { - llama.params.logit_bias[llama_token_eos(llama.ctx)] = -INFINITY; + slot->sparams.logit_bias[llama_token_eos(llama.ctx)] = -INFINITY; } const auto &logit_bias = body.find("logit_bias"); @@ -1606,11 +1472,11 @@ static void parse_options_completion(const json &body, llama_client_slot* &slot, { if (el[1].is_number()) { - llama.params.logit_bias[tok] = el[1].get(); + slot->sparams.logit_bias[tok] = el[1].get(); } else if (el[1].is_boolean() && !el[1].get()) { - llama.params.logit_bias[tok] = -INFINITY; + slot->sparams.logit_bias[tok] = -INFINITY; } } } @@ -1630,8 +1496,6 @@ static void parse_options_completion(const json &body, llama_client_slot* &slot, } } - llama.ctx_sampling = llama_sampling_context_init(llama.params, llama.grammar); - LOG_VERBOSE("completion parameters parsed", format_generation_settings(llama, slot)); } @@ -1809,7 +1673,7 @@ int main(int argc, char **argv) svr.Post("/completion", [&llama](const Request &req, Response &res) { - //auto lock = llama.lock(); + auto lock = llama.lock(); json data = json::parse(req.body); @@ -1865,11 +1729,11 @@ int main(int argc, char **argv) // } // } - auto probs = llama.generated_token_probs; - if (llama.params.n_probs > 0 && llama.stopped_word) { - const std::vector stop_word_toks = llama_tokenize(llama.ctx, llama.stopping_word, false); - probs = std::vector(llama.generated_token_probs.begin(), llama.generated_token_probs.end() - stop_word_toks.size()); - } + // auto probs = llama.generated_token_probs; + // if (llama.params.n_probs > 0 && llama.stopped_word) { + // const std::vector stop_word_toks = llama_tokenize(llama.ctx, llama.stopping_word, false); + // probs = std::vector(llama.generated_token_probs.begin(), llama.generated_token_probs.end() - stop_word_toks.size()); + // } // const json data = format_final_response(llama, llama.generated_text, probs); @@ -1878,104 +1742,67 @@ int main(int argc, char **argv) // res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace), // "application/json"); } else { - printf("processing -> %s\n", slot->isProcessing() ? "true" : "false"); - const auto chunked_content_provider = [slot](size_t, DataSink & sink) { + const auto chunked_content_provider = [slot, &llama](size_t, DataSink & sink) { size_t sent_count = 0; size_t sent_token_probs_index = 0; while(slot->isProcessing()) { if(slot->hasNewToken()) { // new token notification - // const completion_token_output token = slot->next(); - // std::string token_str = llama_token_to_piece(llama.ctx, token.tok); - - size_t pos = std::min(sent_count, llama.generated_text.size()); - - const std::string str_test = llama.generated_text.substr(pos); - bool is_stop_full = false; - size_t stop_pos = - llama.findStoppingStrings(str_test, token_text.size(), STOP_FULL); - if (stop_pos != std::string::npos) { - is_stop_full = true; - llama.generated_text.erase( - llama.generated_text.begin() + pos + stop_pos, - llama.generated_text.end()); - pos = std::min(sent_count, llama.generated_text.size()); - } else { - is_stop_full = false; - stop_pos = llama.findStoppingStrings(str_test, token_text.size(), - STOP_PARTIAL); - } - - if ( - stop_pos == std::string::npos || - // Send rest of the text if we are at the end of the generation - (!llama.has_next_token && !is_stop_full && stop_pos > 0) - ) { - const std::string to_send = llama.generated_text.substr(pos, std::string::npos); - - sent_count += to_send.size(); - - std::vector probs_output = {}; - - if (llama.params.n_probs > 0) { - const std::vector to_send_toks = llama_tokenize(llama.ctx, to_send, false); - size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size()); - size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size()); - if (probs_pos < probs_stop_pos) { - probs_output = std::vector(llama.generated_token_probs.begin() + probs_pos, llama.generated_token_probs.begin() + probs_stop_pos); + const completion_token_output token = slot->next(); + std::string token_str = llama_token_to_piece(llama.ctx, token.tok); + std::vector probs_output = {}; + if (slot->sparams.n_probs > 0) { + const std::vector to_send_toks = llama_tokenize(llama.ctx, token_str, false); + size_t probs_pos = std::min(sent_token_probs_index, slot->generated_token_probs.size()); + size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), slot->generated_token_probs.size()); + if (probs_pos < probs_stop_pos) { + probs_output = std::vector(slot->generated_token_probs.begin() + probs_pos, slot->generated_token_probs.begin() + probs_stop_pos); + } + sent_token_probs_index = probs_stop_pos; } - sent_token_probs_index = probs_stop_pos; - } - - const json data = format_partial_response(llama, to_send, probs_output); - - const std::string str = - "data: " + - data.dump(-1, ' ', false, json::error_handler_t::replace) + - "\n\n"; - - LOG_VERBOSE("data stream", { - { "to_send", str } - }); - - if (!sink.write(str.data(), str.size())) { - LOG_VERBOSE("stream closed", {}); - llama_print_timings(llama.ctx); - return false; + const json data = format_partial_response(llama, slot, token_str, probs_output); + const std::string str = + "data: " + + data.dump(-1, ' ', false, json::error_handler_t::replace) + + "\n\n"; + LOG_VERBOSE("data stream", { + { "to_send", str } + }); + if(!sink.write(str.c_str(), str.size())) { + slot->release(); + return false; + } + } else { + std::this_thread::sleep_for(std::chrono::milliseconds(5)); } } - - if (!llama.has_next_token) { - // Generation is done, send extra information. - const json data = format_final_response( - llama, - "", - std::vector(llama.generated_token_probs.begin(), llama.generated_token_probs.begin() + sent_token_probs_index) - ); - - // const std::string str = - // "data: " + - // data.dump(-1, ' ', false, json::error_handler_t::replace) + - // "\n\n"; - - // LOG_VERBOSE("data stream", { - // { "to_send", str } - // }); - - // if (!sink.write(str.data(), str.size())) { - // LOG_VERBOSE("stream closed", {}); - // llama_print_timings(llama.ctx); - // return false; - // } + const json data = format_final_response( + llama, slot, + "", + std::vector( + slot->generated_token_probs.begin(), + slot->generated_token_probs.begin() + sent_token_probs_index) + ); + const std::string str = + "data: " + + data.dump(-1, ' ', false, json::error_handler_t::replace) + + "\n\n"; + LOG_VERBOSE("data stream", { + { "to_send", str } + }); + if (!sink.write(str.data(), str.size())) { + slot->release(); + return false; + } sink.done(); return true; - }; + }; auto on_complete = [&] (bool) { - //llama.mutex.unlock(); + llama.mutex.unlock(); slot->sent_tokens = 0; slot->generated_token_probs.clear(); slot->release(); }; - //lock.release(); + lock.release(); res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); } }); @@ -2036,15 +1863,15 @@ int main(int argc, char **argv) // std::vector probs_output = {}; - if (llama.params.n_probs > 0) { - const std::vector to_send_toks = llama_tokenize(llama.ctx, to_send, false); - size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size()); - size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size()); - if (probs_pos < probs_stop_pos) { - probs_output = std::vector(llama.generated_token_probs.begin() + probs_pos, llama.generated_token_probs.begin() + probs_stop_pos); - } - sent_token_probs_index = probs_stop_pos; - } + // if (llama.params.n_probs > 0) { + // const std::vector to_send_toks = llama_tokenize(llama.ctx, to_send, false); + // size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size()); + // size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size()); + // if (probs_pos < probs_stop_pos) { + // probs_output = std::vector(llama.generated_token_probs.begin() + probs_pos, llama.generated_token_probs.begin() + probs_stop_pos); + // } + // sent_token_probs_index = probs_stop_pos; + // } // const json data = format_partial_response(llama, to_send, probs_output); @@ -2212,6 +2039,7 @@ int main(int argc, char **argv) { running = llama.updateSlots(); } }); + if (!svr.listen_after_bind()) { return 1;