support ollama

This commit is contained in:
caitianchi 2024-05-28 01:13:57 +08:00
parent 8541e99629
commit d8974b8ea6
2 changed files with 79 additions and 0 deletions

View File

@ -323,6 +323,84 @@ bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx *
return true;
}
bool llava_image_embed_make_with_clip_img_ollama(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float ** image_embd_out, int * n_img_pos_out) {
std::vector<std::vector<clip_image_u8 *>> imgs = slice_image(img);
std::vector<std::vector<llava_image_embed *>> image_embed_slices;
for (size_t i = 0; i < imgs.size(); ++i){
image_embed_slices.push_back(std::vector<llava_image_embed *>());
for (size_t j = 0; j < imgs[i].size(); ++j) {
float* image_embed = NULL;
int n_image_pos = 0;
bool image_embed_result = llava_image_embed_make_with_clip_img(ctx_clip, n_threads, imgs[i][j], &image_embed, &n_image_pos);
if (!image_embed_result) {
LOG_TEE("%s: coulnd't embed the image\n", __func__);
return false;
}
auto result = (llava_image_embed*)malloc(sizeof(llava_image_embed));
result->embed = image_embed;
result->n_image_pos = n_image_pos;
image_embed_slices[i].push_back(result);
}
}
std::string fname = "./examples/minicpm-v2.5/slice_token_for_ollama.raw";
auto file = fopen(fname.c_str(), "rb");
if (file == NULL) {
LOG_TEE("%s: can't read file %s\n", __func__, fname.c_str());
return false;
}
fseek(file, 0, SEEK_END);
auto fileSize = ftell(file);
fseek(file, 0, SEEK_SET);
auto buffer = (unsigned char *)malloc(fileSize); // Allocate memory to hold the file data
if (buffer == NULL) {
LOG_TEE("%s: failed to alloc %ld bytes for file %s\n", __func__, fileSize, fname.c_str());
perror("Memory allocation error");
fclose(file);
return false;
}
errno = 0;
size_t ret = fread(buffer, 1, fileSize, file); // Read the file into the buffer
if (ferror(file)) {
die_fmt("read error: %s", strerror(errno));
}
if (ret != (size_t) fileSize) {
die("unexpectedly reached end of file");
}
fclose(file); // Close the file
float * all_image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip)*61);
int all_n_img_pos=0;
int token_len = 4096*sizeof(float);
std::memcpy(all_image_embd+token_len*all_n_img_pos++, buffer, token_len);
std::memcpy(all_image_embd+token_len*all_n_img_pos, image_embed_slices[0][0]->embed, 96*token_len);
all_n_img_pos+=96;
std::memcpy(all_image_embd+token_len*all_n_img_pos++, buffer+token_len, token_len);
if (image_embed_slices.size() > 1) {
std::memcpy(all_image_embd+token_len*all_n_img_pos++, buffer+token_len*2, token_len);
for (size_t i = 1; i < image_embed_slices.size(); ++i) {
for (size_t j = 0; j < image_embed_slices[i].size(); ++j) {
std::memcpy(all_image_embd+token_len*all_n_img_pos++, buffer, token_len);
std::memcpy(all_image_embd+token_len*all_n_img_pos, image_embed_slices[i][j]->embed, 96*token_len);
all_n_img_pos+=96;
std::memcpy(all_image_embd+token_len*all_n_img_pos++, buffer+token_len, token_len);
if (j == image_embed_slices[i].size() - 1) {
std::memcpy(all_image_embd+token_len*all_n_img_pos++, buffer+token_len*4, token_len);
}
}
}
std::memcpy(all_image_embd+token_len*all_n_img_pos++, buffer+token_len*3, token_len);
}
*image_embd_out = all_image_embd;
*n_img_pos_out = all_n_img_pos;
return true;
}
bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float ** image_embd_out, int * n_img_pos_out) {
float * image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip)*6); // TODO: base on gridsize/llava model
if (!image_embd) {

View File

@ -31,6 +31,7 @@ struct llava_image_embed {
/** sanity check for clip <-> llava embed size match */
MINICPMV_API bool llava_validate_embed_size(const struct llama_context * ctx_llama, const struct clip_ctx * ctx_clip);
MINICPMV_API bool llava_image_embed_make_with_clip_img_ollama(struct clip_ctx * ctx_clip, int n_threads, const struct clip_image_u8 * img, float ** image_embd_out, int * n_img_pos_out);
MINICPMV_API bool llava_image_embed_make_with_clip_img(struct clip_ctx * ctx_clip, int n_threads, const struct clip_image_u8 * img, float ** image_embd_out, int * n_img_pos_out);
/** build an image embed from image file bytes */