From 5b8e29de533110d6fc47491ba4e16d602fbc4529 Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Thu, 12 Oct 2023 17:09:12 -0400 Subject: [PATCH] multiple client support --- examples/server/server.cpp | 53 +++++++++++++------------------------- 1 file changed, 18 insertions(+), 35 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 3e073de1c..53a209736 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -76,7 +76,7 @@ struct slot_params { uint32_t seed = -1; // RNG seed int32_t n_predict = 128; // new tokens to predict std::string grammar = ""; // optional BNF-like grammar to constrain sampling - bool remember_generation = false; // remember a part of the prompt to avoid reprocessing all prompt + bool remember_generation = false; // remember a the prompt to avoid reprocessing all prompt std::vector antiprompt; }; @@ -256,6 +256,7 @@ struct llama_client_slot stopping_word = ""; multibyte_pending = 0; n_past = 0; + sent_count = 0; if (grammar != nullptr) { llama_grammar_free(grammar); @@ -299,8 +300,7 @@ struct llama_client_slot } bool available() { - return state == IDLE && - command == NONE && !params.remember_generation; + return state == IDLE && command == NONE; } bool isProcessing() { @@ -354,12 +354,6 @@ struct llama_server_context int n_ctx; int n_vocab; bool clean_kv_cache = true; - std::mutex mutex; - - std::unique_lock lock() - { - return std::unique_lock(mutex); - } ~llama_server_context() { @@ -406,7 +400,7 @@ struct llama_server_context slot.last_n_tokens.resize(params.n_predict); // max prediction per slot slot.reset(); std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end(), 0); - LOG_TEE(" - slot %i\n", slot.id); + LOG_TEE(" -> Slot %i\n", slot.id); slots.push_back(slot); } LOG_TEE("Context Size: %i\n", params.n_ctx); @@ -716,7 +710,6 @@ struct llama_server_context slot.generated_text.begin() + pos + stop_pos, slot.generated_text.end()); pos = std::min(slot.sent_count, slot.generated_text.size()); - result.tok = -1; } else { is_stop_full = false; stop_pos = findStoppingStrings(str_test, token_str.size(), @@ -737,7 +730,6 @@ struct llama_server_context { slot.generated_token_probs.push_back(result); } - if (slot.multibyte_pending > 0) { slot.multibyte_pending -= token_str.size(); @@ -780,7 +772,6 @@ struct llama_server_context slot.stopped_eos = true; LOG_VERBOSE("eos token found", {}); } - LOG_VERBOSE("next token", { {"token", result.tok}, {"token_text", tokens_to_output_formatted_string(ctx, result.tok)}, @@ -818,18 +809,11 @@ struct llama_server_context if (slot.state == PROCESSING && slot.command == RELEASE && !slot.hasNewToken()) { LOG_TEE("slot %i released\n", slot.id); - if(!slot.params.remember_generation) { - llama_kv_cache_seq_rm(ctx, slot.id, num_tokens_system, n_ctx); - slot.state = IDLE; - slot.command = NONE; - slot.generated_text.clear(); - return true; - } else { - slot.state = SLEEPING; - slot.command = NONE; - } + slot.state = slot.params.remember_generation ? SLEEPING : IDLE; + slot.command = NONE; continue; } + kv_cache_free -= slot.num_prompt_tokens; if (slot.state == IDLE || slot.command == RELEASE) { @@ -858,23 +842,28 @@ struct llama_server_context auto prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt slot.num_prompt_tokens = prompt_tokens.size(); + slot.n_past = keep_gen ? common_part(slot.context_tokens, prompt_tokens) : 0; slot.context_tokens = prompt_tokens; + if (slot.n_past == slot.num_prompt_tokens) { + // we have to evaluate at least 1 token to generate logits. + printf("we have to evaluate at least 1 token to generate logits\n"); + slot.n_past--; + } + + llama_kv_cache_seq_rm(ctx, slot.id, num_tokens_system + slot.n_past, -1); + LOG_VERBOSE("prompt ingested", { {"n_past", slot.n_past}, {"cached", tokens_to_str(ctx, slot.context_tokens.cbegin(), slot.context_tokens.cbegin() + slot.n_past)}, {"to_eval", tokens_to_str(ctx, slot.context_tokens.cbegin() + slot.n_past, slot.context_tokens.cend())}, }); - if(system_prompt.empty()) { - llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1); - } - std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end(), 0); for (; slot.n_past < prompt_tokens.size(); ++slot.n_past) { - printf(llama_token_to_piece(ctx, prompt_tokens[slot.n_past]).c_str()); + //printf(llama_token_to_piece(ctx, prompt_tokens[slot.n_past]).c_str()); batch.token [batch.n_tokens] = prompt_tokens[slot.n_past]; batch.pos [batch.n_tokens] = slot.n_past + num_tokens_system; batch.seq_id[batch.n_tokens] = slot.id; @@ -1693,7 +1682,6 @@ int main(int argc, char **argv) svr.Post("/completion", [&llama](const Request &req, Response &res) { - auto lock = llama.lock(); json data = json::parse(req.body); @@ -1763,13 +1751,12 @@ int main(int argc, char **argv) // "application/json"); } else { const auto chunked_content_provider = [slot, &llama](size_t, DataSink & sink) { - size_t sent_count = 0; size_t sent_token_probs_index = 0; while(slot->isProcessing()) { if(slot->hasNewToken()) { // new token notification const completion_token_output token = slot->next(); std::vector probs_output = {}; - if (slot->sparams.n_probs > 0) { + if (slot->sparams.n_probs > 0) { const std::vector to_send_toks = llama_tokenize(llama.ctx, token.text_to_send, false); size_t probs_pos = std::min(sent_token_probs_index, slot->generated_token_probs.size()); size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), slot->generated_token_probs.size()); @@ -1816,12 +1803,10 @@ int main(int argc, char **argv) return true; }; auto on_complete = [slot, &llama] (bool) { - llama.mutex.unlock(); slot->sent_tokens = 0; slot->generated_token_probs.clear(); slot->release(); }; - lock.release(); res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); } }); @@ -1956,7 +1941,6 @@ int main(int argc, char **argv) svr.Post("/tokenize", [&llama](const Request &req, Response &res) { - auto lock = llama.lock(); const json body = json::parse(req.body); std::vector tokens; @@ -1969,7 +1953,6 @@ int main(int argc, char **argv) svr.Post("/detokenize", [&llama](const Request &req, Response &res) { - auto lock = llama.lock(); const json body = json::parse(req.body); std::string content;