llava multimodal integration

This commit is contained in:
FSSRepo 2023-10-13 18:42:44 -04:00
parent eb08201227
commit 9d98cdda2c
2 changed files with 274 additions and 40 deletions

View File

@ -6,7 +6,7 @@ install(TARGETS ${TARGET} RUNTIME)
target_compile_definitions(${TARGET} PRIVATE
SERVER_VERBOSE=$<BOOL:${LLAMA_SERVER_VERBOSE}>
)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_link_libraries(${TARGET} PRIVATE common llama clip ${CMAKE_THREAD_LIBS_INIT})
if (WIN32)
TARGET_LINK_LIBRARIES(${TARGET} PRIVATE ws2_32)
endif()

View File

@ -3,6 +3,13 @@
#include "build-info.h"
#include "grammar-parser.h"
// #define SERVER_MULTIMODAL_SUPPORT
#ifdef SERVER_MULTIMODAL_SUPPORT
#include "../llava/clip.h"
#include "stb_image.h"
#endif
#ifndef NDEBUG
// crash the server in debug mode, otherwise send an http 500 error
#define CPPHTTPLIB_NO_EXCEPTIONS 1
@ -61,6 +68,57 @@ static bool server_verbose = false;
#define LOG_WARNING(MSG, ...) server_log("WARNING", __func__, __LINE__, MSG, __VA_ARGS__)
#define LOG_INFO(MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__)
#ifdef SERVER_MULTIMODAL_SUPPORT
static const std::string base64_chars =
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
"abcdefghijklmnopqrstuvwxyz"
"0123456789+/";
static inline bool is_base64(uint8_t c) {
return (isalnum(c) || (c == '+') || (c == '/'));
}
std::vector<uint8_t> base64_decode(std::string const& encoded_string) {
int in_len = encoded_string.size();
int i = 0;
int j = 0;
int in_ = 0;
BYTE char_array_4[4], char_array_3[3];
std::vector<uint8_t> ret;
while (in_len-- && ( encoded_string[in_] != '=') && is_base64(encoded_string[in_])) {
char_array_4[i++] = encoded_string[in_]; in_++;
if (i ==4) {
for (i = 0; i <4; i++)
char_array_4[i] = base64_chars.find(char_array_4[i]);
char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4);
char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
for (i = 0; (i < 3); i++)
ret.push_back(char_array_3[i]);
i = 0;
}
}
if (i) {
for (j = i; j <4; j++)
char_array_4[j] = 0;
for (j = 0; j <4; j++)
char_array_4[j] = base64_chars.find(char_array_4[j]);
char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4);
char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
for (j = 0; (j < i - 1); j++) ret.push_back(char_array_3[j]);
}
return ret;
}
#endif
// parallel
enum slot_state
@ -258,6 +316,13 @@ struct llama_client_slot
grammar_parser::parse_state parsed_grammar;
llama_grammar *grammar = nullptr;
#ifdef SERVER_MULTIMODAL_SUPPORT
clip_image_u8 img_data;
bool request_encode_image = false;
float* image_embedding = nullptr;
int image_tokens = 0;
#endif
void reset() {
num_prompt_tokens = 0;
generated_text = "";
@ -375,6 +440,12 @@ struct llama_server_context
std::string user_name = ""; // this should be the anti prompt
std::string assistant_name = ""; // this is for generate the prompt
#ifdef SERVER_MULTIMODAL_SUPPORT
bool multimodal = false;
clip_ctx *clp_ctx = nullptr;
int n_embd;
#endif
llama_model *model = nullptr;
llama_context *ctx = nullptr;
llama_batch batch;
@ -407,12 +478,40 @@ struct llama_server_context
bool loadModel(const gpt_params &params_)
{
params = params_;
#ifdef SERVER_MULTIMODAL_SUPPORT
if(!params.mmproj.empty()) {
multimodal = true;
LOG_TEE("Multi Modal Mode Enabled");
clp_ctx = clip_model_load(params.mmproj.c_str(), /*verbosity=*/ 1);
if(clp_ctx == nullptr) {
LOG_ERROR("unable to load clip model", {{"model", params.mmproj}});
return false;
}
if(params.n_ctx < 2048) { // request larger context for the image embedding
params.n_ctx = 2048;
}
}
#endif
std::tie(model, ctx) = llama_init_from_gpt_params(params);
if (model == nullptr)
{
LOG_ERROR("unable to load model", {{"model", params_.model}});
LOG_ERROR("unable to load model", {{"model", params.model}});
return false;
}
#ifdef SERVER_MULTIMODAL_SUPPORT
if(multimodal) {
int n_img_embd = clip_n_mmproj_embd(clp_ctx);
n_embd = llama_n_embd(model);
if (n_img_embd != n_embd) {
LOG_TEE("%s: embedding dim of the multimodal projector (%d) is not equal to that of LLaMA (%d). Make sure that you use the correct mmproj file.\n", __func__, n_img_embd, n_embd);
llama_free(ctx);
llama_free_model(model);
return false;
}
}
#endif
n_ctx = llama_n_ctx(ctx);
n_vocab = llama_n_vocab(model);
candidates.reserve(n_vocab);
@ -729,6 +828,37 @@ struct llama_server_context
std::this_thread::sleep_for(std::chrono::milliseconds(5));
}
// context shift takes effect only when there is a single slot
if(slots.size() == 1) {
llama_client_slot slot = slots[0];
if (slot.isProcessing() && slot.cache_tokens.size() >= (size_t)n_ctx)
{
// Shift context
const int n_left = slot.n_past - params.n_keep - 1;
const int n_discard = n_left / 2;
llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, slot.n_past, -n_discard);
for (size_t i = params.n_keep + 1 + n_discard; i < slot.cache_tokens.size(); i++)
{
slot.cache_tokens[i - n_discard] = slot.cache_tokens[i];
}
slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard);
slot.n_past -= n_discard;
slot.truncated = true;
LOG_VERBOSE("input truncated", {
{"n_ctx", n_ctx},
{"n_keep", params.n_keep},
{"n_left", n_left},
});
}
}
// decode any currently ongoing sequences
for (auto & slot : slots) {
// release the slot
@ -764,6 +894,8 @@ struct llama_server_context
batch.n_tokens += 1;
}
// process in chunks of params.n_batch
int32_t n_batch = params.n_batch;
// assign workload to the slots
if (params.cont_batching || batch.n_tokens == 0) {
@ -772,9 +904,35 @@ struct llama_server_context
if ((slot.state == IDLE || slot.state == SLEEPING) && slot.command == LOAD_PROMPT) {
slot.state = PROCESSING;
slot.command = NONE;
std::vector<llama_token> prompt_tokens;
#ifdef SERVER_MULTIMODAL_SUPPORT
bool ingest_image = false;
if(slot.request_encode_image) {
ingest_image = true;
clip_image_f32 img_res;
if (!clip_image_preprocess(clp_ctx, &slot.img_data, &img_res, /*pad2square =*/ true)) {
LOG_TEE("Error processing the given image");
clip_free(clp_ctx);
return false;
}
slot.image_tokens = clip_n_patches(clp_ctx);
slot.image_embedding = (float *)malloc(clip_embd_nbytes(clp_ctx));
if (!slot.image_embedding) {
LOG_TEE("Unable to allocate memory for image embeddings\n");
clip_free(clp_ctx);
return false;
}
LOG_TEE("slot %i - encoding image\n", slot.id);
if (!clip_image_encode(clp_ctx, params.n_threads, &img_res, slot.image_embedding)) {
LOG_TEE("Unable to encode image\n");
return false;
}
slot.request_encode_image = false;
}
#endif
slot.t_start_process_prompt = ggml_time_us();
slot.t_start_genereration = 0;
std::vector<llama_token> prompt_tokens;
if(slot.infill) {
bool suff_rm_leading_spc = true;
if (params.input_suffix.find_first_of(" ") == 0 && params.input_suffix.size() > 1) {
@ -809,7 +967,6 @@ struct llama_server_context
slot.n_past--;
}
slot.num_prompt_tokens_processed = slot.num_prompt_tokens - slot.n_past;
if(!slot.params.cache_prompt) {
@ -848,6 +1005,68 @@ struct llama_server_context
{"cached", tokens_to_str(ctx, slot.cache_tokens.cbegin(), slot.cache_tokens.cbegin() + slot.n_past)},
{"to_eval", tokens_to_str(ctx, slot.cache_tokens.cbegin() + slot.n_past, slot.cache_tokens.cend())},
});
#ifdef SERVER_MULTIMODAL_SUPPORT
std::vector<llama_token> preffix_tokens = ingest_image ? tokenize(slot.params.input_prefix, true) : prompt_tokens;
for (; slot.n_past < preffix_tokens.size(); ++slot.n_past) {
printf(llama_token_to_piece(ctx, preffix_tokens[slot.n_past]).c_str());
batch.token [batch.n_tokens] = preffix_tokens[slot.n_past];
batch.pos [batch.n_tokens] = slot.n_past + num_tokens_system;
batch.seq_id[batch.n_tokens] = slot.id;
batch.logits[batch.n_tokens] = false;
batch.n_tokens += 1;
}
if(ingest_image) {
// process preffix prompt
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
llama_batch batch_view = {
n_tokens,
batch.token + i,
nullptr,
batch.pos + i,
batch.seq_id + i,
batch.logits + i,
0, 0, 0, // unused
};
if (llama_decode(ctx, batch_view)) {
LOG_TEE("%s : failed to eval\n", __func__);
return false;
}
}
printf("\nEvaluated preffix prompt: %i\n", slot.n_past);
// process image
for (int i = 0; i < slot.image_tokens; i += n_batch) {
int n_eval = slot.image_tokens - i;
if (n_eval > n_batch) {
n_eval = n_batch;
}
llama_batch batch = {int32_t(n_eval), nullptr, (slot.image_embedding + i * n_embd), nullptr, nullptr, nullptr, slot.n_past, 1, 0, };
if (llama_decode(ctx, batch)) {
LOG_TEE("%s : failed to eval image\n", __func__);
return false;
}
slot.n_past += n_eval;
}
printf("Evaluated image embedding: %i\n", slot.n_past);
// process suffix prompt
batch.n_tokens = 0;
std::vector<llama_token> suffix_tokens = tokenize(slot.params.input_suffix, true);
for (int i = 0; i < suffix_tokens.size(); ++i) {
printf(llama_token_to_piece(ctx, suffix_tokens[i]).c_str());
batch.token [batch.n_tokens] = suffix_tokens[i];
batch.pos [batch.n_tokens] = slot.n_past;
batch.seq_id[batch.n_tokens] = slot.id;
batch.logits[batch.n_tokens] = false;
slot.n_past += 1;
batch.n_tokens += 1;
}
printf("\nEvaluated suffix prompt: %i\n", slot.n_past);
}
#else
for (; slot.n_past < prompt_tokens.size(); ++slot.n_past) {
batch.token [batch.n_tokens] = prompt_tokens[slot.n_past];
batch.pos [batch.n_tokens] = slot.n_past + num_tokens_system;
@ -855,7 +1074,7 @@ struct llama_server_context
batch.logits[batch.n_tokens] = false;
batch.n_tokens += 1;
}
#endif
// extract the logits only for the last token
if (batch.n_tokens > 0) {
batch.logits[batch.n_tokens - 1] = true;
@ -872,40 +1091,6 @@ struct llama_server_context
return true;
}
// context shift takes effect only when there is a single slot
if(slots.size() == 1) {
llama_client_slot slot = slots[0];
if (slot.cache_tokens.size() >= (size_t)n_ctx)
{
// Shift context
const int n_left = slot.n_past - params.n_keep - 1;
const int n_discard = n_left / 2;
llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, slot.n_past, -n_discard);
for (size_t i = params.n_keep + 1 + n_discard; i < slot.cache_tokens.size(); i++)
{
slot.cache_tokens[i - n_discard] = slot.cache_tokens[i];
}
slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard);
slot.n_past -= n_discard;
slot.truncated = true;
LOG_VERBOSE("input truncated", {
{"n_ctx", n_ctx},
{"n_keep", params.n_keep},
{"n_left", n_left},
});
}
}
// process in chunks of params.n_batch
int32_t n_batch = params.n_batch;
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
llama_batch batch_view = {
@ -1045,6 +1230,11 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
printf(" --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str());
printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout);
printf(" --embedding enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled");
printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel);
printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n");
#ifdef SERVER_MULTIMODAL_SUPPORT
printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA.\n");
#endif
printf("\n");
}
@ -1315,6 +1505,16 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
}
params.n_predict = std::stoi(argv[i]);
}
#ifdef SERVER_MULTIMODAL_SUPPORT
else if(arg == "--mmproj") {
if (++i >= argc)
{
invalid_param = true;
break;
}
params.mmproj = argv[i];
}
#endif
else
{
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
@ -1531,8 +1731,42 @@ static void parse_options_completion(const json &body, llama_client_slot* slot,
}
}
}
LOG_VERBOSE("completion parameters parsed", format_generation_settings(llama, slot));
#ifdef SERVER_MULTIMODAL_SUPPORT
if(!llama.multimodal) {
return;
}
std::string data_b64 = json_value(body, "image_data", std::string(""));
if(!data_b64.empty()) {
if(!slot->prompt.is_array()) {
std::string prompt = slot->prompt.get<std::string>();
int pos = prompt.find("[(img)]");
if(pos == std::string::npos) {
LOG_TEE("Missing image position in prompt\n");
return;
} else {
// reuse infill prompt input
slot->params.input_prefix = prompt.substr(0, pos);
slot->params.input_suffix = prompt.substr(pos + 7); // ignore [(img)]
slot->params.cache_prompt = false; // multimodal doesn't support cache prompt
}
}
int width, height, channels;
std::vector<uint8_t> image_buffer = base64_decode(data_b64);
data_b64.clear();
// decode base64
auto data = stbi_load_from_memory(image_buffer.data(), image_buffer.size(), &width, &height, &channels, 3);
slot->img_data.nx = width;
slot->img_data.ny = height;
slot->img_data.size = width * height * 3;
slot->img_data.data = new uint8_t[slot->img_data.size]();
memcpy(slot->img_data.data, data, slot->img_data.size);
stbi_image_free(data);
LOG_TEE("slot %i - RGB image loaded (%i x %i)\n", slot->id, width, height);
slot->request_encode_image = true;
}
#endif
}
static void parse_options_infill(const json &body, llama_server_context &llama, llama_client_slot *slot)