From 569ebf11cfd99ffbacc94ca4e52e2d3925c0ec0b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 22 Oct 2023 16:57:05 +0300 Subject: [PATCH] server : refactor ctx_sampling init + n_ctx + names --- examples/server/server.cpp | 256 ++++++++++++++++++------------------- 1 file changed, 121 insertions(+), 135 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 105b6d92e..5a6c434de 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -341,47 +341,55 @@ struct llama_client_slot { int id; int task_id = -1; + + struct slot_params params; + + slot_state state = IDLE; + slot_command command = NONE; + // generation props - int32_t n_past = 0; - int32_t n_decoded = 0; - int32_t i_batch = -1; - size_t num_prompt_tokens = 0; - int32_t num_prompt_tokens_processed = 0; + int32_t n_ctx = 0; // context size per slot + int32_t n_past = 0; + int32_t n_decoded = 0; int32_t n_remaining = -1; + int32_t i_batch = -1; + + int32_t num_prompt_tokens = 0; + int32_t num_prompt_tokens_processed = 0; + int32_t multibyte_pending = 0; json prompt; std::string generated_text; llama_token sampled; std::vector cache_tokens; std::vector generated_token_probs; - slot_state state = IDLE; - slot_command command = NONE; + + bool infill = false; + bool has_next_token = true; bool truncated = false; bool stopped_eos = false; bool stopped_word = false; bool stopped_limit = false; + std::string stopping_word; - int32_t multibyte_pending = 0; + + // sampling + struct llama_sampling_params sparams; + llama_sampling_context *ctx_sampling = nullptr; + + // multimodal + std::vector images; + + // stats size_t sent_count = 0; size_t sent_token_probs_index = 0; - bool infill = false; int64_t t_start_process_prompt; int64_t t_start_genereration; double t_prompt_processing; // ms double t_token_generation; // ms - struct slot_params params; - - // sampling - struct llama_sampling_params sparams; - llama_sampling_context* ctx_sampling = nullptr; - bool has_next_token = true; - - // multimodal - std::vector images; - void reset() { num_prompt_tokens = 0; generated_text = ""; @@ -397,13 +405,6 @@ struct llama_client_slot infill = false; generated_token_probs.clear(); - if (ctx_sampling != nullptr) - { - llama_sampling_free(ctx_sampling); - } - - ctx_sampling = llama_sampling_init(sparams); - for (slot_image &img : images) { free(img.image_embedding); @@ -415,17 +416,6 @@ struct llama_client_slot // llama_set_rng_seed(ctx, params.seed); in batched the seed matter??????? } - bool load_grammar() - { - if (ctx_sampling != nullptr) - { - llama_sampling_free(ctx_sampling); - } - - ctx_sampling = llama_sampling_init(sparams); - return ctx_sampling != nullptr; - } - bool has_budget(gpt_params &global_params) { n_remaining = -1; if(params.n_predict != -1) @@ -491,33 +481,33 @@ struct llama_client_slot struct llama_server_context { - std::vector slots; - - // system prompt - std::string system_prompt; - bool need_update_system_prompt = false; - std::vector tokens_system; - int32_t num_tokens_system; - - // broadcast to all clients to keep the same prompt format - std::string user_name; // this should be the anti prompt - std::string assistant_name; // this is for generate the prompt - - bool multimodal = false; - clip_ctx *clp_ctx = nullptr; - int n_embd; - llama_model *model = nullptr; llama_context *ctx = nullptr; - llama_batch batch; - bool all_slots_are_idle = false; - gpt_params params; - int n_ctx; - int n_vocab; - int max_ctx_per_slot = -1; - bool clean_kv_cache = true; - int id_gen; + clip_ctx *clp_ctx = nullptr; + + gpt_params params; + + llama_batch batch; + + bool multimodal = false; + bool clean_kv_cache = true; + bool all_slots_are_idle = false; + + int32_t id_gen; + int32_t n_ctx; // total context for all clients / slots + + // system prompt + bool system_need_update = false; + + std::string system_prompt; + std::vector system_tokens; + + std::string name_user; // this should be the antiprompt + std::string name_assistant; + + // slots / clients + std::vector slots; std::vector queue_tasks; std::vector queue_results; @@ -541,7 +531,7 @@ struct llama_server_context bool load_model(const gpt_params ¶ms_) { params = params_; - if(!params.mmproj.empty()) { + if (!params.mmproj.empty()) { multimodal = true; LOG_TEE("Multi Modal Mode Enabled"); clp_ctx = clip_model_load(params.mmproj.c_str(), /*verbosity=*/ 1); @@ -550,10 +540,11 @@ struct llama_server_context return false; } - if(params.n_ctx < 2048) { // request larger context for the image embedding + if (params.n_ctx < 2048) { // request larger context for the image embedding params.n_ctx = 2048; } } + std::tie(model, ctx) = llama_init_from_gpt_params(params); if (model == nullptr) { @@ -561,18 +552,19 @@ struct llama_server_context return false; } - if(multimodal) { - int n_img_embd = clip_n_mmproj_embd(clp_ctx); - n_embd = llama_n_embd(model); - if (n_img_embd != n_embd) { - LOG_TEE("%s: embedding dim of the multimodal projector (%d) is not equal to that of LLaMA (%d). Make sure that you use the correct mmproj file.\n", __func__, n_img_embd, n_embd); + if (multimodal) { + const int n_embd_clip = clip_n_mmproj_embd(clp_ctx); + const int n_embd_llm = llama_n_embd(model); + if (n_embd_clip != n_embd_llm) { + LOG_TEE("%s: embedding dim of the multimodal projector (%d) is not equal to that of LLaMA (%d). Make sure that you use the correct mmproj file.\n", __func__, n_embd_clip, n_embd_llm); llama_free(ctx); llama_free_model(model); return false; } } + n_ctx = llama_n_ctx(ctx); - n_vocab = llama_n_vocab(model); + return true; } @@ -581,25 +573,19 @@ struct llama_server_context // create slots all_slots_are_idle = true; - if (max_ctx_per_slot == -1) - { - max_ctx_per_slot = n_ctx / params.n_parallel; // split context - } - if (max_ctx_per_slot * params.n_parallel > n_ctx) - { - printf("Error: The max context per slot is more greater than model context size"); - return; - } + + const int32_t n_ctx_slot = n_ctx / params.n_parallel; LOG_TEE("Available slots:\n"); for (int i = 0; i < params.n_parallel; i++) { llama_client_slot slot; + slot.id = i; - slot.sparams.n_prev = max_ctx_per_slot; + slot.n_ctx = n_ctx_slot; slot.reset(); - LOG_TEE(" -> Slot %i - max context: %i\n", slot.id, max_ctx_per_slot); + LOG_TEE(" -> Slot %i - max context: %i\n", slot.id, n_ctx_slot); slots.push_back(slot); } @@ -607,7 +593,7 @@ struct llama_server_context // empty system prompt system_prompt = ""; - num_tokens_system = 0; + system_tokens.clear(); } std::vector tokenize(const json & json_prompt, bool add_bos) const @@ -699,16 +685,16 @@ struct llama_server_context { slot->params.input_prefix = ""; } + if (data.count("input_suffix") != 0) { slot->params.input_suffix = data["input_suffix"]; } - - // common params else { slot->params.input_suffix = ""; } + if (data.count("prompt") != 0) { slot->prompt = data["prompt"]; @@ -717,11 +703,14 @@ struct llama_server_context { slot->prompt = ""; } + slot->sparams.logit_bias.clear(); + if (json_value(data, "ignore_eos", false)) { slot->sparams.logit_bias[llama_token_eos(ctx)] = -INFINITY; } + const auto &logit_bias = data.find("logit_bias"); if (logit_bias != data.end() && logit_bias->is_array()) { @@ -832,36 +821,37 @@ struct llama_server_context } } } - if (!slot->load_grammar()) + + if (slot->ctx_sampling != nullptr) { - return false; + llama_sampling_free(slot->ctx_sampling); } - all_slots_are_idle = false; + slot->ctx_sampling = llama_sampling_init(slot->sparams); slot->command = LOAD_PROMPT; + + all_slots_are_idle = false; + LOG_TEE("slot %i is processing [task id: %i]\n", slot->id, slot->task_id); + return true; } void kv_cache_clear() { // clear the entire KV cache - for (int i = 0; i < params.n_parallel; ++i) - { - llama_kv_cache_seq_rm(ctx, i, 0, -1); - } + llama_kv_cache_tokens_rm(ctx, -1, -1); clean_kv_cache = false; } void update_system_prompt() { - tokens_system = ::llama_tokenize(ctx, system_prompt, true); - num_tokens_system = tokens_system.size(); + system_tokens = ::llama_tokenize(ctx, system_prompt, true); - batch.n_tokens = num_tokens_system; + llama_batch_clear(batch); kv_cache_clear(); for (int32_t i = 0; i < batch.n_tokens; ++i) { - llama_batch_add(batch, tokens_system[i], i, { 0 }, false); + llama_batch_add(batch, system_tokens[i], i, { 0 }, false); } if (llama_decode(ctx, batch) != 0) @@ -873,11 +863,11 @@ struct llama_server_context // assign the system KV cache to all parallel sequences for (int32_t i = 1; i < params.n_parallel; ++i) { - llama_kv_cache_seq_cp(ctx, 0, i, 0, num_tokens_system); + llama_kv_cache_seq_cp(ctx, 0, i, 0, system_tokens.size()); } LOG_TEE("system prompt updated\n"); - need_update_system_prompt = false; + system_need_update = false; } void notify_system_prompt_changed() { @@ -890,8 +880,8 @@ struct llama_server_context all_slots_are_idle = true; // wait until system prompt load - need_update_system_prompt = true; - while (need_update_system_prompt) { + system_need_update = true; + while (system_need_update) { std::this_thread::sleep_for(std::chrono::milliseconds(5)); } // system prompt loaded, continue @@ -899,8 +889,8 @@ struct llama_server_context void process_system_prompt_data(const json &sys_props) { system_prompt = sys_props.value("prompt", ""); - user_name = sys_props.value("anti_prompt", ""); - assistant_name = sys_props.value("assistant_name", ""); + name_user = sys_props.value("anti_prompt", ""); + name_assistant = sys_props.value("assistant_name", ""); if (slots.size() > 0) { @@ -908,7 +898,7 @@ struct llama_server_context } else { - need_update_system_prompt = true; + system_need_update = true; } } @@ -1036,14 +1026,14 @@ struct llama_server_context } // check the limits - if ( - slot.n_decoded > 2 && slot.has_next_token && !slot.has_budget(params)) + if (slot.n_decoded > 2 && slot.has_next_token && !slot.has_budget(params)) { slot.stopped_limit = true; slot.has_next_token = false; } - if (!slot.cache_tokens.empty() && result.tok == llama_token_eos(ctx)) { + if (!slot.cache_tokens.empty() && result.tok == llama_token_eos(ctx)) + { slot.stopped_eos = true; slot.has_next_token = false; LOG_VERBOSE("eos token found", {}); @@ -1111,12 +1101,12 @@ struct llama_server_context return get_formated_generation(slots[0]); } - json get_formated_generation(llama_client_slot & slot) { + json get_formated_generation(llama_client_slot &slot) { const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(ctx)); const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second); - return json{ - {"n_ctx", max_ctx_per_slot}, + return json { + {"n_ctx", slot.n_ctx}, {"model", params.model_alias}, {"seed", slot.params.seed}, {"temp", slot.sparams.temp}, @@ -1219,7 +1209,8 @@ struct llama_server_context res.id = slot.task_id; res.error = false; res.stop = true; - static const int n_embd = llama_n_embd(model); + + const int n_embd = llama_n_embd(model); if (!params.embedding) { LOG_WARNING("embedding disabled", { @@ -1229,7 +1220,9 @@ struct llama_server_context { {"embedding", std::vector(n_embd, 0.0f)}, }; - } else { + } + else + { const float *data = llama_get_embeddings(ctx); std::vector embedding(data, data + n_embd); res.result_json = json @@ -1312,6 +1305,7 @@ struct llama_server_context n_eval = n_batch; } + const int n_embd = llama_n_embd(model); llama_batch batch_img = { n_eval, nullptr, (img.image_embedding + i * n_embd), nullptr, nullptr, nullptr, nullptr, slot.n_past, 1, 0, }; if (llama_decode(ctx, batch_img)) { @@ -1400,7 +1394,7 @@ struct llama_server_context process_tasks(); // update the system prompt wait until all slots are idle state - if (need_update_system_prompt) + if (system_need_update) { LOG_TEE("updating system prompt\n"); update_system_prompt(); @@ -1421,7 +1415,7 @@ struct llama_server_context for (llama_client_slot &slot : slots) { - if (slot.is_processing() && slot.cache_tokens.size() >= (size_t)max_ctx_per_slot) + if (slot.is_processing() && slot.cache_tokens.size() >= (size_t) slot.n_ctx) { // Shift context const int n_left = slot.n_past - slot.params.n_keep - 1; @@ -1443,7 +1437,7 @@ struct llama_server_context slot.truncated = true; LOG_VERBOSE("context shift", { - {"n_ctx", n_ctx}, + {"n_ctx", n_ctx}, {"n_keep", params.n_keep}, {"n_left", n_left}, }); @@ -1478,7 +1472,7 @@ struct llama_server_context slot.i_batch = batch.n_tokens; - llama_batch_add(batch, slot.sampled, num_tokens_system + slot.n_past, { slot.id }, true); + llama_batch_add(batch, slot.sampled, system_tokens.size() + slot.n_past, { slot.id }, true); slot.n_decoded += 1; slot.n_past += 1; @@ -1537,21 +1531,21 @@ struct llama_server_context { if (slot.params.n_keep < 0) { - slot.params.n_keep = (int)slot.num_prompt_tokens; + slot.params.n_keep = slot.num_prompt_tokens; } - slot.params.n_keep = std::min(max_ctx_per_slot - 4, slot.params.n_keep); + slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep); //if input prompt is too big, truncate like normal - if (slot.num_prompt_tokens >= (size_t)max_ctx_per_slot) + if (slot.num_prompt_tokens >= slot.n_ctx) { // applied bug of #3661 - const int n_left = max_ctx_per_slot - slot.params.n_keep; + const int n_left = slot.n_ctx - slot.params.n_keep; const int n_block_size = n_left / 2; const int erased_blocks = (slot.num_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size; std::vector new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + slot.params.n_keep); // Use half the left-over space in the context for the prompt new_tokens.insert(new_tokens.end(), prompt_tokens.end() + slot.params.n_keep + erased_blocks * n_block_size, prompt_tokens.end()); LOG_VERBOSE("input truncated", { - {"n_ctx", max_ctx_per_slot}, + {"n_ctx", slot.n_ctx}, {"n_keep", slot.params.n_keep}, {"n_left", n_left}, {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())}, @@ -1559,7 +1553,7 @@ struct llama_server_context slot.truncated = true; prompt_tokens = new_tokens; slot.num_prompt_tokens = prompt_tokens.size(); - GGML_ASSERT(slot.num_prompt_tokens < (size_t)max_ctx_per_slot); + GGML_ASSERT(slot.num_prompt_tokens < slot.n_ctx); } const size_t ps = slot.num_prompt_tokens; std::fill(slot.ctx_sampling->prev.begin(), slot.ctx_sampling->prev.end() - ps, 0); @@ -1569,12 +1563,13 @@ struct llama_server_context LOG_TEE("slot %d : in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed); } - LOG_TEE("slot %d : kv cache rm - [%d, end)\n", slot.id, num_tokens_system + slot.n_past); - llama_kv_cache_seq_rm(ctx, slot.id, num_tokens_system + slot.n_past, -1); + LOG_TEE("slot %d : kv cache rm - [%d, end)\n", slot.id, (int) system_tokens.size() + slot.n_past); + + llama_kv_cache_seq_rm(ctx, slot.id, system_tokens.size() + slot.n_past, -1); slot.cache_tokens = prompt_tokens; - if (slot.n_past == (int) slot.num_prompt_tokens) + if (slot.n_past == slot.num_prompt_tokens) { // we have to evaluate at least 1 token to generate logits. LOG_TEE("slot %d : we have to evaluate at least 1 token to generate logits\n", slot.id); @@ -1593,7 +1588,7 @@ struct llama_server_context std::vector prefix_tokens = has_images ? tokenize(slot.images[0].prefix_prompt, true) : prompt_tokens; for (; slot.n_past < (int) prefix_tokens.size(); ++slot.n_past) { - llama_batch_add(batch, prefix_tokens[slot.n_past], num_tokens_system + slot.n_past, { slot.id }, false); + llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot.n_past, { slot.id }, false); } if (has_images && !ingest_images(slot, n_batch)) @@ -1842,15 +1837,6 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, } params.n_ctx = std::stoi(argv[i]); } - else if (arg == "-cps" || arg == "--ctx-per-slot" || arg == "--ctx_per_slot") - { - if (++i >= argc) - { - invalid_param = true; - break; - } - llama.max_ctx_per_slot = std::stoi(argv[i]); - } else if (arg == "--rope-freq-base") { if (++i >= argc) @@ -2227,8 +2213,8 @@ int main(int argc, char **argv) { res.set_header("Access-Control-Allow-Origin", "*"); json data = { - { "user_name", llama.user_name.c_str() }, - { "assistant_name", llama.assistant_name.c_str() } + { "user_name", llama.name_user.c_str() }, + { "assistant_name", llama.name_assistant.c_str() } }; res.set_content(data.dump(), "application/json"); }); @@ -2434,7 +2420,7 @@ int main(int argc, char **argv) svr.set_base_dir(sparams.public_path); // to make it ctrl+clickable: - printf("\nllama server listening at http://%s:%d\n\n", sparams.hostname.c_str(), sparams.port); + LOG_TEE("\nllama server listening at http://%s:%d\n\n", sparams.hostname.c_str(), sparams.port); LOG_INFO("HTTP server listening", { {"hostname", sparams.hostname},