diff --git a/examples/minicpmv/clip.cpp b/examples/minicpmv/clip.cpp index 6a726fa50..ec5f61554 100644 --- a/examples/minicpmv/clip.cpp +++ b/examples/minicpmv/clip.cpp @@ -549,92 +549,6 @@ struct clip_ctx { ggml_gallocr_t compute_alloc = NULL; }; -std::vector>> get_1d_sincos_pos_embed_from_grid_new(int embed_dim, const std::vector>& pos) { - assert(embed_dim % 2 == 0); - int H = pos.size(); - int W = pos[0].size(); - - std::vector omega(embed_dim / 2); - for (int i = 0; i < embed_dim / 2; ++i) { - omega[i] = 1.0 / pow(10000.0, static_cast(i) / (embed_dim / 2)); - } - - std::vector>> emb(H, std::vector>(W, std::vector(embed_dim))); - for (int h = 0; h < H; ++h) { - for (int w = 0; w < W; ++w) { - for (int d = 0; d < embed_dim / 2; ++d) { - float out_value = pos[h][w] * omega[d]; - emb[h][w][d] = sin(out_value); - emb[h][w][d + embed_dim / 2] = cos(out_value); - } - } - } - - return emb; -} - -std::vector>> get_2d_sincos_pos_embed_from_grid(int embed_dim, const std::vector>>& grid) { - assert(embed_dim % 2 == 0); - std::vector>> emb_h = get_1d_sincos_pos_embed_from_grid_new(embed_dim / 2, grid[0]); // (H, W, D/2) - std::vector>> emb_w = get_1d_sincos_pos_embed_from_grid_new(embed_dim / 2, grid[1]); // (H, W, D/2) - - int H = emb_h.size(); - int W = emb_h[0].size(); - std::vector>> emb(H, std::vector>(W, std::vector(embed_dim))); - - for (int h = 0; h < H; ++h) { - for (int w = 0; w < W; ++w) { - for (int d = 0; d < embed_dim / 2; ++d) { - emb[h][w][d] = emb_h[h][w][d]; - emb[h][w][d + embed_dim / 2] = emb_w[h][w][d]; - } - } - } - return emb; -} - -std::vector> get_2d_sincos_pos_embed(int embed_dim, const std::pair image_size) { - int grid_h_size = image_size.first; - int grid_w_size = image_size.second; - - std::vector grid_h(grid_h_size); - std::vector grid_w(grid_w_size); - - for (int i = 0; i < grid_h_size; ++i) { - grid_h[i] = static_cast(i); - } - for (int i = 0; i < grid_w_size; ++i) { - grid_w[i] = static_cast(i); - } - - std::vector> grid(grid_h_size, std::vector(grid_w_size)); - for (int h = 0; h < grid_h_size; ++h) { - for (int w = 0; w < grid_w_size; ++w) { - grid[h][w] = grid_w[w]; - } - } - std::vector>> grid_2d = {grid, grid}; - for (int h = 0; h < grid_h_size; ++h) { - for (int w = 0; w < grid_w_size; ++w) { - grid_2d[0][h][w] = grid_h[h]; - grid_2d[1][h][w] = grid_w[w]; - } - } - - std::vector>> pos_embed_3d = get_2d_sincos_pos_embed_from_grid(embed_dim, grid_2d); - - int H = image_size.first; - int W = image_size.second; - std::vector> pos_embed_2d(H * W, std::vector(embed_dim)); - for (int h = 0; h < H; ++h) { - for (int w = 0; w < W; ++w) { - pos_embed_2d[w * H + h] = pos_embed_3d[h][w]; - } - } - - return pos_embed_2d; -} - static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch * imgs, std::pair load_image_size = {448, 448}) { if (!ctx->has_vision_encoder) { LOG_TEE("This gguf file seems to have no vision encoder\n"); @@ -1536,404 +1450,19 @@ bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length return true; } -// Linear interpolation between two points -inline float lerp(float s, float e, float t) { - return s + (e - s) * t; -} -// Bilinear resize function -static void bilinear_resize(const clip_image_u8& src, clip_image_u8& dst, int target_width, int target_height) { - dst.nx = target_width; - dst.ny = target_height; - dst.buf.resize(3 * target_width * target_height); - - float x_ratio = static_cast(src.nx - 1) / target_width; - float y_ratio = static_cast(src.ny - 1) / target_height; - - for (int y = 0; y < target_height; y++) { - for (int x = 0; x < target_width; x++) { - float px = x_ratio * x; - float py = y_ratio * y; - int x_floor = static_cast(px); - int y_floor = static_cast(py); - float x_lerp = px - x_floor; - float y_lerp = py - y_floor; - - for (int c = 0; c < 3; c++) { - float top = lerp( - static_cast(src.buf[3 * (y_floor * src.nx + x_floor) + c]), - static_cast(src.buf[3 * (y_floor * src.nx + (x_floor + 1)) + c]), - x_lerp - ); - float bottom = lerp( - static_cast(src.buf[3 * ((y_floor + 1) * src.nx + x_floor) + c]), - static_cast(src.buf[3 * ((y_floor + 1) * src.nx + (x_floor + 1)) + c]), - x_lerp - ); - dst.buf[3 * (y * target_width + x) + c] = static_cast(lerp(top, bottom, y_lerp)); - } - } - } -} - -// Normalize image to float32 - careful with pytorch .to(model.device, dtype=torch.float16) - this sometimes reduces precision (32>16>32), sometimes not -static void normalize_image_u8_to_f32(const clip_image_u8* src, clip_image_f32* dst, const float mean[3], const float std[3]) { +static void normalize_image_u8_to_f32(struct clip_ctx * ctx, const clip_image_u8* src, clip_image_f32* dst) { dst->nx = src->nx; dst->ny = src->ny; dst->buf.resize(src->buf.size()); + const auto & m3 = ctx->image_mean; + const auto & s3 = ctx->image_std; for (size_t i = 0; i < src->buf.size(); ++i) { int c = i % 3; // rgb - dst->buf[i] = (static_cast(src->buf[i]) / 255.0f - mean[c]) / std[c]; + dst->buf[i] = (static_cast(src->buf[i]) / 255.0f - m3[c]) / s3[c]; } } -inline float clip(float x, float lower, float upper) { - return std::max(lower, std::min(x, upper)); -} - -static bool bicubic_resize(const clip_image_u8 &img, clip_image_u8 &dst, int target_width, int target_height) { - const int nx = img.nx; - const int ny = img.ny; - - dst.nx = target_width; - dst.ny = target_height; - dst.buf.resize(3 * target_width * target_height); - - float Cc; - float C[5]; - float d0, d2, d3, a0, a1, a2, a3; - int i, j, k, jj; - int x, y; - float dx, dy; - float tx, ty; - - tx = (float)nx / (float)target_width; - ty = (float)ny / (float)target_height; - - // Bicubic interpolation; adapted from ViT.cpp, inspired from : - // -> https://github.com/yglukhov/bicubic-interpolation-image-processing/blob/master/libimage.c#L36 - // -> https://en.wikipedia.org/wiki/Bicubic_interpolation - - for (i = 0; i < target_height; i++) { - for (j = 0; j < target_width; j++) { - x = (int)(tx * j); - y = (int)(ty * i); - - dx = tx * j - x; - dy = ty * i - y; - - for (k = 0; k < 3; k++) { - for (jj = 0; jj <= 3; jj++) { - d0 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x - 1, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k]; - d2 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x + 1, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k]; - d3 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x + 2, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k]; - a0 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k]; - - a1 = -1.0 / 3 * d0 + d2 - 1.0 / 6 * d3; - a2 = 1.0 / 2 * d0 + 1.0 / 2 * d2; - a3 = -1.0 / 6 * d0 - 1.0 / 2 * d2 + 1.0 / 6 * d3; - - C[jj] = a0 + a1 * dx + a2 * dx * dx + a3 * dx * dx * dx; - - d0 = C[0] - C[1]; - d2 = C[2] - C[1]; - d3 = C[3] - C[1]; - a0 = C[1]; - a1 = -1.0 / 3 * d0 + d2 - 1.0 / 6 * d3; - a2 = 1.0 / 2 * d0 + 1.0 / 2 * d2; - a3 = -1.0 / 6 * d0 - 1.0 / 2 * d2 + 1.0 / 6 * d3; - Cc = a0 + a1 * dy + a2 * dy * dy + a3 * dy * dy * dy; - - const uint8_t Cc2 = std::min(std::max(std::round(Cc), 0.0f), 255.0f); - dst.buf[(i * target_width + j) * 3 + k] = float(Cc2); - } - } - } - } - - return true; -} - -// llava-1.6 type of resize_and_pad (black) -static void resize_and_pad_image(const clip_image_u8& image, clip_image_u8 &image_output, const std::pair& target_resolution) { - int target_width = target_resolution.first; - int target_height = target_resolution.second; - - float scale_w = static_cast(target_width) / image.nx; - float scale_h = static_cast(target_height) / image.ny; - - int new_width, new_height; - - if (scale_w < scale_h) { - new_width = target_width; - new_height = std::min(static_cast(std::ceil(image.ny * scale_w)), target_height); - } else { - new_height = target_height; - new_width = std::min(static_cast(std::ceil(image.nx * scale_h)), target_width); - } - - clip_image_u8 resized_image; - // bilinear_resize(image, resized_image, new_width, new_height); - bicubic_resize(image, resized_image, new_width, new_height); - - clip_image_u8 padded_image; - padded_image.nx = target_width; - padded_image.ny = target_height; - padded_image.buf.resize(3 * target_width * target_height, 0); // Initialize with black - - // Calculate padding offsets - int pad_x = (target_width - new_width) / 2; - int pad_y = (target_height - new_height) / 2; - - // Copy the resized image into the center of the padded buffer - for (int y = 0; y < new_height; ++y) { - for (int x = 0; x < new_width; ++x) { - for (int c = 0; c < 3; ++c) { - padded_image.buf[3 * ((y + pad_y) * target_width + (x + pad_x)) + c] = resized_image.buf[3 * (y * new_width + x) + c]; - } - } - } - image_output = std::move(padded_image); -} - -/** - * Selects the best resolution from a list of possible resolutions based on the original size. - * - * @param original_size The original size of the image in the format (width, height). - * @param possible_resolutions A list of possible resolutions in the format [(width1, height1), (width2, height2), ...]. - * @return The best fit resolution in the format (width, height). - */ -static std::pair select_best_resolution(const std::pair & original_size, const std::vector> & possible_resolutions) { - int original_width = original_size.first; - int original_height = original_size.second; - std::pair best_fit; - int max_effective_resolution = 0; - int min_wasted_resolution = std::numeric_limits::max(); - - for (const auto& resolution : possible_resolutions) { - int width = resolution.first; - int height = resolution.second; - float scale = std::min(static_cast(width) / original_width, static_cast(height) / original_height); - int downscaled_width = static_cast(original_width * scale); - int downscaled_height = static_cast(original_height * scale); - int effective_resolution = std::min(downscaled_width * downscaled_height, original_width * original_height); - int wasted_resolution = (width * height) - effective_resolution; - // LOG_TEE("resolution: %d %d, scale: %f, downscaled: %d %d, effective: %d, wasted: %d\n", width, height, scale, downscaled_width, downscaled_height, effective_resolution, wasted_resolution); - if (effective_resolution > max_effective_resolution || (effective_resolution == max_effective_resolution && wasted_resolution < min_wasted_resolution)) { - max_effective_resolution = effective_resolution; - min_wasted_resolution = wasted_resolution; - best_fit = resolution; - } - } - - return best_fit; -} - -static std::vector divide_to_patches_u8(const clip_image_u8 & image, int patch_size) { - std::vector patches; - int width = image.nx; - int height = image.ny; - for (int i = 0; i < height; i += patch_size) { - for (int j = 0; j < width; j += patch_size) { - clip_image_u8 *patch = clip_image_u8_init(); - patch->nx = std::min(patch_size, width - j); - patch->ny = std::min(patch_size, height - i); - patch->buf.resize(3 * patch->nx * patch->ny); - for (int y = 0; y < patch->ny; ++y) { - for (int x = 0; x < patch->nx; ++x) { - for (int c = 0; c < 3; ++c) { - patch->buf[3 * (y * patch->nx + x) + c] = image.buf[3 * ((i + y) * width + (j + x)) + c]; - } - } - } - patches.push_back(patch); - } - } - return patches; -} - -// returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector -// res_imgs memory is being allocated here, previous allocations will be freed if found -bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, clip_image_f32_batch * res_imgs) { - bool pad_to_square = true; - if (!ctx->has_vision_encoder) { - LOG_TEE("This gguf file seems to have no vision encoder\n"); - return false; - } - auto & params = ctx->vision_model.hparams; - // The model config actually contains all we need to decide on how to preprocess, here we automatically switch to the new llava-1.6 preprocessing - if (strcmp(params.mm_patch_merge_type, "spatial_unpad") == 0) { - pad_to_square = false; - } - // free the previous res_imgs if any set - if (res_imgs->size > 0) { - clip_image_f32_batch_free(res_imgs); - } - res_imgs->data = nullptr; - res_imgs->size = 0; - - // the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104) - // see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156 - - clip_image_u8 * temp = clip_image_u8_init(); // we will keep the input image data here temporarily - temp->nx = img->nx; - temp->ny = img->ny; - temp->buf.resize(img->buf.size()); - memcpy(temp->buf.data(), img->buf.data(), temp->buf.size()); - - // if (pad_to_square && img->nx != img->ny) { - // int longer_side = std::max(img->nx, img->ny); - // temp->nx = img->nx; - // temp->ny = longer_side; - // temp->buf.resize(3 * longer_side * longer_side); - // const uint8_t bc[3] = {122, 116, 104}; // background color in RGB from LLaVA (this is the mean rgb color * 255) - - // // fill with background color - // for (size_t i = 0; i < temp->buf.size(); i++) { - // temp->buf[i] = bc[i % 3]; - // } - - // // copy from the input image - // for (int y = 0; y < img->ny; y++) { - // for (int x = 0; x < img->nx; x++) { - // const int i = 3 * (y * img->nx + x); - // const int j = 3 * (y * temp->nx + x); - // temp->buf[j] = img->buf[i]; - // temp->buf[j+1] = img->buf[i+1]; - // temp->buf[j+2] = img->buf[i+2]; - // } - // } - // } else { - // if (params.image_grid_pinpoints[0] != 0) { - // // "spatial_unpad" with "anyres" processing for llava-1.6 - // std::vector> possible_resolutions; - // for (int i = 0; i < 32 && params.image_grid_pinpoints[i] != 0; i+=2) { - // possible_resolutions.push_back({params.image_grid_pinpoints[i], params.image_grid_pinpoints[i+1]}); - // } - // std::pair best_resolution = select_best_resolution({img->nx, img->ny}, possible_resolutions); - // // clip_image_save_to_bmp(*img, "input.bmp"); - // resize_and_pad_image(*img, *temp, best_resolution); // we do not pad with mean-bg color anymore in llava-1.6 - // // clip_image_save_to_bmp(*temp, "resized.bmp"); - // // visually verify normalized image: - // // normalize_image_u8_to_f32(*temp, *res, ctx->image_mean, ctx->image_std); - // // { - // // clip_image_u8 * temp2 = clip_image_u8_init(); - // // clip_image_convert_f32_to_u8(*res, *temp2); - // // clip_image_save_to_bmp(*temp2, "resized_normalized_f32.bmp"); - // // clip_image_u8_free(temp2); - // // } - - // std::vector patches = divide_to_patches_u8(*temp, params.image_size); // prepare spatial sorted main patches of image_size each (336 in llava-1.6) - - // clip_image_u8 *image_original_resize = clip_image_u8_init(); - // // bilinear_resize(*img, *image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square - // bicubic_resize(*img, *image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square - // patches.insert(patches.begin(), image_original_resize); - // // clip_image_f32_batch_init(patches.size()); - // res_imgs->size = patches.size(); - // res_imgs->data = new clip_image_f32[res_imgs->size]; - // int num=0; - // for (auto& patch : patches) { - // normalize_image_u8_to_f32(patch, &res_imgs->data[num], ctx->image_mean, ctx->image_std); - // num++; - // } - - // for (size_t i = 0; i < patches.size(); i++) { - // // LOG_TEE("patch %d: %d %d\n", i, patches[i]->nx, patches[i]->ny); - // clip_image_u8_free(patches[i]); - // } - - // clip_image_u8_free(temp); - - // return true; - // } else { - // temp->nx = img->nx; - // temp->ny = img->ny; - // temp->buf.resize(img->buf.size()); - // memcpy(temp->buf.data(), img->buf.data(), temp->buf.size()); - // } - // } - - const int nx = temp->nx; - const int ny = temp->ny; - // clip_image_save_to_bmp(*temp, "resized_vanilla.bmp"); - - const int nx2 = temp->nx; - const int ny2 = temp->ny; - - clip_image_f32 * res = clip_image_f32_init(); - res->nx = nx2; - res->ny = ny2; - res->buf.resize(3 * nx2 * ny2); - - // const float scale = std::max(nx, ny) / (float)ctx->vision_model.hparams.image_size; - - // const int nx3 = int(nx / scale + 0.5f); - // const int ny3 = int(ny / scale + 0.5f); - - const int nx3 = nx; - const int ny3 = ny; - - const auto & m3 = ctx->image_mean; // {0.48145466f, 0.4578275f, 0.40821073f}; - const auto & s3 = ctx->image_std; // {0.26862954f, 0.26130258f, 0.27577711f}; - - for (int y = 0; y < ny3; y++) { - for (int x = 0; x < nx3; x++) { - for (int c = 0; c < 3; c++) { - // linear interpolation - const float sx = x; - const float sy = y; - - const int x0 = std::max(0, (int)std::floor(sx)); - const int y0 = std::max(0, (int)std::floor(sy)); - - const int x1 = std::min(x0 + 1, nx - 1); - const int y1 = std::min(y0 + 1, ny - 1); - - const float dx = sx - x0; - const float dy = sy - y0; - - const int j00 = 3 * (y0 * nx + x0) + c; - const int j01 = 3 * (y0 * nx + x1) + c; - const int j10 = 3 * (y1 * nx + x0) + c; - const int j11 = 3 * (y1 * nx + x1) + c; - - const float v00 = temp->buf[j00]; - const float v01 = temp->buf[j01]; - const float v10 = temp->buf[j10]; - const float v11 = temp->buf[j11]; - - const float v0 = v00 * (1.0f - dx) + v01 * dx; - const float v1 = v10 * (1.0f - dx) + v11 * dx; - - const float v = v0 * (1.0f - dy) + v1 * dy; - - const uint8_t v2 = std::min(std::max(std::round(v), 0.0f), 255.0f); - - const int i = 3 * (y * nx3 + x) + c; - - res->buf[i] = ((float(v2) / 255.0f) - m3[c]) / s3[c]; - } - } - } - clip_image_u8_free(temp); - - // { - // clip_image_u8 * temp2 = clip_image_u8_init(); - // clip_image_convert_f32_to_u8(*res, *temp2); - // clip_image_save_to_bmp(*temp2, "resized_normalized_f32_vanilla.bmp"); - // clip_image_u8_free(temp2); - // } - // res_imgs.push_back(res); - - res_imgs->size = 1; - res_imgs->data = new clip_image_f32[res_imgs->size]; - res_imgs->data[0] = *res; - clip_image_f32_free(res); - - return true; -} - ggml_tensor * clip_get_newline_tensor(const struct clip_ctx * ctx) { return ctx->vision_model.image_newline; } @@ -1986,6 +1515,92 @@ int clip_n_patches(const struct clip_ctx * ctx) { return n_patches; } +std::vector>> get_1d_sincos_pos_embed_from_grid_new(int embed_dim, const std::vector>& pos) { + assert(embed_dim % 2 == 0); + int H = pos.size(); + int W = pos[0].size(); + + std::vector omega(embed_dim / 2); + for (int i = 0; i < embed_dim / 2; ++i) { + omega[i] = 1.0 / pow(10000.0, static_cast(i) / (embed_dim / 2)); + } + + std::vector>> emb(H, std::vector>(W, std::vector(embed_dim))); + for (int h = 0; h < H; ++h) { + for (int w = 0; w < W; ++w) { + for (int d = 0; d < embed_dim / 2; ++d) { + float out_value = pos[h][w] * omega[d]; + emb[h][w][d] = sin(out_value); + emb[h][w][d + embed_dim / 2] = cos(out_value); + } + } + } + + return emb; +} + +std::vector>> get_2d_sincos_pos_embed_from_grid(int embed_dim, const std::vector>>& grid) { + assert(embed_dim % 2 == 0); + std::vector>> emb_h = get_1d_sincos_pos_embed_from_grid_new(embed_dim / 2, grid[0]); // (H, W, D/2) + std::vector>> emb_w = get_1d_sincos_pos_embed_from_grid_new(embed_dim / 2, grid[1]); // (H, W, D/2) + + int H = emb_h.size(); + int W = emb_h[0].size(); + std::vector>> emb(H, std::vector>(W, std::vector(embed_dim))); + + for (int h = 0; h < H; ++h) { + for (int w = 0; w < W; ++w) { + for (int d = 0; d < embed_dim / 2; ++d) { + emb[h][w][d] = emb_h[h][w][d]; + emb[h][w][d + embed_dim / 2] = emb_w[h][w][d]; + } + } + } + return emb; +} + +std::vector> get_2d_sincos_pos_embed(int embed_dim, const std::pair image_size) { + int grid_h_size = image_size.first; + int grid_w_size = image_size.second; + + std::vector grid_h(grid_h_size); + std::vector grid_w(grid_w_size); + + for (int i = 0; i < grid_h_size; ++i) { + grid_h[i] = static_cast(i); + } + for (int i = 0; i < grid_w_size; ++i) { + grid_w[i] = static_cast(i); + } + + std::vector> grid(grid_h_size, std::vector(grid_w_size)); + for (int h = 0; h < grid_h_size; ++h) { + for (int w = 0; w < grid_w_size; ++w) { + grid[h][w] = grid_w[w]; + } + } + std::vector>> grid_2d = {grid, grid}; + for (int h = 0; h < grid_h_size; ++h) { + for (int w = 0; w < grid_w_size; ++w) { + grid_2d[0][h][w] = grid_h[h]; + grid_2d[1][h][w] = grid_w[w]; + } + } + + std::vector>> pos_embed_3d = get_2d_sincos_pos_embed_from_grid(embed_dim, grid_2d); + + int H = image_size.first; + int W = image_size.second; + std::vector> pos_embed_2d(H * W, std::vector(embed_dim)); + for (int h = 0; h < H; ++h) { + for (int w = 0; w < W; ++w) { + pos_embed_2d[w * H + h] = pos_embed_3d[h][w]; + } + } + + return pos_embed_2d; +} + bool clip_image_encode(struct clip_ctx * ctx, const int n_threads, clip_image_f32 * img, float * vec, std::pair load_image_size = {448, 448}) { if (!ctx->has_vision_encoder) { LOG_TEE("This gguf file seems to have no vision encoder\n"); @@ -2052,12 +1667,8 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions"); int* positions_data = (int*)malloc(ggml_nbytes(positions)); - int n = 0; - float t = 0; for (int i = 0; i < num_positions; i++) { - positions_data[i] = n; - t=70.0*i/num_positions-1; - if(t>n)n++; + positions_data[i] = std::floor(70.0*i/num_positions); } ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions)); free(positions_data); diff --git a/examples/minicpmv/clip.h b/examples/minicpmv/clip.h index 335a02e1d..aae4c7c3a 100644 --- a/examples/minicpmv/clip.h +++ b/examples/minicpmv/clip.h @@ -69,8 +69,7 @@ CLIP_API bool clip_image_load_from_file(const char * fname, struct clip_image_u8 /** interpret bytes as an image file with length bytes_length, and use the result to populate img */ CLIP_API bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, struct clip_image_u8 * img); -/** preprocess img and store the result in res_imgs, pad_to_square may be overriden to false depending on model configuration */ -CLIP_API bool clip_image_preprocess(struct clip_ctx * ctx, const struct clip_image_u8 * img, struct clip_image_f32_batch * res_imgs ); +static void normalize_image_u8_to_f32(struct clip_ctx * ctx, const clip_image_u8* src, clip_image_f32* dst); CLIP_API struct ggml_tensor * clip_get_newline_tensor(const struct clip_ctx * ctx); diff --git a/examples/minicpmv/minicpmv.cpp b/examples/minicpmv/minicpmv.cpp index 7339a02b4..1b4273e27 100644 --- a/examples/minicpmv/minicpmv.cpp +++ b/examples/minicpmv/minicpmv.cpp @@ -33,45 +33,28 @@ struct clip_image_grid_shape { static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float * image_embd, int * n_img_pos) { // std::vector img_res_v; // format VectN x H x W x RGB (N x 336 x 336 x 3), so interleaved RGB - different to the python implementation which is N x 3 x 336 x 336 - clip_image_f32_batch img_res_v; - img_res_v.size = 0; - img_res_v.data = nullptr; + + clip_image_f32 * img_res_v = clip_image_f32_init(); std::pair load_image_size; load_image_size.first = img->nx; load_image_size.second = img->ny; - const int64_t t_img_enc_start_us_ip = ggml_time_us(); - if (!clip_image_preprocess(ctx_clip, img, &img_res_v)) { - LOG_TEE("%s: unable to preprocess image\n", __func__); - delete[] img_res_v.data; - return false; - } - - const int64_t t_img_enc_end_us_ip = ggml_time_us(); - float t_img_enc_ms_ip = (t_img_enc_end_us_ip - t_img_enc_start_us_ip) / 1000.0; - - LOG_TEE("\n%s: image encoded in %8.2f ms by clip_image_preprocess.\n", __func__, t_img_enc_ms_ip); + normalize_image_u8_to_f32(ctx_clip, img, img_res_v); const int64_t t_img_enc_start_us = ggml_time_us(); const char * mm_patch_merge_type = clip_patch_merge_type(ctx_clip); - LOG_TEE("\n%s: mm_patch_merge_type is %s.\n", __func__, mm_patch_merge_type); *n_img_pos = clip_n_patches(ctx_clip); - - bool encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[0], image_embd, load_image_size); // image_embd shape is 576 x 4096 - delete[] img_res_v.data; + bool encoded = clip_image_encode(ctx_clip, n_threads, img_res_v, image_embd, load_image_size); // image_embd shape is 576 x 4096 if (!encoded) { LOG_TEE("Unable to encode image\n"); - return false; } - LOG_TEE("%s: image embedding created: %d tokens\n", __func__, *n_img_pos); const int64_t t_img_enc_end_us = ggml_time_us(); float t_img_enc_ms = (t_img_enc_end_us - t_img_enc_start_us) / 1000.0; - LOG_TEE("\n%s: image encoded in %8.2f ms by CLIP (%8.2f ms per image patch)\n", __func__, t_img_enc_ms, t_img_enc_ms / *n_img_pos); return true; @@ -231,7 +214,7 @@ static bool bicubic_resize(const clip_image_u8 &img, clip_image_u8 &dst, int tar return true; } -std::vector> slice_image(const clip_image_u8 * img, const int max_slice_nums, const int scale_resolution, const int patch_size, const bool never_split) { +std::vector> slice_image(const clip_image_u8 * img, const int max_slice_nums=9, const int scale_resolution=448, const int patch_size=14, const bool never_split=false) { const std::pair original_size={img->nx,img->ny}; const int original_width = img->nx; const int original_height = img->ny; @@ -244,10 +227,6 @@ std::vector> slice_image(const clip_image_u8 * img, images.push_back(std::vector()); if(multiple <= 1){ - // auto best_resolution = select_best_resolution(image_size, grid_pinpoints); - // clip_image_u8 *image_original_resize = clip_image_u8_init(); - // bicubic_resize(*img, *image_original_resize, best_resolution.first, best_resolution.second); - auto best_size = find_best_resize(original_size, scale_resolution, patch_size, true); clip_image_u8 *source_image = clip_image_u8_init(); bicubic_resize(*img, *source_image, best_size.first, best_size.second); @@ -324,10 +303,7 @@ std::vector> slice_image(const clip_image_u8 * img, images[images.size()-1].push_back(patch); } } - - } - return images; } diff --git a/examples/minicpmv/minicpmv.h b/examples/minicpmv/minicpmv.h index b7f8a24d2..2c2d44fda 100644 --- a/examples/minicpmv/minicpmv.h +++ b/examples/minicpmv/minicpmv.h @@ -31,13 +31,12 @@ struct llava_image_embed { /** sanity check for clip <-> llava embed size match */ MINICPMV_API bool llava_validate_embed_size(const struct llama_context * ctx_llama, const struct clip_ctx * ctx_clip); -MINICPMV_API bool llava_image_embed_make_with_clip_img_ollama(struct clip_ctx * ctx_clip, int n_threads, const struct clip_image_u8 * img, float ** image_embd_out, int * n_img_pos_out); MINICPMV_API bool llava_image_embed_make_with_clip_img(struct clip_ctx * ctx_clip, int n_threads, const struct clip_image_u8 * img, float ** image_embd_out, int * n_img_pos_out); /** build an image embed from image file bytes */ -MINICPMV_API std::vector> slice_image(const clip_image_u8 * img, const int max_slice_nums=9, const int scale_resolution=448, const int patch_size=14, const bool never_split=false); MINICPMV_API std::vector> llava_image_embed_make_with_bytes_slice(struct clip_ctx * ctx_clip, int n_threads, const unsigned char * image_bytes, int image_bytes_length); /** build an image embed from a path to an image filename */ +MINICPMV_API bool llava_image_embed_make_with_clip_img_ollama(struct clip_ctx * ctx_clip, int n_threads, const struct clip_image_u8 * img, float ** image_embd_out, int * n_img_pos_out); MINICPMV_API std::vector> llava_image_embed_make_with_filename_slice(struct clip_ctx * ctx_clip, int n_threads, const char * image_path); MINICPMV_API void llava_image_embed_free_slice(std::vector> embed); /** free an embedding made with llava_image_embed_make_* */