From 7196c4e08ae60608f60a12676e24034532aaed87 Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Wed, 18 Oct 2023 16:50:09 -0400 Subject: [PATCH] new sampling API --- examples/server/server.cpp | 168 +++++++++++++++++-------------------- 1 file changed, 78 insertions(+), 90 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 97171b125..8beb19983 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -125,7 +125,7 @@ enum slot_command { struct slot_params { bool stream = true; uint32_t seed = -1; // RNG seed - int n_keep = 0; // RNG seed + int n_keep = 0; // number of tokens to keep from initial prompt int32_t n_predict = -1; // new tokens to predict std::string grammar = ""; // optional BNF-like grammar to constrain sampling bool cache_prompt = false; // remember a the prompt to avoid reprocessing all prompt @@ -262,6 +262,34 @@ static json probs_vector_to_json(const llama_context *ctx, const std::vectorparams = sparams; + result->grammar = nullptr; + + // if there is a grammar, parse it + if (!grammar.empty()) { + result->parsed_grammar = grammar_parser::parse(grammar.c_str()); + + // will be empty (default) if there are parse errors + if (result->parsed_grammar.rules.empty()) { + fprintf(stderr, "%s: failed to parse grammar\n", __func__); + return nullptr; + } + + std::vector grammar_rules(result->parsed_grammar.c_rules()); + + result->grammar = llama_grammar_init( + grammar_rules.data(), + grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root")); + } + + result->prev.resize(n_ctx); + + return result; +} + struct slot_image { clip_image_u8 img_data; bool request_encode_image = false; @@ -287,7 +315,6 @@ struct llama_client_slot int num_tokens_predicted = 0; llama_token sampled; std::vector cache_tokens; - std::vector last_n_tokens; std::vector generated_token_probs; int sent_tokens = 0; slot_state state = IDLE; @@ -307,13 +334,12 @@ struct llama_client_slot double t_token_generation; // ms struct slot_params params; - struct llama_sampling_params sparams; - llama_sampling_context ctx_sampling; - bool has_next_token = true; - // grammar props - grammar_parser::parse_state parsed_grammar; - llama_grammar *grammar = nullptr; + // sampling + struct llama_sampling_params sparams; + llama_sampling_context* ctx_sampling = nullptr; + bool has_next_token = true; + int max_context_size = 0; // multimodal std::vector images; @@ -332,47 +358,26 @@ struct llama_client_slot infill = false; clean_tokens(); - if (grammar != nullptr) { - llama_grammar_free(grammar); - grammar = nullptr; - ctx_sampling.params = sparams; - ctx_sampling.grammar = NULL; + if (ctx_sampling != nullptr) { + llama_sampling_free(ctx_sampling); } + ctx_sampling = llama_sampling_init_srv(sparams, params.grammar, max_context_size); + for(slot_image img : images) { free(img.image_embedding); delete[] img.img_data.data; img.prefix_prompt = ""; } + images.clear(); // llama_set_rng_seed(ctx, params.seed); in batched the seed matter??????? } bool loadGrammar(llama_token eos) { - if (!params.grammar.empty()) { - parsed_grammar = grammar_parser::parse(params.grammar.c_str()); - // will be empty (default) if there are parse errors - if (parsed_grammar.rules.empty()) { - LOG_ERROR("grammar parse error", {{"grammar", params.grammar}}); - return false; - } - grammar_parser::print_grammar(stderr, parsed_grammar); - - { - auto it = sparams.logit_bias.find(eos); - if (it != sparams.logit_bias.end() && it->second == -INFINITY) { - LOG_WARNING("EOS token is disabled, which will cause most grammars to fail", {}); - } - } - - std::vector grammar_rules(parsed_grammar.c_rules()); - grammar = llama_grammar_init( - grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); - } - ctx_sampling.params = sparams; - ctx_sampling.grammar = grammar; - return true; + ctx_sampling = llama_sampling_init_srv(sparams, params.grammar, max_context_size); + return ctx_sampling != nullptr; } bool hasBudget(gpt_params &global_params) { @@ -448,7 +453,6 @@ struct llama_server_context llama_model *model = nullptr; llama_context *ctx = nullptr; llama_batch batch; - std::vector candidates; bool all_slots_are_idle = false; gpt_params params; int n_ctx; @@ -468,11 +472,6 @@ struct llama_server_context llama_free_model(model); model = nullptr; } - for(auto &slot : slots) { - if(slot.grammar) { - llama_grammar_free(slot.grammar); - } - } } bool loadModel(const gpt_params ¶ms_) @@ -510,7 +509,6 @@ struct llama_server_context } n_ctx = llama_n_ctx(ctx); n_vocab = llama_n_vocab(model); - candidates.reserve(n_vocab); return true; } @@ -529,13 +527,12 @@ struct llama_server_context { llama_client_slot slot; slot.id = i; - slot.last_n_tokens.resize(max_ctx_per_slot); - std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end(), 0); + slot.max_context_size = max_ctx_per_slot; slot.reset(); LOG_TEE(" -> Slot %i - max context: %i\n", slot.id, max_ctx_per_slot); slots.push_back(slot); } - batch = llama_batch_init(n_ctx, 0); + batch = llama_batch_init(n_ctx, 0, 1); // empty system prompt system_prompt = ""; num_tokens_system = 0; @@ -626,10 +623,7 @@ struct llama_server_context for (int32_t i = 0; i < batch.n_tokens; ++i) { - batch.token[i] = tokens_system[i]; - batch.pos[i] = i; - batch.seq_id[i] = 0; - batch.logits[i] = false; + llama_batch_add(batch, tokens_system[i], i, { 0 }, false); } if (llama_decode(ctx, batch) != 0) @@ -726,8 +720,6 @@ struct llama_server_context bool processToken(completion_token_output & result, llama_client_slot & slot) { // remember which tokens were sampled - used for repetition penalties during sampling - slot.last_n_tokens.erase(slot.last_n_tokens.begin()); - slot.last_n_tokens.push_back(result.tok); const std::string token_str = llama_token_to_piece(ctx, result.tok); slot.sampled = result.tok; @@ -859,11 +851,12 @@ struct llama_server_context const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); llama_batch batch_view = { n_tokens, - batch.token + i, + batch.token + i, nullptr, - batch.pos + i, - batch.seq_id + i, - batch.logits + i, + batch.pos + i, + batch.n_seq_id + i, + batch.seq_id + i, + batch.logits + i, 0, 0, 0, // unused }; if (llama_decode(ctx, batch_view)) { @@ -878,8 +871,8 @@ struct llama_server_context if (n_eval > n_batch) { n_eval = n_batch; } - llama_batch batch = {int32_t(n_eval), nullptr, (img.image_embedding + i * n_embd), nullptr, nullptr, nullptr, slot.n_past, 1, 0, }; - if (llama_decode(ctx, batch)) { + llama_batch batch_img = {int32_t(n_eval), nullptr, (img.image_embedding + i * n_embd), nullptr, nullptr, nullptr, nullptr, slot.n_past, 1, 0, }; + if (llama_decode(ctx, batch_img)) { LOG_TEE("%s : failed to eval image\n", __func__); return false; } @@ -894,10 +887,7 @@ struct llama_server_context (json)(slot.images[image_idx].prefix_prompt); std::vector append_tokens = tokenize(json_prompt, false); // has next image for (int i = 0; i < append_tokens.size(); ++i) { - batch.token [batch.n_tokens] = append_tokens[i]; - batch.pos [batch.n_tokens] = slot.n_past; - batch.seq_id[batch.n_tokens] = slot.id; - batch.logits[batch.n_tokens] = false; + llama_batch_add(batch, append_tokens[i], slot.n_past, { slot.id }, true); slot.n_past += 1; batch.n_tokens += 1; } @@ -922,7 +912,6 @@ struct llama_server_context std::this_thread::sleep_for(std::chrono::milliseconds(5)); } - // context shift takes effect only when there is a single slot for(llama_client_slot &slot : slots) { if (slot.isProcessing() && slot.cache_tokens.size() >= (size_t)max_ctx_per_slot) { @@ -976,16 +965,12 @@ struct llama_server_context continue; } - batch.token [batch.n_tokens] = slot.sampled; - batch.pos [batch.n_tokens] = num_tokens_system + slot.n_past; - batch.seq_id[batch.n_tokens] = slot.id; - batch.logits[batch.n_tokens] = true; + slot.i_batch = batch.n_tokens; + + llama_batch_add(batch, slot.sampled, num_tokens_system + slot.n_past, { slot.id }, true); slot.n_decoded += 1; - slot.i_batch = batch.n_tokens; slot.n_past += 1; - - batch.n_tokens += 1; } // process in chunks of params.n_batch int32_t n_batch = params.n_batch; @@ -1026,7 +1011,7 @@ struct llama_server_context slot.num_prompt_tokens = prompt_tokens.size(); if(!slot.params.cache_prompt) { - std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end(), 0); + std::fill(slot.ctx_sampling->prev.begin(), slot.ctx_sampling->prev.end(), 0); slot.n_past = 0; slot.num_prompt_tokens_processed = slot.num_prompt_tokens; } else { @@ -1038,23 +1023,27 @@ struct llama_server_context //if input prompt is too big, truncate like normal if (slot.num_prompt_tokens >= (size_t)max_ctx_per_slot) { + // applied bug of #3661 const int n_left = max_ctx_per_slot - slot.params.n_keep; + const int n_block_size = n_left / 2; + const int erased_blocks = (slot.num_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size; std::vector new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + slot.params.n_keep); // Use half the left-over space in the context for the prompt - new_tokens.insert(new_tokens.end(), prompt_tokens.end() - n_left / 2, prompt_tokens.end()); + new_tokens.insert(new_tokens.end(), prompt_tokens.end() + slot.params.n_keep + erased_blocks * n_block_size, prompt_tokens.end()); LOG_VERBOSE("input truncated", { - {"n_ctx", n_ctx}, - {"n_keep", params.n_keep}, + {"n_ctx", max_ctx_per_slot}, + {"n_keep", slot.params.n_keep}, {"n_left", n_left}, {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())}, }); slot.truncated = true; prompt_tokens = new_tokens; slot.num_prompt_tokens = prompt_tokens.size(); + GGML_ASSERT(slot.num_prompt_tokens < (size_t)max_ctx_per_slot); } const size_t ps = slot.num_prompt_tokens; - std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end() - ps, 0); - std::copy(prompt_tokens.begin(), prompt_tokens.end(), slot.last_n_tokens.end() - ps); + std::fill(slot.ctx_sampling->prev.begin(), slot.ctx_sampling->prev.end() - ps, 0); + std::copy(prompt_tokens.begin(), prompt_tokens.end(), slot.ctx_sampling->prev.end() - ps); slot.n_past = common_part(slot.cache_tokens, prompt_tokens); slot.num_prompt_tokens_processed = slot.num_prompt_tokens - slot.n_past; LOG_TEE("slot %i - in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed); @@ -1081,11 +1070,7 @@ struct llama_server_context // process the prefix of first image std::vector prefix_tokens = ingest_images ? tokenize(slot.images[0].prefix_prompt, true) : prompt_tokens; for (; slot.n_past < prefix_tokens.size(); ++slot.n_past) { - batch.token [batch.n_tokens] = prefix_tokens[slot.n_past]; - batch.pos [batch.n_tokens] = slot.n_past + num_tokens_system; - batch.seq_id[batch.n_tokens] = slot.id; - batch.logits[batch.n_tokens] = false; - batch.n_tokens += 1; + llama_batch_add(batch, prefix_tokens[slot.n_past], num_tokens_system + slot.n_past, { slot.id }, false); } if(ingest_images && !ingestImages(slot, n_batch)) { @@ -1113,11 +1098,12 @@ struct llama_server_context const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); llama_batch batch_view = { n_tokens, - batch.token + i, + batch.token + i, nullptr, - batch.pos + i, - batch.seq_id + i, - batch.logits + i, + batch.pos + i, + batch.n_seq_id + i, + batch.seq_id + i, + batch.logits + i, 0, 0, 0, // unused }; @@ -1150,25 +1136,27 @@ struct llama_server_context } completion_token_output result; - const llama_token id = llama_sampling_sample(ctx, NULL, slot.ctx_sampling, slot.last_n_tokens, candidates, slot.i_batch - i); + const llama_token id = llama_sampling_sample(slot.ctx_sampling, ctx, NULL, slot.i_batch - i); + + llama_sampling_accept(slot.ctx_sampling, ctx, id); if (slot.n_decoded == 1) { slot.t_start_genereration = ggml_time_us(); slot.t_prompt_processing = (slot.t_start_genereration - slot.t_start_process_prompt) / 1e3; } - llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; + llama_token_data_array cur_p = { slot.ctx_sampling->cur.data(), slot.ctx_sampling->cur.size(), false }; result.tok = id; const int32_t n_probs = slot.sparams.n_probs; if (slot.sparams.temp <= 0 && n_probs > 0) { // For llama_sample_token_greedy we need to sort candidates - llama_sample_softmax(ctx, &candidates_p); + llama_sample_softmax(ctx, &cur_p); } - for (size_t i = 0; i < std::min(candidates_p.size, (size_t)n_probs); ++i) + for (size_t i = 0; i < std::min(cur_p.size, (size_t)n_probs); ++i) { - result.probs.push_back({candidates_p.data[i].id, candidates_p.data[i].p}); + result.probs.push_back({cur_p.data[i].id, cur_p.data[i].p}); } if (!processToken(result, slot)) {