diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 2f3c3fe4f..18ee4b5be 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -37,6 +37,11 @@ struct server_params int32_t write_timeout = 600; }; +// struct beam_search_callback_data { +// llama_server_context* ctx; +// llama_client_slot* slot; +// }; + static bool server_verbose = false; #if SERVER_VERBOSE != 1 @@ -76,7 +81,7 @@ struct slot_params { uint32_t seed = -1; // RNG seed int32_t n_predict = 128; // new tokens to predict std::string grammar = ""; // optional BNF-like grammar to constrain sampling - bool remember_generation = false; // remember a the prompt to avoid reprocessing all prompt + bool cache_prompt = false; // remember a the prompt to avoid reprocessing all prompt std::vector antiprompt; json input_prefix; json input_suffix; @@ -246,12 +251,9 @@ struct llama_client_slot llama_grammar *grammar = nullptr; void reset() { - state = IDLE; - command = NONE; num_prompt_tokens = 0; num_tokens_predicted = 0; generated_text = ""; - generated_token_probs.clear(); truncated = false; stopped_eos = false; stopped_word = false; @@ -261,6 +263,7 @@ struct llama_client_slot n_past = 0; sent_count = 0; infill = false; + clean_tokens(); if (grammar != nullptr) { llama_grammar_free(grammar); @@ -300,7 +303,7 @@ struct llama_client_slot } bool hasNewToken() { - return generated_token_probs.size() > sent_tokens; + return num_tokens_predicted > sent_tokens; } bool available() { @@ -308,7 +311,7 @@ struct llama_client_slot } bool isProcessing() { - return (state == IDLE && command == LOAD_PROMPT) || state == PROCESSING; + return (state == IDLE || state == SLEEPING) && command == LOAD_PROMPT || state == PROCESSING; } completion_token_output next() { @@ -319,8 +322,6 @@ struct llama_client_slot void addTokenString(completion_token_output token) { if(command == RELEASE) { - generated_token_probs.clear(); - sent_tokens = 0; return; } context_tokens.push_back(token.tok); @@ -333,6 +334,11 @@ struct llama_client_slot command = RELEASE; } } + + void clean_tokens() { + sent_tokens = 0; + generated_token_probs.clear(); + } }; struct llama_server_context @@ -626,9 +632,7 @@ struct llama_server_context const std::string token_str = llama_token_to_piece(ctx, result.tok); slot.sampled = result.tok; slot.generated_text += token_str; - size_t pos = std::min(slot.sent_count, slot.generated_text.size()); - const std::string str_test = slot.generated_text.substr(pos); bool is_stop_full = false; size_t stop_pos = findStoppingStrings(str_test, token_str.size(), STOP_FULL, slot); @@ -737,14 +741,20 @@ struct llama_server_context if (slot.state == PROCESSING && slot.command == RELEASE && !slot.hasNewToken()) { LOG_TEE("slot %i released\n", slot.id); - slot.state = slot.params.remember_generation ? SLEEPING : IDLE; + slot.state = slot.params.cache_prompt ? SLEEPING : IDLE; + if(slot.state == SLEEPING) { + printf("%i has cached prompt."); + } slot.command = NONE; continue; } kv_cache_free -= slot.num_prompt_tokens; - if (slot.state == IDLE || slot.command == RELEASE) { + if ( + slot.state == IDLE || + slot.state == SLEEPING || + slot.command == RELEASE) { continue; } @@ -765,8 +775,6 @@ struct llama_server_context // need process the prompt bool keep_gen = slot.state == SLEEPING; // remember generation if ((slot.state == IDLE || keep_gen) && slot.command == LOAD_PROMPT) { - slot.state = PROCESSING; - slot.command = NONE; std::vector prompt_tokens; if(slot.infill) { bool suff_rm_leading_spc = true; @@ -794,6 +802,9 @@ struct llama_server_context slot.n_past = keep_gen ? common_part(slot.context_tokens, prompt_tokens) : 0; + printf("n_past: %i, context: %i, prompt: %i, cache: %s\n", + slot.n_past ,slot.context_tokens.size(), prompt_tokens.size(), keep_gen ? "true" : "false"); + slot.context_tokens = prompt_tokens; if (slot.n_past == slot.num_prompt_tokens) { @@ -812,7 +823,7 @@ struct llama_server_context std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end(), 0); for (; slot.n_past < prompt_tokens.size(); ++slot.n_past) { - //printf(llama_token_to_piece(ctx, prompt_tokens[slot.n_past]).c_str()); + printf(llama_token_to_piece(ctx, prompt_tokens[slot.n_past]).c_str()); batch.token [batch.n_tokens] = prompt_tokens[slot.n_past]; batch.pos [batch.n_tokens] = slot.n_past + num_tokens_system; batch.seq_id[batch.n_tokens] = slot.id; @@ -827,6 +838,8 @@ struct llama_server_context slot.n_decoded = 0; slot.i_batch = batch.n_tokens - 1; + slot.state = PROCESSING; + slot.command = NONE; } } } @@ -868,10 +881,18 @@ struct llama_server_context } for (auto & slot : slots) { + if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) { continue; } + // prompt evaluated for embedding + if(params.embedding) { + slot.release(); + slot.i_batch = -1; + return true; + } + completion_token_output result; const llama_token id = llama_sampling_sample(ctx, NULL, slot.ctx_sampling, slot.last_n_tokens, candidates, slot.i_batch - i); llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; @@ -1316,6 +1337,7 @@ static json format_final_response(llama_server_context &llama, llama_client_slot json res = json{ {"content", content}, + {"slot_id", slot->id}, {"stop", true}, {"model", llama.params.model_alias}, {"tokens_predicted", slot->num_tokens_predicted}, @@ -1327,7 +1349,7 @@ static json format_final_response(llama_server_context &llama, llama_client_slot {"stopped_word", slot->stopped_word}, {"stopped_limit", slot->stopped_limit}, {"stopping_word", slot->stopping_word}, - {"tokens_cached", slot->n_past}, + {"tokens_cached", slot->n_past} // {"timings", format_timings(llama)}, }; @@ -1383,6 +1405,7 @@ static void parse_options_completion(const json &body, llama_client_slot* slot, llama_sampling_params default_sparams; slot->params.stream = json_value(body, "stream", false); + slot->params.cache_prompt = json_value(body, "cache_prompt", false); slot->params.n_predict = json_value(body, "n_predict", default_params.n_predict); slot->sparams.top_k = json_value(body, "top_k", default_sparams.top_k); slot->sparams.top_p = json_value(body, "top_p", default_sparams.top_p); @@ -1495,8 +1518,8 @@ static void log_server_request(const Request &req, const Response &res) }); } -static bool is_at_eob(llama_server_context &server_context, const llama_token *tokens, const size_t n_tokens) { - return n_tokens && tokens[n_tokens-1] == llama_token_eos(server_context.ctx); +static bool is_at_eob(llama_server_context * server_context, const llama_token *tokens, const size_t n_tokens) { + return n_tokens && tokens[n_tokens-1] == llama_token_eos(server_context->ctx); } // Function matching type llama_beam_search_callback_fn_t. @@ -1509,21 +1532,21 @@ static bool is_at_eob(llama_server_context &server_context, const llama_token *t // AVOID HEADACHES unnecessaries // static void beam_search_callback(void *callback_data, llama_beams_state beams_state) { -// auto & llama = *static_cast(callback_data); +// auto & llama = *static_cast(callback_data); // // Mark beams as EOS as needed. // for (size_t i = 0 ; i < beams_state.n_beams ; ++i) { // llama_beam_view& beam_view = beams_state.beam_views[i]; -// if (!beam_view.eob && is_at_eob(llama, beam_view.tokens, beam_view.n_tokens)) { +// if (!beam_view.eob && is_at_eob(llama.ctx, beam_view.tokens, beam_view.n_tokens)) { // beam_view.eob = true; // } // } // printf(","); // Show progress // if (const size_t n = beams_state.common_prefix_length) { -// llama.generated_token_probs.resize(llama.generated_token_probs.size() + n); +// llama.slot->generated_token_probs.resize(llama.slot->generated_token_probs.size() + n); // assert(0u < beams_state.n_beams); // const llama_token * tokens = beams_state.beam_views[0].tokens; // const auto map = [](llama_token tok) { return completion_token_output{{},tok}; }; -// std::transform(tokens, tokens + n, llama.generated_token_probs.end() - n, map); +// std::transform(tokens, tokens + n, llama.slot->generated_token_probs.end() - n, map); // printf("%zu", n); // } // fflush(stdout); @@ -1541,17 +1564,17 @@ struct token_translator { std::string operator()(const completion_token_output & cto) const { return (*this)(cto.tok); } }; -static void append_to_generated_text_from_generated_token_probs(llama_server_context &llama, llama_client_slot & slot) +static void append_to_generated_text_from_generated_token_probs(llama_server_context &llama, llama_client_slot* slot) { - auto & gtps = slot.generated_token_probs; + auto & gtps = slot->generated_token_probs; auto translator = token_translator{llama.ctx}; auto add_strlen = [=](size_t sum, const completion_token_output & cto) { return sum + translator(cto).size(); }; const size_t len = std::accumulate(gtps.begin(), gtps.end(), size_t(0), add_strlen); - if (slot.generated_text.capacity() < slot.generated_text.size() + len) { - slot.generated_text.reserve(slot.generated_text.size() + len); + if (slot->generated_text.capacity() < slot->generated_text.size() + len) { + slot->generated_text.reserve(slot->generated_text.size() + len); } for (const completion_token_output & cto : gtps) { - slot.generated_text += translator(cto); + slot->generated_text += translator(cto); } } @@ -1662,17 +1685,19 @@ int main(int argc, char **argv) std::string completion_text = ""; if (llama.params.n_beams) { // // Fill llama.generated_token_probs vector with final beam. - // llama_beam_search(llama.ctx, beam_search_callback, &llama, llama.params.n_beams, - // slot->n_past, llama.n_remain); + // beam_search_callback_data data_; + // data_.slot = slot; + // data_.ctx = &llama; + // llama_beam_search(llama.ctx, beam_search_callback, &data_, llama.params.n_beams, + // slot->n_past, llama.params.n_predict); // // Translate llama.generated_token_probs to llama.generated_text. - // append_to_generated_text_from_generated_token_probs(llama); + // append_to_generated_text_from_generated_token_probs(llama, slot); } else { - while (slot->isProcessing()) { if(slot->hasNewToken()) { completion_text += slot->next().text_to_send; } else { - std::this_thread::sleep_for(std::chrono::milliseconds(5)); + std::this_thread::sleep_for(std::chrono::microseconds(5)); } } } @@ -1686,7 +1711,7 @@ int main(int argc, char **argv) const json data = format_final_response(llama, slot, completion_text, probs); //llama_print_timings(llama.ctx); - + slot->release(); res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace), "application/json"); } else { @@ -1743,8 +1768,7 @@ int main(int argc, char **argv) return true; }; auto on_complete = [slot, &llama] (bool) { - slot->sent_tokens = 0; - slot->generated_token_probs.clear(); + slot->clean_tokens(); slot->release(); }; res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); @@ -1781,7 +1805,27 @@ int main(int argc, char **argv) return; } - const auto chunked_content_provider = [slot, &llama](size_t, DataSink & sink) { + if(!slot->params.stream) { + std::string completion_text = ""; + while (slot->isProcessing()) { + if(slot->hasNewToken()) { + completion_text += slot->next().text_to_send; + } else { + std::this_thread::sleep_for(std::chrono::microseconds(5)); + } + } + auto probs = slot->generated_token_probs; + if (slot->sparams.n_probs > 0 && slot->stopped_word) { + const std::vector stop_word_toks = llama_tokenize(llama.ctx, slot->stopping_word, false); + probs = std::vector(slot->generated_token_probs.begin(), slot->generated_token_probs.end() - stop_word_toks.size()); + } + + const json data = format_final_response(llama, slot, completion_text, probs); + //llama_print_timings(llama.ctx); + res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace), + "application/json"); + } else { + const auto chunked_content_provider = [slot, &llama](size_t, DataSink & sink) { size_t sent_token_probs_index = 0; while(slot->isProcessing()) { if(slot->hasNewToken()) { // new token notification @@ -1834,11 +1878,11 @@ int main(int argc, char **argv) return true; }; auto on_complete = [slot, &llama] (bool) { - slot->sent_tokens = 0; - slot->generated_token_probs.clear(); + slot->clean_tokens(); slot->release(); }; - res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); + res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); + } }); svr.Get("/model.json", [&llama](const Request &, Response &res) @@ -1878,9 +1922,7 @@ int main(int argc, char **argv) svr.Post("/embedding", [&llama](const Request &req, Response &res) { const json body = json::parse(req.body); - llama_client_slot* slot = llama.getSlot(-1); - slot->reset(); //llama_reset_timings(llama.ctx); if (body.count("content") != 0)