From 9d98cdda2cf95cefcdfd84347645c2bc2ba379ad Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Fri, 13 Oct 2023 18:42:44 -0400 Subject: [PATCH] llava multimodal integration --- examples/server/CMakeLists.txt | 2 +- examples/server/server.cpp | 312 ++++++++++++++++++++++++++++----- 2 files changed, 274 insertions(+), 40 deletions(-) diff --git a/examples/server/CMakeLists.txt b/examples/server/CMakeLists.txt index 3782f9b80..a23ddcc55 100644 --- a/examples/server/CMakeLists.txt +++ b/examples/server/CMakeLists.txt @@ -6,7 +6,7 @@ install(TARGETS ${TARGET} RUNTIME) target_compile_definitions(${TARGET} PRIVATE SERVER_VERBOSE=$ ) -target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_link_libraries(${TARGET} PRIVATE common llama clip ${CMAKE_THREAD_LIBS_INIT}) if (WIN32) TARGET_LINK_LIBRARIES(${TARGET} PRIVATE ws2_32) endif() diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 0741ebf11..f3e6b6e39 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -3,6 +3,13 @@ #include "build-info.h" #include "grammar-parser.h" +// #define SERVER_MULTIMODAL_SUPPORT + +#ifdef SERVER_MULTIMODAL_SUPPORT +#include "../llava/clip.h" +#include "stb_image.h" +#endif + #ifndef NDEBUG // crash the server in debug mode, otherwise send an http 500 error #define CPPHTTPLIB_NO_EXCEPTIONS 1 @@ -61,6 +68,57 @@ static bool server_verbose = false; #define LOG_WARNING(MSG, ...) server_log("WARNING", __func__, __LINE__, MSG, __VA_ARGS__) #define LOG_INFO(MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__) +#ifdef SERVER_MULTIMODAL_SUPPORT +static const std::string base64_chars = + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz" + "0123456789+/"; + + +static inline bool is_base64(uint8_t c) { + return (isalnum(c) || (c == '+') || (c == '/')); +} + +std::vector base64_decode(std::string const& encoded_string) { + int in_len = encoded_string.size(); + int i = 0; + int j = 0; + int in_ = 0; + BYTE char_array_4[4], char_array_3[3]; + std::vector ret; + while (in_len-- && ( encoded_string[in_] != '=') && is_base64(encoded_string[in_])) { + char_array_4[i++] = encoded_string[in_]; in_++; + if (i ==4) { + for (i = 0; i <4; i++) + char_array_4[i] = base64_chars.find(char_array_4[i]); + + char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4); + char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); + char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + + for (i = 0; (i < 3); i++) + ret.push_back(char_array_3[i]); + i = 0; + } + } + + if (i) { + for (j = i; j <4; j++) + char_array_4[j] = 0; + + for (j = 0; j <4; j++) + char_array_4[j] = base64_chars.find(char_array_4[j]); + + char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4); + char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); + char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + + for (j = 0; (j < i - 1); j++) ret.push_back(char_array_3[j]); + } + + return ret; +} +#endif // parallel enum slot_state @@ -258,6 +316,13 @@ struct llama_client_slot grammar_parser::parse_state parsed_grammar; llama_grammar *grammar = nullptr; +#ifdef SERVER_MULTIMODAL_SUPPORT + clip_image_u8 img_data; + bool request_encode_image = false; + float* image_embedding = nullptr; + int image_tokens = 0; +#endif + void reset() { num_prompt_tokens = 0; generated_text = ""; @@ -375,6 +440,12 @@ struct llama_server_context std::string user_name = ""; // this should be the anti prompt std::string assistant_name = ""; // this is for generate the prompt +#ifdef SERVER_MULTIMODAL_SUPPORT + bool multimodal = false; + clip_ctx *clp_ctx = nullptr; + int n_embd; +#endif + llama_model *model = nullptr; llama_context *ctx = nullptr; llama_batch batch; @@ -407,12 +478,40 @@ struct llama_server_context bool loadModel(const gpt_params ¶ms_) { params = params_; +#ifdef SERVER_MULTIMODAL_SUPPORT + if(!params.mmproj.empty()) { + multimodal = true; + LOG_TEE("Multi Modal Mode Enabled"); + clp_ctx = clip_model_load(params.mmproj.c_str(), /*verbosity=*/ 1); + if(clp_ctx == nullptr) { + LOG_ERROR("unable to load clip model", {{"model", params.mmproj}}); + return false; + } + + if(params.n_ctx < 2048) { // request larger context for the image embedding + params.n_ctx = 2048; + } + } +#endif std::tie(model, ctx) = llama_init_from_gpt_params(params); if (model == nullptr) { - LOG_ERROR("unable to load model", {{"model", params_.model}}); + LOG_ERROR("unable to load model", {{"model", params.model}}); return false; } + +#ifdef SERVER_MULTIMODAL_SUPPORT + if(multimodal) { + int n_img_embd = clip_n_mmproj_embd(clp_ctx); + n_embd = llama_n_embd(model); + if (n_img_embd != n_embd) { + LOG_TEE("%s: embedding dim of the multimodal projector (%d) is not equal to that of LLaMA (%d). Make sure that you use the correct mmproj file.\n", __func__, n_img_embd, n_embd); + llama_free(ctx); + llama_free_model(model); + return false; + } + } +#endif n_ctx = llama_n_ctx(ctx); n_vocab = llama_n_vocab(model); candidates.reserve(n_vocab); @@ -729,6 +828,37 @@ struct llama_server_context std::this_thread::sleep_for(std::chrono::milliseconds(5)); } + // context shift takes effect only when there is a single slot + if(slots.size() == 1) { + llama_client_slot slot = slots[0]; + if (slot.isProcessing() && slot.cache_tokens.size() >= (size_t)n_ctx) + { + // Shift context + const int n_left = slot.n_past - params.n_keep - 1; + const int n_discard = n_left / 2; + + llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1); + llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, slot.n_past, -n_discard); + + for (size_t i = params.n_keep + 1 + n_discard; i < slot.cache_tokens.size(); i++) + { + slot.cache_tokens[i - n_discard] = slot.cache_tokens[i]; + } + + slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard); + + slot.n_past -= n_discard; + + slot.truncated = true; + + LOG_VERBOSE("input truncated", { + {"n_ctx", n_ctx}, + {"n_keep", params.n_keep}, + {"n_left", n_left}, + }); + } + } + // decode any currently ongoing sequences for (auto & slot : slots) { // release the slot @@ -764,6 +894,8 @@ struct llama_server_context batch.n_tokens += 1; } + // process in chunks of params.n_batch + int32_t n_batch = params.n_batch; // assign workload to the slots if (params.cont_batching || batch.n_tokens == 0) { @@ -772,9 +904,35 @@ struct llama_server_context if ((slot.state == IDLE || slot.state == SLEEPING) && slot.command == LOAD_PROMPT) { slot.state = PROCESSING; slot.command = NONE; + std::vector prompt_tokens; +#ifdef SERVER_MULTIMODAL_SUPPORT + bool ingest_image = false; + if(slot.request_encode_image) { + ingest_image = true; + clip_image_f32 img_res; + if (!clip_image_preprocess(clp_ctx, &slot.img_data, &img_res, /*pad2square =*/ true)) { + LOG_TEE("Error processing the given image"); + clip_free(clp_ctx); + return false; + } + slot.image_tokens = clip_n_patches(clp_ctx); + slot.image_embedding = (float *)malloc(clip_embd_nbytes(clp_ctx)); + if (!slot.image_embedding) { + LOG_TEE("Unable to allocate memory for image embeddings\n"); + clip_free(clp_ctx); + return false; + } + LOG_TEE("slot %i - encoding image\n", slot.id); + if (!clip_image_encode(clp_ctx, params.n_threads, &img_res, slot.image_embedding)) { + LOG_TEE("Unable to encode image\n"); + return false; + } + slot.request_encode_image = false; + } +#endif slot.t_start_process_prompt = ggml_time_us(); slot.t_start_genereration = 0; - std::vector prompt_tokens; + if(slot.infill) { bool suff_rm_leading_spc = true; if (params.input_suffix.find_first_of(" ") == 0 && params.input_suffix.size() > 1) { @@ -809,7 +967,6 @@ struct llama_server_context slot.n_past--; } - slot.num_prompt_tokens_processed = slot.num_prompt_tokens - slot.n_past; if(!slot.params.cache_prompt) { @@ -848,6 +1005,68 @@ struct llama_server_context {"cached", tokens_to_str(ctx, slot.cache_tokens.cbegin(), slot.cache_tokens.cbegin() + slot.n_past)}, {"to_eval", tokens_to_str(ctx, slot.cache_tokens.cbegin() + slot.n_past, slot.cache_tokens.cend())}, }); + +#ifdef SERVER_MULTIMODAL_SUPPORT + std::vector preffix_tokens = ingest_image ? tokenize(slot.params.input_prefix, true) : prompt_tokens; + for (; slot.n_past < preffix_tokens.size(); ++slot.n_past) { + printf(llama_token_to_piece(ctx, preffix_tokens[slot.n_past]).c_str()); + batch.token [batch.n_tokens] = preffix_tokens[slot.n_past]; + batch.pos [batch.n_tokens] = slot.n_past + num_tokens_system; + batch.seq_id[batch.n_tokens] = slot.id; + batch.logits[batch.n_tokens] = false; + batch.n_tokens += 1; + } + if(ingest_image) { + // process preffix prompt + for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { + const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); + llama_batch batch_view = { + n_tokens, + batch.token + i, + nullptr, + batch.pos + i, + batch.seq_id + i, + batch.logits + i, + 0, 0, 0, // unused + }; + if (llama_decode(ctx, batch_view)) { + LOG_TEE("%s : failed to eval\n", __func__); + return false; + } + } + + printf("\nEvaluated preffix prompt: %i\n", slot.n_past); + + // process image + for (int i = 0; i < slot.image_tokens; i += n_batch) { + int n_eval = slot.image_tokens - i; + if (n_eval > n_batch) { + n_eval = n_batch; + } + llama_batch batch = {int32_t(n_eval), nullptr, (slot.image_embedding + i * n_embd), nullptr, nullptr, nullptr, slot.n_past, 1, 0, }; + if (llama_decode(ctx, batch)) { + LOG_TEE("%s : failed to eval image\n", __func__); + return false; + } + slot.n_past += n_eval; + } + printf("Evaluated image embedding: %i\n", slot.n_past); + + // process suffix prompt + batch.n_tokens = 0; + std::vector suffix_tokens = tokenize(slot.params.input_suffix, true); + for (int i = 0; i < suffix_tokens.size(); ++i) { + printf(llama_token_to_piece(ctx, suffix_tokens[i]).c_str()); + batch.token [batch.n_tokens] = suffix_tokens[i]; + batch.pos [batch.n_tokens] = slot.n_past; + batch.seq_id[batch.n_tokens] = slot.id; + batch.logits[batch.n_tokens] = false; + slot.n_past += 1; + batch.n_tokens += 1; + } + printf("\nEvaluated suffix prompt: %i\n", slot.n_past); + } +#else for (; slot.n_past < prompt_tokens.size(); ++slot.n_past) { batch.token [batch.n_tokens] = prompt_tokens[slot.n_past]; batch.pos [batch.n_tokens] = slot.n_past + num_tokens_system; @@ -855,7 +1074,7 @@ struct llama_server_context batch.logits[batch.n_tokens] = false; batch.n_tokens += 1; } - +#endif // extract the logits only for the last token if (batch.n_tokens > 0) { batch.logits[batch.n_tokens - 1] = true; @@ -872,40 +1091,6 @@ struct llama_server_context return true; } - // context shift takes effect only when there is a single slot - if(slots.size() == 1) { - llama_client_slot slot = slots[0]; - if (slot.cache_tokens.size() >= (size_t)n_ctx) - { - // Shift context - const int n_left = slot.n_past - params.n_keep - 1; - const int n_discard = n_left / 2; - - llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1); - llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, slot.n_past, -n_discard); - - for (size_t i = params.n_keep + 1 + n_discard; i < slot.cache_tokens.size(); i++) - { - slot.cache_tokens[i - n_discard] = slot.cache_tokens[i]; - } - - slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard); - - slot.n_past -= n_discard; - - slot.truncated = true; - - LOG_VERBOSE("input truncated", { - {"n_ctx", n_ctx}, - {"n_keep", params.n_keep}, - {"n_left", n_left}, - }); - } - } - - // process in chunks of params.n_batch - int32_t n_batch = params.n_batch; - for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); llama_batch batch_view = { @@ -1045,6 +1230,11 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms, printf(" --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str()); printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout); printf(" --embedding enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled"); + printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel); + printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n"); +#ifdef SERVER_MULTIMODAL_SUPPORT + printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA.\n"); +#endif printf("\n"); } @@ -1315,6 +1505,16 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, } params.n_predict = std::stoi(argv[i]); } +#ifdef SERVER_MULTIMODAL_SUPPORT + else if(arg == "--mmproj") { + if (++i >= argc) + { + invalid_param = true; + break; + } + params.mmproj = argv[i]; + } +#endif else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); @@ -1531,8 +1731,42 @@ static void parse_options_completion(const json &body, llama_client_slot* slot, } } } - LOG_VERBOSE("completion parameters parsed", format_generation_settings(llama, slot)); +#ifdef SERVER_MULTIMODAL_SUPPORT + if(!llama.multimodal) { + return; + } + + std::string data_b64 = json_value(body, "image_data", std::string("")); + if(!data_b64.empty()) { + if(!slot->prompt.is_array()) { + std::string prompt = slot->prompt.get(); + int pos = prompt.find("[(img)]"); + if(pos == std::string::npos) { + LOG_TEE("Missing image position in prompt\n"); + return; + } else { + // reuse infill prompt input + slot->params.input_prefix = prompt.substr(0, pos); + slot->params.input_suffix = prompt.substr(pos + 7); // ignore [(img)] + slot->params.cache_prompt = false; // multimodal doesn't support cache prompt + } + } + int width, height, channels; + std::vector image_buffer = base64_decode(data_b64); + data_b64.clear(); + // decode base64 + auto data = stbi_load_from_memory(image_buffer.data(), image_buffer.size(), &width, &height, &channels, 3); + slot->img_data.nx = width; + slot->img_data.ny = height; + slot->img_data.size = width * height * 3; + slot->img_data.data = new uint8_t[slot->img_data.size](); + memcpy(slot->img_data.data, data, slot->img_data.size); + stbi_image_free(data); + LOG_TEE("slot %i - RGB image loaded (%i x %i)\n", slot->id, width, height); + slot->request_encode_image = true; + } +#endif } static void parse_options_infill(const json &body, llama_server_context &llama, llama_client_slot *slot)