From ab2fc002248f9098b02a8a015f5b5f6f03836cc2 Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Wed, 18 Oct 2023 16:57:48 -0400 Subject: [PATCH] latest changes of sampling API --- examples/server/server.cpp | 816 +++++++++++++++++++++++-------------- 1 file changed, 517 insertions(+), 299 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 1e2ae5ea8..8beb19983 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -3,6 +3,9 @@ #include "build-info.h" #include "grammar-parser.h" +#include "../llava/clip.h" +#include "stb_image.h" + #ifndef NDEBUG // crash the server in debug mode, otherwise send an http 500 error #define CPPHTTPLIB_NO_EXCEPTIONS 1 @@ -308,18 +311,14 @@ struct llama_client_slot int32_t n_remaining = -1; json prompt; - std::vector embd; - std::vector last_n_tokens; - - llama_model *model = nullptr; - llama_context *ctx = nullptr; - gpt_params params; - llama_sampling_context ctx_sampling; - int n_ctx; - - grammar_parser::parse_state parsed_grammar; - llama_grammar *grammar = nullptr; - + std::string generated_text = ""; + int num_tokens_predicted = 0; + llama_token sampled; + std::vector cache_tokens; + std::vector generated_token_probs; + int sent_tokens = 0; + slot_state state = IDLE; + slot_command command = NONE; bool truncated = false; bool stopped_eos = false; bool stopped_word = false; @@ -475,34 +474,22 @@ struct llama_server_context } } - bool loadModel(const gpt_params ¶ms_) - { - params.antiprompt.clear(); - params.grammar.clear(); - num_prompt_tokens = 0; - num_tokens_predicted = 0; - generated_text = ""; - generated_text.reserve(n_ctx); - generated_token_probs.clear(); - truncated = false; - stopped_eos = false; - stopped_word = false; - stopped_limit = false; - stopping_word = ""; - multibyte_pending = 0; - n_remain = 0; - n_past = 0; - - if (grammar != nullptr) { - llama_grammar_free(grammar); - grammar = nullptr; - ctx_sampling = llama_sampling_context_init(params, NULL); - } - } - bool loadModel(const gpt_params ¶ms_) { params = params_; + if(!params.mmproj.empty()) { + multimodal = true; + LOG_TEE("Multi Modal Mode Enabled"); + clp_ctx = clip_model_load(params.mmproj.c_str(), /*verbosity=*/ 1); + if(clp_ctx == nullptr) { + LOG_ERROR("unable to load clip model", {{"model", params.mmproj}}); + return false; + } + + if(params.n_ctx < 2048) { // request larger context for the image embedding + params.n_ctx = 2048; + } + } std::tie(model, ctx) = llama_init_from_gpt_params(params); if (model == nullptr) { @@ -521,8 +508,7 @@ struct llama_server_context } } n_ctx = llama_n_ctx(ctx); - last_n_tokens.resize(n_ctx); - std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); + n_vocab = llama_n_vocab(model); return true; } @@ -597,29 +583,24 @@ struct llama_server_context return prompt_tokens; } - bool loadGrammar() - { - if (!params.grammar.empty()) { - parsed_grammar = grammar_parser::parse(params.grammar.c_str()); - // will be empty (default) if there are parse errors - if (parsed_grammar.rules.empty()) { - LOG_ERROR("grammar parse error", {{"grammar", params.grammar}}); - return false; - } - grammar_parser::print_grammar(stderr, parsed_grammar); - + llama_client_slot* getSlot(int id) { + for (llama_client_slot & slot : slots) + { + if ((id == -1 && slot.available()) || slot.id == id) { - auto it = params.sampling_params.logit_bias.find(llama_token_eos(ctx)); - if (it != params.sampling_params.logit_bias.end() && it->second == -INFINITY) { - LOG_WARNING("EOS token is disabled, which will cause most grammars to fail", {}); - } + return &slot; } - - std::vector grammar_rules(parsed_grammar.c_rules()); - grammar = llama_grammar_init( - grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); } - ctx_sampling = llama_sampling_context_init(params, grammar); + return nullptr; + } + + bool launchSlot(llama_client_slot* &slot) { + if(!slot->loadGrammar(llama_token_eos(ctx))) { + return false; + } + all_slots_are_idle = false; + slot->command = LOAD_PROMPT; + LOG_TEE("slot %i is processing\n", slot->id); return true; } @@ -665,119 +646,253 @@ struct llama_server_context // release all slots for (llama_client_slot &slot : slots) { - params.n_keep = (int)num_prompt_tokens; + slot.release(); } - params.n_keep = std::min(params.n_ctx - 4, params.n_keep); + waitAllAreIdle(); + all_slots_are_idle = true; - // if input prompt is too big, truncate like normal - if (num_prompt_tokens >= (size_t)params.n_ctx) - { - printf("Input prompt is too big, truncating. Can only take %d tokens but got %zu\n", params.n_ctx, num_prompt_tokens); - // todo we probably want to cut from both sides - const int n_left = (params.n_ctx - params.n_keep) / 2; - std::vector new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep); - const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left; - new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end()); - std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), last_n_tokens.begin()); - - LOG_VERBOSE("input truncated", { - {"n_ctx", params.n_ctx}, - {"n_keep", params.n_keep}, - {"n_left", n_left}, - {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())}, - }); - - truncated = true; - prompt_tokens = new_tokens; + // wait until system prompt load + update_system_prompt = true; + while(update_system_prompt) { + std::this_thread::sleep_for(std::chrono::milliseconds(5)); } - else - { - const size_t ps = num_prompt_tokens; - std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0); - std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps); - } - - // compare the evaluated prompt with the new prompt - n_past = common_part(embd, prompt_tokens); - embd = prompt_tokens; - - if (n_past == num_prompt_tokens) - { - // we have to evaluate at least 1 token to generate logits. - printf("we have to evaluate at least 1 token to generate logits\n"); - n_past--; - } - - // since #3228 we now have to manually manage the KV cache - llama_kv_cache_seq_rm(ctx, 0, n_past, -1); - - LOG_VERBOSE("prompt ingested", { - {"n_past", n_past}, - {"cached", tokens_to_str(ctx, embd.cbegin(), embd.cbegin() + n_past)}, - {"to_eval", tokens_to_str(ctx, embd.cbegin() + n_past, embd.cend())}, - }); - - has_next_token = true; + // system prompt loaded, continue } - void loadPrompt() + + void processSystemPromptData(json sys_props) { + system_prompt = sys_props.value("prompt", ""); + user_name = sys_props.value("anti_prompt", ""); + assistant_name = sys_props.value("assistant_name", ""); + if(slots.size() > 0) { + notifySystemPromptChanged(); + } else { + update_system_prompt = true; + } + } + + void waitAllAreIdle() { + bool wait = true; + while(wait) { + wait = false; + for (auto &slot : slots) + { + if (!slot.available()) + { + wait = true; + break; + } + } + } + } + + size_t findStoppingStrings(const std::string &text, const size_t last_token_size, + const stop_type type, llama_client_slot &slot) { - auto prompt_tokens = tokenize(prompt, true); // always add BOS + size_t stop_pos = std::string::npos; + for (const std::string &word : slot.params.antiprompt) + { + size_t pos; + if (type == STOP_FULL) + { + const size_t tmp = word.size() + last_token_size; + const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0; + pos = text.find(word, from_pos); + } + else + { + pos = find_partial_stop_string(word, text); + } + if (pos != std::string::npos && + (stop_pos == std::string::npos || pos < stop_pos)) + { + if (type == STOP_FULL) + { + slot.stopped_word = true; + slot.stopping_word = word; + slot.has_next_token = false; + } + stop_pos = pos; - num_prompt_tokens = prompt_tokens.size(); + } + } + return stop_pos; + } - if (params.n_keep < 0) + bool processToken(completion_token_output & result, llama_client_slot & slot) { + // remember which tokens were sampled - used for repetition penalties during sampling + const std::string token_str = llama_token_to_piece(ctx, result.tok); + slot.sampled = result.tok; + + // search stop word and delete it + slot.generated_text += token_str; + slot.has_next_token = true; + + size_t pos = std::min(slot.sent_count, slot.generated_text.size()); + const std::string str_test = slot.generated_text.substr(pos); + bool is_stop_full = false; + size_t stop_pos = findStoppingStrings(str_test, token_str.size(), STOP_FULL, slot); + if (stop_pos != std::string::npos) { + is_stop_full = true; + slot.generated_text.erase( + slot.generated_text.begin() + pos + stop_pos, + slot.generated_text.end()); + pos = std::min(slot.sent_count, slot.generated_text.size()); + } else { + is_stop_full = false; + stop_pos = findStoppingStrings(str_test, token_str.size(), + STOP_PARTIAL, slot); + } + + // check if there is any token to predict + if(stop_pos == std::string::npos || (!slot.has_next_token && !is_stop_full && stop_pos > 0)) { + // no send the stop word in the response + result.text_to_send = slot.generated_text.substr(pos, std::string::npos); + slot.sent_count += result.text_to_send.size(); + // add the token to slot queue and cache + } + slot.addTokenString(result); + if (slot.multibyte_pending > 0) + { + slot.multibyte_pending -= token_str.size(); + } + else if (token_str.size() == 1) + { + const char c = token_str[0]; + // 2-byte characters: 110xxxxx 10xxxxxx + if ((c & 0xE0) == 0xC0) + { + slot.multibyte_pending = 1; + // 3-byte characters: 1110xxxx 10xxxxxx 10xxxxxx + } + else if ((c & 0xF0) == 0xE0) + { + slot.multibyte_pending = 2; + // 4-byte characters: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx + } + else if ((c & 0xF8) == 0xF0) + { + slot.multibyte_pending = 3; + } + else + { + slot.multibyte_pending = 0; + } + } + + if (slot.multibyte_pending > 0 && !slot.has_next_token) + { + slot.has_next_token = true; + } + + // check the limits + if ( + slot.n_decoded > 2 && slot.has_next_token && !slot.hasBudget(params)) { slot.stopped_limit = true; slot.has_next_token = false; } - params.n_keep = std::min(n_ctx - 4, params.n_keep); - // if input prompt is too big, truncate like normal - if (num_prompt_tokens >= (size_t)n_ctx) - { - const int n_left = (n_ctx - params.n_keep) / 2; - std::vector new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep); - const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left; - new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end()); - std::copy(prompt_tokens.end() - n_ctx, prompt_tokens.end(), last_n_tokens.begin()); - - LOG_VERBOSE("input truncated", { - {"n_ctx", n_ctx}, - {"n_keep", params.n_keep}, - {"n_left", n_left}, - {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())}, - }); - - truncated = true; - prompt_tokens = new_tokens; - } - else - { - const size_t ps = num_prompt_tokens; - std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0); - std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps); + if (!slot.cache_tokens.empty() && result.tok == llama_token_eos(ctx)){ + slot.stopped_eos = true; + slot.has_next_token = false; + LOG_VERBOSE("eos token found", {}); } - // compare the evaluated prompt with the new prompt - n_past = common_part(embd, prompt_tokens); + LOG_VERBOSE("next token", { + {"token", result.tok}, + {"token_text", tokens_to_output_formatted_string(ctx, result.tok)}, + {"has_next_token", slot.has_next_token}, + {"n_remain", slot.n_remaining}, + {"num_tokens_predicted", slot.num_tokens_predicted}, + {"stopped_eos", slot.stopped_eos}, + {"stopped_word", slot.stopped_word}, + {"stopped_limit", slot.stopped_limit}, + {"stopping_word", slot.stopping_word}, + }); + return slot.has_next_token; // continue + } - embd = prompt_tokens; - if (n_past == num_prompt_tokens) - { - // we have to evaluate at least 1 token to generate logits. - n_past--; + bool processImages(llama_client_slot &slot) { + for(slot_image &img : slot.images) { + if(!img.request_encode_image) { + continue; + } + clip_image_f32 img_res; + if (!clip_image_preprocess(clp_ctx, &img.img_data, &img_res, /*pad2square =*/ true)) { + LOG_TEE("Error processing the given image"); + clip_free(clp_ctx); + return false; + } + img.image_tokens = clip_n_patches(clp_ctx); + img.image_embedding = (float *)malloc(clip_embd_nbytes(clp_ctx)); + if (!img.image_embedding) { + LOG_TEE("Unable to allocate memory for image embeddings\n"); + clip_free(clp_ctx); + return false; + } + LOG_TEE("slot %i - encoding image %i\n", slot.id, img.id); + if (!clip_image_encode(clp_ctx, params.n_threads, &img_res, img.image_embedding)) { + LOG_TEE("Unable to encode image\n"); + return false; + } + img.request_encode_image = false; } + return slot.images.size() > 0; + } - // since #3228 we now have to manually manage the KV cache - llama_kv_cache_seq_rm(ctx, 0, n_past, -1); + // for multiple images processing + bool ingestImages(llama_client_slot &slot, int n_batch) { + int image_idx = 0; + while(image_idx < slot.images.size()) { + slot_image img = slot.images[image_idx]; - LOG_VERBOSE("prompt ingested", { - {"n_past", n_past}, - {"cached", tokens_to_str(ctx, embd.cbegin(), embd.cbegin() + n_past)}, - {"to_eval", tokens_to_str(ctx, embd.cbegin() + n_past, embd.cend())}, - }); + // process prefix prompt + for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { + const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); + llama_batch batch_view = { + n_tokens, + batch.token + i, + nullptr, + batch.pos + i, + batch.n_seq_id + i, + batch.seq_id + i, + batch.logits + i, + 0, 0, 0, // unused + }; + if (llama_decode(ctx, batch_view)) { + LOG_TEE("%s : failed to eval\n", __func__); + return false; + } + } - has_next_token = true; + // process image with llm + for (int i = 0; i < img.image_tokens; i += n_batch) { + int n_eval = img.image_tokens - i; + if (n_eval > n_batch) { + n_eval = n_batch; + } + llama_batch batch_img = {int32_t(n_eval), nullptr, (img.image_embedding + i * n_embd), nullptr, nullptr, nullptr, nullptr, slot.n_past, 1, 0, }; + if (llama_decode(ctx, batch_img)) { + LOG_TEE("%s : failed to eval image\n", __func__); + return false; + } + slot.n_past += n_eval; + } + image_idx++; + + // append prefix of next image + batch.n_tokens = 0; + const auto json_prompt = (image_idx >= slot.images.size()) ? + slot.params.input_suffix : // no more images, then process suffix prompt + (json)(slot.images[image_idx].prefix_prompt); + std::vector append_tokens = tokenize(json_prompt, false); // has next image + for (int i = 0; i < append_tokens.size(); ++i) { + llama_batch_add(batch, append_tokens[i], slot.n_past, { slot.id }, true); + slot.n_past += 1; + batch.n_tokens += 1; + } + } + return true; } bool updateSlots() { @@ -814,159 +929,198 @@ struct llama_server_context slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard); - truncated = true; - LOG_VERBOSE("input truncated", { - {"n_ctx", n_ctx}, - {"n_keep", params.n_keep}, - {"n_left", n_left}, - }); - } + slot.n_past -= n_discard; - bool tg = true; - while (n_past < embd.size()) - { - int n_eval = (int)embd.size() - n_past; - tg = n_eval == 1; - if (n_eval > params.n_batch) - { - n_eval = params.n_batch; - } + slot.truncated = true; - if (llama_decode(ctx, llama_batch_get_one(&embd[n_past], n_eval, n_past, 0))) - { - LOG_ERROR("failed to eval", { - {"n_eval", n_eval}, - {"n_past", n_past}, - {"embd", tokens_to_str(ctx, embd.cbegin() + n_past, embd.cend())}, + LOG_VERBOSE("input truncated", { + {"n_ctx", n_ctx}, + {"n_keep", params.n_keep}, + {"n_left", n_left}, }); - has_next_token = false; - return result; - } - n_past += n_eval; - } - - if (params.n_predict == 0) - { - has_next_token = false; - result.tok = llama_token_eos(ctx); - return result; - } - - { - // out of user input, sample next token - std::vector candidates; - candidates.reserve(llama_n_vocab(model)); - - result.tok = llama_sampling_sample(ctx, NULL, ctx_sampling, last_n_tokens, candidates); - - llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; - - const int32_t n_probs = params.sampling_params.n_probs; - if (params.sampling_params.temp <= 0 && n_probs > 0) - { - // For llama_sample_token_greedy we need to sort candidates - llama_sample_softmax(ctx, &candidates_p); - } - - for (size_t i = 0; i < std::min(candidates_p.size, (size_t)n_probs); ++i) - { - result.probs.push_back({candidates_p.data[i].id, candidates_p.data[i].p}); - } - - last_n_tokens.erase(last_n_tokens.begin()); - last_n_tokens.push_back(result.tok); - if (tg) { - num_tokens_predicted++; } } - // add it to the context - embd.push_back(result.tok); - // decrement remaining sampling budget - --n_remain; - - if (!embd.empty() && embd.back() == llama_token_eos(ctx)) - { - // stopping_word = llama_token_to_piece(ctx, embd.back()); - has_next_token = false; - stopped_eos = true; - LOG_VERBOSE("eos token found", {}); - return result; - } - - has_next_token = params.n_predict == -1 || n_remain != 0; - return result; - } - - size_t findStoppingStrings(const std::string &text, const size_t last_token_size, - const stop_type type) - { - size_t stop_pos = std::string::npos; - for (const std::string &word : params.antiprompt) - { - size_t pos; - if (type == STOP_FULL) + // decode any currently ongoing sequences + for (auto & slot : slots) { + // release the slot + if (slot.state == PROCESSING && slot.command == RELEASE && !slot.hasNewToken()) { - const size_t tmp = word.size() + last_token_size; - const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0; - pos = text.find(word, from_pos); - } - else - { - pos = find_partial_stop_string(word, text); - } - if (pos != std::string::npos && - (stop_pos == std::string::npos || pos < stop_pos)) - { - if (type == STOP_FULL) - { - stopping_word = word; - stopped_word = true; - has_next_token = false; + slot.state = slot.params.cache_prompt ? SLEEPING : IDLE; + if(slot.state == SLEEPING) { + LOG_TEE("slot %i has %i tokens in cache.\n", slot.id, slot.cache_tokens.size()); + } else { + LOG_TEE("slot %i released\n", slot.id); + } + slot.command = NONE; + continue; + } + + kv_cache_free -= slot.num_prompt_tokens; + + if ( + slot.state == IDLE || + slot.state == SLEEPING || + slot.command == RELEASE) { + continue; + } + + slot.i_batch = batch.n_tokens; + + llama_batch_add(batch, slot.sampled, num_tokens_system + slot.n_past, { slot.id }, true); + + slot.n_decoded += 1; + slot.n_past += 1; + } + // process in chunks of params.n_batch + int32_t n_batch = params.n_batch; + + // assign workload to the slots + if (params.cont_batching || batch.n_tokens == 0) { + for (auto & slot : slots) { + // need process the prompt + if ((slot.state == IDLE || slot.state == SLEEPING) && slot.command == LOAD_PROMPT) { + slot.state = PROCESSING; + slot.command = NONE; + std::vector prompt_tokens; + slot.t_start_process_prompt = ggml_time_us(); + slot.t_start_genereration = 0; + + if(slot.infill) { + bool suff_rm_leading_spc = true; + if (params.input_suffix.find_first_of(" ") == 0 && params.input_suffix.size() > 1) { + params.input_suffix.erase(0, 1); + suff_rm_leading_spc = false; + } + auto prefix_tokens = tokenize(slot.params.input_prefix, false); + auto suffix_tokens = tokenize(slot.params.input_suffix, false); + const int space_token = 29871; + if (suff_rm_leading_spc && suffix_tokens[0] == space_token) { + suffix_tokens.erase(suffix_tokens.begin()); + } + prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(ctx)); + prefix_tokens.insert(prefix_tokens.begin(), llama_token_bos(ctx)); // always add BOS + prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(ctx)); + prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end()); + prefix_tokens.push_back(llama_token_middle(ctx)); + prompt_tokens = prefix_tokens; + } else { + prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt + } + + slot.num_prompt_tokens = prompt_tokens.size(); + + if(!slot.params.cache_prompt) { + std::fill(slot.ctx_sampling->prev.begin(), slot.ctx_sampling->prev.end(), 0); + slot.n_past = 0; + slot.num_prompt_tokens_processed = slot.num_prompt_tokens; + } else { + if (slot.params.n_keep < 0) + { + slot.params.n_keep = (int)slot.num_prompt_tokens; + } + slot.params.n_keep = std::min(max_ctx_per_slot - 4, slot.params.n_keep); + //if input prompt is too big, truncate like normal + if (slot.num_prompt_tokens >= (size_t)max_ctx_per_slot) + { + // applied bug of #3661 + const int n_left = max_ctx_per_slot - slot.params.n_keep; + const int n_block_size = n_left / 2; + const int erased_blocks = (slot.num_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size; + std::vector new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + slot.params.n_keep); + // Use half the left-over space in the context for the prompt + new_tokens.insert(new_tokens.end(), prompt_tokens.end() + slot.params.n_keep + erased_blocks * n_block_size, prompt_tokens.end()); + LOG_VERBOSE("input truncated", { + {"n_ctx", max_ctx_per_slot}, + {"n_keep", slot.params.n_keep}, + {"n_left", n_left}, + {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())}, + }); + slot.truncated = true; + prompt_tokens = new_tokens; + slot.num_prompt_tokens = prompt_tokens.size(); + GGML_ASSERT(slot.num_prompt_tokens < (size_t)max_ctx_per_slot); + } + const size_t ps = slot.num_prompt_tokens; + std::fill(slot.ctx_sampling->prev.begin(), slot.ctx_sampling->prev.end() - ps, 0); + std::copy(prompt_tokens.begin(), prompt_tokens.end(), slot.ctx_sampling->prev.end() - ps); + slot.n_past = common_part(slot.cache_tokens, prompt_tokens); + slot.num_prompt_tokens_processed = slot.num_prompt_tokens - slot.n_past; + LOG_TEE("slot %i - in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed); + } + + llama_kv_cache_seq_rm(ctx, slot.id, num_tokens_system + slot.n_past, -1); + + slot.cache_tokens = prompt_tokens; + + if (slot.n_past == slot.num_prompt_tokens) { + // we have to evaluate at least 1 token to generate logits. + printf("we have to evaluate at least 1 token to generate logits\n"); + slot.n_past--; + } + + LOG_VERBOSE("prompt ingested", { + {"n_past", slot.n_past}, + {"cached", tokens_to_str(ctx, slot.cache_tokens.cbegin(), slot.cache_tokens.cbegin() + slot.n_past)}, + {"to_eval", tokens_to_str(ctx, slot.cache_tokens.cbegin() + slot.n_past, slot.cache_tokens.cend())}, + }); + + bool ingest_images = processImages(slot); // has images? + + // process the prefix of first image + std::vector prefix_tokens = ingest_images ? tokenize(slot.images[0].prefix_prompt, true) : prompt_tokens; + for (; slot.n_past < prefix_tokens.size(); ++slot.n_past) { + llama_batch_add(batch, prefix_tokens[slot.n_past], num_tokens_system + slot.n_past, { slot.id }, false); + } + + if(ingest_images && !ingestImages(slot, n_batch)) { + LOG_TEE("failed processing images\n"); + return false; + } + + // extract the logits only for the last token + if (batch.n_tokens > 0) { + batch.logits[batch.n_tokens - 1] = true; + } + + slot.n_decoded = 0; + slot.i_batch = batch.n_tokens - 1; } - stop_pos = pos; } } - return stop_pos; - } - completion_token_output doCompletion() - { - auto token_with_probs = nextToken(); - - const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_piece(ctx, token_with_probs.tok); - generated_text += token_text; - - if (params.sampling_params.n_probs > 0) - { - generated_token_probs.push_back(token_with_probs); + if (batch.n_tokens == 0) { + all_slots_are_idle = true; + return true; } - if (multibyte_pending > 0) - { - multibyte_pending -= token_text.size(); - } - else if (token_text.size() == 1) - { - const char c = token_text[0]; - // 2-byte characters: 110xxxxx 10xxxxxx - if ((c & 0xE0) == 0xC0) - { - multibyte_pending = 1; - // 3-byte characters: 1110xxxx 10xxxxxx 10xxxxxx - } - else if ((c & 0xF0) == 0xE0) - { - multibyte_pending = 2; - // 4-byte characters: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx - } - else if ((c & 0xF8) == 0xF0) - { - multibyte_pending = 3; - } - else - { - multibyte_pending = 0; + for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { + const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); + llama_batch batch_view = { + n_tokens, + batch.token + i, + nullptr, + batch.pos + i, + batch.n_seq_id + i, + batch.seq_id + i, + batch.logits + i, + 0, 0, 0, // unused + }; + + const int ret = llama_decode(ctx, batch_view); + if (ret != 0) { + if (n_batch == 1 || ret < 0) { + // if you get here, it means the KV cache is full - try increasing it via the context size + LOG_TEE("%s : failed to decode the batch, n_batch = %d, ret = %d\n", __func__, n_batch, ret); + return false; + } + + LOG("%s : failed to decode the batch, retrying with n_batch = %d\n", __func__, n_batch / 2); + + // retry with half the batch size to try to find a free slot in the KV cache + n_batch /= 2; + i -= n_batch; + continue; } for (auto & slot : slots) { @@ -1625,9 +1779,77 @@ static void parse_options_completion(const json &body, llama_client_slot* slot, } } - llama.ctx_sampling = llama_sampling_context_init(llama.params, llama.grammar); + LOG_VERBOSE("completion parameters parsed", format_generation_settings(llama, slot)); + if(!llama.multimodal) { + return; + } - LOG_VERBOSE("completion parameters parsed", format_generation_settings(llama)); + const auto &images_data = body.find("image_data"); + if (images_data != body.end() && images_data->is_array()) + { + for (const auto &img : *images_data) + { + slot_image img_sl; + std::string data_b64 = img["data"].get(); + img_sl.id = img.count("id") != 0 ? img["id"].get() : slot->images.size(); + int width, height, channels; + std::vector image_buffer = base64_decode(data_b64); + data_b64.clear(); + auto data = stbi_load_from_memory(image_buffer.data(), image_buffer.size(), &width, &height, &channels, 3); + if(!data) { + LOG_TEE("slot %i - failed to load image id= %i\n", slot->id, img_sl.id); + return; + } + LOG_TEE("slot %i - image id= %i loaded (%i x %i)\n", slot->id, img_sl.id, width, height); + img_sl.img_data.nx = width; + img_sl.img_data.ny = height; + img_sl.img_data.size = width * height * 3; + img_sl.img_data.data = new uint8_t[width * height * 3](); + memcpy(img_sl.img_data.data, data, width * height * 3); + stbi_image_free(data); + img_sl.request_encode_image = true; + slot->images.push_back(img_sl); + } + // process prompt + // example: system prompt [img-102] user [img-103] describe [img-134] -> [{id: 102, prefix: 'system prompt '}, {id: 103, prefix: ' user '}, {id: 134, prefix: ' describe '}]} + if(slot->images.size() > 0 && !slot->prompt.is_array()) { + std::string prompt = slot->prompt.get(); + size_t pos = 0, begin_prefix = 0; + std::string pattern = "[img-"; + while ((pos = prompt.find(pattern, pos)) != std::string::npos) { + size_t end_prefix = pos; + pos += pattern.length(); + size_t end_pos = prompt.find("]", pos); + if (end_pos != std::string::npos) { + std::string image_id = prompt.substr(pos, end_pos - pos); + try { + int img_id = std::stoi(image_id); + bool found = false; + for(slot_image &img : slot->images) { + if(img.id == img_id) { + found = true; + img.prefix_prompt = prompt.substr(begin_prefix, end_prefix - begin_prefix); + begin_prefix = end_pos + 1; + break; + } + } + if(!found) { + LOG_TEE("ERROR: Image with id %i not found.\n", img_id); + slot->images.clear(); + return; + } + } catch (const std::invalid_argument& e) { + LOG_TEE("Invalid image number id in prompt\n"); + slot->images.clear(); + return; + } + } + } + slot->prompt = ""; + slot->params.input_suffix = prompt.substr(begin_prefix); + slot->params.cache_prompt = false; // multimodal doesn't support cache prompt + } + } } static void parse_options_infill(const json &body, llama_server_context &llama, llama_client_slot *slot) @@ -2144,10 +2366,6 @@ int main(int argc, char **argv) { return 1; } - - if (llama.grammar != nullptr) { - llama_grammar_free(llama.grammar); - } llama_backend_free(); return 0; }