llava : support MiniCPM-V-2.6 (#8967)

* init

* rename

* add run android for termux in readme

* add android readme

* add instructions in readme

* change name in readme

* Update README.md

* fixed line

* add result in readme

* random pos_embed

* add positions index

* change for ollama

* change for ollama

* better pos_embed in clip

* support ollama

* updata cmakelist

* updata cmakelist

* rename wrapper

* clear code

* replace and organize code

* add link

* sync master

* fix warnings

* fix warnings

* fix bug in bicubic resize when need resize iamge smaller

* receive review comments and modify

* receive review comments and modify

* put all code into llava dir

* fix quality problem in pr code

* change n_layer

* add space in "-1"

* imitate reshape bug of python code

* fix bug in clip

* fix issues for merging

* fix llama-minicpmv-cli in cmake file

* change pr readme

* fix code review

* remove in line 33 directory in the /cmakelists.txt (not in example, in the main dir

* fix cmakefile

* add warn

* fix KEY_HAS_MINICPMV_PROJ

* remove load_image_size into clip_ctx

* remove the extern "C", MINICPMV_API

* fix uhd code for review comment

* delete minicpmv-wrapper in pr

* remove uhd_image_embed

* Modify 2 notes

* support minicpmv2.6

* modify convert script of minicpmv

* modify convert

* modify convert

* add readme

* add resampler of v2.6

* modify clip

* modify readme

* fix type-check

* fix type-check

* fix type-check

* fix type-check

* modify convert script and readme

* fix convert script and readme

* fix convert

* fix num in convert

* fix type-check

---------

Co-authored-by: Hongji Zhu <fireyoucan@gmail.com>
Co-authored-by: harvestingmoon <leewenyeong@gmail.com>
This commit is contained in:
tc-mb 2024-08-16 21:34:41 +08:00 committed by GitHub
parent ee2984bdaf
commit d565bb2fd5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 645 additions and 35 deletions

View File

@ -16,8 +16,8 @@ Convert PyTorch model to gguf files (You can also download the converted [gguf](
```bash
python ./examples/minicpmv/minicpmv-surgery.py -m ../MiniCPM-Llama3-V-2_5
python ./examples/minicpmv/minicpmv-convert-image-encoder-to-gguf.py -m ../MiniCPM-Llama3-V-2_5 --minicpmv-projector ../MiniCPM-Llama3-V-2_5/minicpmv.projector --output-dir ../MiniCPM-Llama3-V-2_5/ --image-mean 0.5 0.5 0.5 --image-std 0.5 0.5 0.5
python ./convert-hf-to-gguf.py ../MiniCPM-Llama3-V-2_5/model
python ./examples/minicpmv/minicpmv-convert-image-encoder-to-gguf.py -m ../MiniCPM-Llama3-V-2_5 --minicpmv-projector ../MiniCPM-Llama3-V-2_5/minicpmv.projector --output-dir ../MiniCPM-Llama3-V-2_5/ --image-mean 0.5 0.5 0.5 --image-std 0.5 0.5 0.5 --minicpmv_version 2
python ./convert_hf_to_gguf.py ../MiniCPM-Llama3-V-2_5/model
# quantize int4 version
./llama-quantize ../MiniCPM-Llama3-V-2_5/model/model-8B-F16.gguf ../MiniCPM-Llama3-V-2_5/model/ggml-model-Q4_K_M.gguf Q4_K_M

View File

@ -0,0 +1,107 @@
## MiniCPM-V 2.6
### Prepare models and code
Download [MiniCPM-V-2_6](https://huggingface.co/openbmb/MiniCPM-V-2_6) PyTorch model from huggingface to "MiniCPM-V-2_6" folder.
Clone llama.cpp:
```bash
git clone git@github.com:OpenBMB/llama.cpp.git
cd llama.cpp
git checkout minicpmv-main
```
### Usage of MiniCPM-V 2.6
Convert PyTorch model to gguf files (You can also download the converted [gguf](https://huggingface.co/openbmb/MiniCPM-V-2_6-gguf) by us)
```bash
python ./examples/llava/minicpmv-surgery.py -m ../MiniCPM-V-2_6
python ./examples/llava/minicpmv-convert-image-encoder-to-gguf.py -m ../MiniCPM-V-2_6 --minicpmv-projector ../MiniCPM-V-2_6/minicpmv.projector --output-dir ../MiniCPM-V-2_6/ --image-mean 0.5 0.5 0.5 --image-std 0.5 0.5 0.5 --minicpmv_version 3
python ./convert_hf_to_gguf.py ../MiniCPM-V-2_6/model
# quantize int4 version
./llama-quantize ../MiniCPM-V-2_6/model/ggml-model-f16.gguf ../MiniCPM-V-2_6/model/ggml-model-Q4_K_M.gguf Q4_K_M
```
Build for Linux or Mac
```bash
make
make llama-minicpmv-cli
```
Inference on Linux or Mac
```
# run f16 version
./llama-minicpmv-cli -m ../MiniCPM-V-2_6/model/ggml-model-f16.gguf --mmproj ../MiniCPM-V-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
# run quantized int4 version
./llama-minicpmv-cli -m ../MiniCPM-V-2_6/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-V-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
# or run in interactive mode
./llama-minicpmv-cli -m ../MiniCPM-V-2_6/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-V-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -i
```
### Video
Install FFmpeg
```
brew install ffmpeg
brew install pkg-config
```
### Android
#### Build on Android device using Termux
We found that build on Android device would bring better runtime performance, so we recommend to build on device.
[Termux](https://github.com/termux/termux-app#installation) is a terminal app on Android device (no root required).
Install tools in Termux:
```
apt update && apt upgrade -y
apt install git make cmake
```
It's recommended to move your model inside the `~/` directory for best performance:
```
cd storage/downloads
mv model.gguf ~/
```
#### Building the Project using Android NDK
Obtain the [Android NDK](https://developer.android.com/ndk) and then build with CMake.
Execute the following commands on your computer to avoid downloading the NDK to your mobile. Alternatively, you can also do this in Termux:
```bash
mkdir build-android
cd build-android
export NDK=/your_ndk_path
cmake -DCMAKE_TOOLCHAIN_FILE=$NDK/build/cmake/android.toolchain.cmake -DANDROID_ABI=arm64-v8a -DANDROID_PLATFORM=android-23 -DCMAKE_C_FLAGS=-march=armv8.4a+dotprod ..
make
```
Install [termux](https://github.com/termux/termux-app#installation) on your device and run `termux-setup-storage` to get access to your SD card (if Android 11+ then run the command twice).
Finally, copy these built `llama` binaries and the model file to your device storage. Because the file permissions in the Android sdcard cannot be changed, you can copy the executable files to the `/data/data/com.termux/files/home/bin` path, and then execute the following commands in Termux to add executable permission:
(Assumed that you have pushed the built executable files to the /sdcard/llama.cpp/bin path using `adb push`)
```
$cp -r /sdcard/llama.cpp/bin /data/data/com.termux/files/home/
$cd /data/data/com.termux/files/home/bin
$chmod +x ./*
```
Download models and push them to `/sdcard/llama.cpp/`, then move it to `/data/data/com.termux/files/home/model/`
```
$mv /sdcard/llama.cpp/ggml-model-Q4_K_M.gguf /data/data/com.termux/files/home/model/
$mv /sdcard/llama.cpp/mmproj-model-f16.gguf /data/data/com.termux/files/home/model/
```
Now, you can start chatting:
```
$cd /data/data/com.termux/files/home/bin
$./llama-minicpmv-cli -m ../model/ggml-model-Q4_K_M.gguf --mmproj ../model/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
```

View File

@ -81,6 +81,7 @@ static std::string format(const char * fmt, ...) {
#define KEY_HAS_VIS_ENC "clip.has_vision_encoder"
#define KEY_HAS_LLAVA_PROJ "clip.has_llava_projector"
#define KEY_HAS_MINICPMV_PROJ "clip.has_minicpmv_projector"
#define KEY_MINICPMV_VERSION "clip.minicpmv_version"
#define KEY_USE_GELU "clip.use_gelu"
#define KEY_N_EMBD "clip.%s.embedding_length"
#define KEY_N_FF "clip.%s.feed_forward_length"
@ -526,6 +527,7 @@ struct clip_ctx {
bool has_vision_encoder = false;
bool has_llava_projector = false;
bool has_minicpmv_projector = false;
int minicpmv_version = 2;
struct clip_vision_model vision_model;
projector_type proj_type = PROJECTOR_TYPE_MLP;
@ -641,7 +643,12 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
if (ctx->has_minicpmv_projector) {
int pos_w = image_size_width/patch_size;
int pos_h = image_size_height/patch_size;
pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 4096, pos_w * pos_h, 1);
if (ctx->minicpmv_version == 2) {
pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 4096, pos_w * pos_h, 1);
}
else if (ctx->minicpmv_version == 3) {
pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 3584, pos_w * pos_h, 1);
}
ggml_set_name(pos_embed, "pos_embed");
ggml_set_input(pos_embed);
}
@ -768,8 +775,8 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
embeddings = ggml_gelu(ctx0, embeddings);
embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings);
embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
} else if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) {
}
else if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) {
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
// ggml_tensor_printf(embeddings, "mm_0_w",0,true,false);
@ -949,10 +956,20 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
}
{ // attention
const int hidden_size = 4096;
int hidden_size = 4096;
const int d_head = 128;
const int n_head = hidden_size/d_head;
const int num_query = 96;
int n_head = hidden_size/d_head;
int num_query = 96;
if (ctx->minicpmv_version == 2) {
hidden_size = 4096;
n_head = hidden_size/d_head;
num_query = 96;
}
else if (ctx->minicpmv_version == 3) {
hidden_size = 3584;
n_head = hidden_size/d_head;
num_query = 64;
}
struct ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q), model.mm_model_attn_q_b);
Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head));
@ -1149,6 +1166,11 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
new_clip->has_minicpmv_projector = gguf_get_val_bool(ctx, idx);
}
idx = gguf_find_key(ctx, KEY_MINICPMV_VERSION);
if (idx != -1) {
new_clip->minicpmv_version = gguf_get_val_i32(ctx, idx);
}
// GGML_ASSERT(new_clip->has_llava_projector); // see monatis/clip.cpp for image and/or text encoding for semantic search
GGML_ASSERT(new_clip->has_vision_encoder);
@ -1910,10 +1932,12 @@ int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip) {
// returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector
// res_imgs memory is being allocated here, previous allocations will be freed if found
bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, clip_image_f32_batch * res_imgs) {
if (clip_is_minicpmv(ctx)) {
std::vector<std::vector<clip_image_u8 *>> imgs = uhd_slice_image(img);
if(clip_is_minicpmv(ctx)){
int max_slice_nums = 9;
std::vector<std::vector<clip_image_u8 *>> imgs = uhd_slice_image(img, max_slice_nums);
res_imgs->size = 0;
for (size_t i = 0; i < imgs.size(); ++i) {
for (size_t i = 0; i < imgs.size(); ++i){
res_imgs->size += imgs[i].size();
}
res_imgs->data = new clip_image_f32[res_imgs->size];
@ -2146,7 +2170,12 @@ int clip_n_patches(const struct clip_ctx * ctx) {
if (ctx->proj_type == PROJECTOR_TYPE_LDP || ctx->proj_type == PROJECTOR_TYPE_LDPV2) {
n_patches /= 4;
} else if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) {
n_patches = 96;
if (ctx->minicpmv_version == 2) {
n_patches = 96;
}
else if (ctx->minicpmv_version == 3) {
n_patches = 64;
}
}
return n_patches;
@ -2282,6 +2311,11 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
const int patch_size = hparams.patch_size;
const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size));
const int num_positions = num_patches + (ctx->has_class_embedding ? 1 : 0);
if(ctx->load_image_size==nullptr){
ctx->load_image_size= clip_image_size_init();
}
const int pos_w = ctx->load_image_size->width/patch_size;
const int pos_h = ctx->load_image_size->height/patch_size;
{
struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw");
@ -2316,8 +2350,18 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
// -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit/blob/d66538faeba44480d0bfaa42145eef26f9423199/modeling_siglip.py#L316
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
int* positions_data = (int*)malloc(ggml_nbytes(positions));
for (int i = 0; i < num_positions; i++) {
positions_data[i] = std::floor(70.0*i/num_positions);
int bucket_coords_h[70];
int bucket_coords_w[70];
for (int i = 0; i < pos_h; i++){
bucket_coords_h[i] = std::floor(70.0*i/pos_h);
}
for (int i = 0; i < pos_w; i++){
bucket_coords_w[i] = std::floor(70.0*i/pos_w);
}
for (int i = 0, id = 0; i < pos_h; i++){
for (int j = 0; j < pos_w; j++){
positions_data[id++] = bucket_coords_h[i]*70 + bucket_coords_w[j];
}
}
ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
free(positions_data);
@ -2328,12 +2372,13 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
// -> https://huggingface.co/Qwen/Qwen-VL/tree/main
// -> https://huggingface.co/Qwen/Qwen-VL/blob/0547ed36a86561e2e42fecec8fd0c4f6953e33c4/visual.py#L23
struct ggml_tensor * pos_embed = ggml_graph_get_tensor(gf, "pos_embed");
if(ctx->load_image_size==nullptr){
ctx->load_image_size= clip_image_size_init();
}
int pos_w = ctx->load_image_size->width/patch_size;
int pos_h = ctx->load_image_size->height/patch_size;
int embed_dim = 4096;
if (ctx->minicpmv_version == 2) {
embed_dim = 4096;
}
else if (ctx->minicpmv_version == 3) {
embed_dim = 3584;
}
auto pos_embed_t = get_2d_sincos_pos_embed(embed_dim, std::make_pair(pos_w, pos_h));
float * pos_embed_data = (float *)malloc(ggml_nbytes(pos_embed));
@ -2346,7 +2391,8 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
ggml_backend_tensor_set(pos_embed, pos_embed_data, 0, ggml_nbytes(pos_embed));
free(pos_embed_data);
}
} else {
}
else{
{
if (ctx->has_class_embedding) {
struct ggml_tensor * embeddings = ggml_graph_get_tensor(gf, "embeddings");
@ -2548,13 +2594,21 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
return ctx->vision_model.mm_3_b->ne[0];
}
if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) {
return 4096;
if (ctx->minicpmv_version == 2) {
return 4096;
}
else if (ctx->minicpmv_version == 3) {
return 3584;
}
}
std::string proj_type = PROJECTOR_TYPE_NAMES[ctx->proj_type];
throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str()));
}
bool clip_is_minicpmv(const struct clip_ctx * ctx) {
return ctx->has_minicpmv_projector;
int clip_is_minicpmv(const struct clip_ctx * ctx) {
if (ctx->has_minicpmv_projector) {
return ctx->minicpmv_version;
}
return 0;
}

View File

@ -85,7 +85,7 @@ CLIP_API bool clip_image_batch_encode(struct clip_ctx * ctx, int n_threads, cons
CLIP_API bool clip_model_quantize(const char * fname_inp, const char * fname_out, int itype);
CLIP_API bool clip_is_minicpmv(const struct clip_ctx * ctx);
CLIP_API int clip_is_minicpmv(const struct clip_ctx * ctx);
#ifdef __cplusplus
}

View File

@ -256,7 +256,14 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
load_image_size->width = img_res_v.data[i].nx;
load_image_size->height = img_res_v.data[i].ny;
clip_add_load_image_size(ctx_clip, load_image_size);
const bool encoded = clip_image_encode(ctx_clip, n_threads, only_v2_5_reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]);
bool encoded = false;
int has_minicpmv_projector = clip_is_minicpmv(ctx_clip);
if (has_minicpmv_projector == 2) {
encoded = clip_image_encode(ctx_clip, n_threads, only_v2_5_reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]);
}
else if (has_minicpmv_projector == 3) {
encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]);
}
if (!encoded) {
LOG_TEE("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int) i+1, (int) img_res_v.size);
return false;

View File

@ -134,7 +134,13 @@ static void process_image(struct llava_context * ctx_llava, struct llava_image_e
std::string system_prompt;
int idx = 0;
int num_image_embeds = embeds->n_image_pos / clip_n_patches(ctx_llava->ctx_clip);
system_prompt = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n";
int has_minicpmv_projector = clip_is_minicpmv(ctx_llava->ctx_clip);
if (has_minicpmv_projector == 2) {
system_prompt = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n";
}
else if (has_minicpmv_projector == 3) {
system_prompt = "<|im_start|>user\n";
}
LOG_TEE("%s: image token past: %d\n", __func__, n_past);
eval_string(ctx_llava->ctx_llama, (system_prompt+"<image>").c_str(), params->n_batch, &n_past, false);
process_eval_image_embed(ctx_llava, embeds, params->n_batch, &n_past, idx++);
@ -210,10 +216,24 @@ static struct llava_context * minicpmv_init(gpt_params * params, const std::stri
static struct llama_sampling_context * llama_init(struct llava_context * ctx_llava, gpt_params * params, std::string prompt, int &n_past, bool is_first = false){
std::string user_prompt = prompt;
if (!is_first) user_prompt = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" + prompt;
int has_minicpmv_projector = clip_is_minicpmv(ctx_llava->ctx_clip);
if (!is_first) {
if (has_minicpmv_projector == 2) {
user_prompt = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" + prompt;
}
else if (has_minicpmv_projector == 3) {
user_prompt = "<|im_start|>user\n" + prompt;
}
}
eval_string(ctx_llava->ctx_llama, user_prompt.c_str(), params->n_batch, &n_past, false);
eval_string(ctx_llava->ctx_llama, "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", params->n_batch, &n_past, false);
if (has_minicpmv_projector == 2) {
eval_string(ctx_llava->ctx_llama, "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", params->n_batch, &n_past, false);
}
else if (has_minicpmv_projector == 3) {
eval_string(ctx_llava->ctx_llama, "<|im_end|><|im_start|>assistant\n", params->n_batch, &n_past, false);
}
// generate the response
LOG_TEE("\n");

View File

@ -1,9 +1,416 @@
import argparse
# coding=utf-8
# Copyright 2024 Google AI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Siglip model. """
# Copied from HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit and add tgt_sizes
import os
import math
import warnings
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn.init import _calculate_fan_in_and_fan_out
from transformers.activations import ACT2FN
from transformers.modeling_utils import PreTrainedModel
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import (
logging,
)
from transformers.utils import logging
logger = logging.get_logger(__name__)
class SiglipVisionConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a
Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip
[google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
intermediate_size (`int`, *optional*, defaults to 3072):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
num_hidden_layers (`int`, *optional*, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
num_channels (`int`, *optional*, defaults to 3):
Number of channels in the input images.
image_size (`int`, *optional*, defaults to 224):
The size (resolution) of each image.
patch_size (`int`, *optional*, defaults to 16):
The size (resolution) of each patch.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the layer normalization layers.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
Example:
```python
>>> from transformers import SiglipVisionConfig, SiglipVisionModel
>>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 style configuration
>>> configuration = SiglipVisionConfig()
>>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration
>>> model = SiglipVisionModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "siglip_vision_model"
def __init__(
self,
hidden_size=768,
intermediate_size=3072,
num_hidden_layers=12,
num_attention_heads=12,
num_channels=3,
image_size=224,
patch_size=16,
hidden_act="gelu_pytorch_tanh",
layer_norm_eps=1e-6,
attention_dropout=0.0,
**kwargs,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_channels = num_channels
self.patch_size = patch_size
self.image_size = image_size
self.attention_dropout = attention_dropout
self.layer_norm_eps = layer_norm_eps
self.hidden_act = hidden_act
_CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224"
SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [
"google/siglip-base-patch16-224",
# See all SigLIP models at https://huggingface.co/models?filter=siglip
]
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
def _trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn(
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2,
)
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
if tensor.dtype in [torch.float16, torch.bfloat16]:
# The `erfinv_` op is not (yet?) defined in float16+cpu, bfloat16+gpu
og_dtype = tensor.dtype
tensor = tensor.to(torch.float32)
tensor.erfinv_()
tensor = tensor.to(og_dtype)
else:
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.0))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
if tensor.dtype == torch.float16:
# The `clamp_` op is not (yet?) defined in float16+cpu
tensor = tensor.to(torch.float32)
tensor.clamp_(min=a, max=b)
tensor = tensor.to(torch.float16)
else:
tensor.clamp_(min=a, max=b)
def trunc_normal_tf_(
tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0
):
"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \\leq \text{mean} \\leq b`.
NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
and the result is subsquently scaled and shifted by the mean and std args.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
"""
with torch.no_grad():
_trunc_normal_(tensor, 0, 1.0, a, b)
tensor.mul_(std).add_(mean)
def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
denom = fan_in
if mode == "fan_in":
denom = fan_in
elif mode == "fan_out":
denom = fan_out
elif mode == "fan_avg":
denom = (fan_in + fan_out) / 2
variance = scale / denom
if distribution == "truncated_normal":
# constant is stddev of standard normal truncated to (-2, 2)
trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
elif distribution == "normal":
with torch.no_grad():
tensor.normal_(std=math.sqrt(variance))
elif distribution == "uniform":
bound = math.sqrt(3 * variance)
with torch.no_grad():
tensor.uniform_(-bound, bound)
else:
raise ValueError(f"invalid distribution {distribution}")
def lecun_normal_(tensor):
variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
def default_flax_embed_init(tensor):
variance_scaling_(tensor, mode="fan_in", distribution="normal")
class SiglipVisionEmbeddings(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
padding="valid",
)
self.num_patches_per_side = self.image_size // self.patch_size
self.num_patches = self.num_patches_per_side**2
self.num_positions = self.num_patches
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
class SiglipAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
# Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
def __init__(self, config):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip
class SiglipMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->Siglip
class SiglipEncoderLayer(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.embed_dim = config.hidden_size
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self.self_attn = (
SiglipAttention(config)
)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = SiglipMLP(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
class SiglipPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = SiglipVisionConfig
base_model_prefix = "siglip"
supports_gradient_checkpointing = True
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, SiglipVisionEmbeddings):
width = self.config.hidden_size
nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width))
elif isinstance(module, nn.Embedding):
default_flax_embed_init(module.weight)
elif isinstance(module, SiglipAttention):
nn.init.normal_(module.q_proj.weight)
nn.init.normal_(module.k_proj.weight)
nn.init.normal_(module.v_proj.weight)
nn.init.normal_(module.out_proj.weight)
nn.init.zeros_(module.q_proj.bias)
nn.init.zeros_(module.k_proj.bias)
nn.init.zeros_(module.v_proj.bias)
nn.init.zeros_(module.out_proj.bias)
elif isinstance(module, SiglipMLP):
nn.init.normal_(module.fc1.weight)
nn.init.normal_(module.fc2.weight)
nn.init.normal_(module.fc1.bias, std=1e-6)
nn.init.normal_(module.fc2.bias, std=1e-6)
elif isinstance(module, (nn.Linear, nn.Conv2d)):
lecun_normal_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
SIGLIP_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`SiglipVisionConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
SIGLIP_VISION_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->Siglip
class SiglipEncoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
[`SiglipEncoderLayer`].
Args:
config: SiglipConfig
"""
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.config = config
self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
class SiglipVisionTransformer(SiglipPreTrainedModel):
config_class = SiglipVisionConfig
main_input_name = "pixel_values"
_supports_flash_attn_2 = True
def __init__(self, config: SiglipVisionConfig):
super().__init__(config)
self.config = config
embed_dim = config.hidden_size
self.embeddings = SiglipVisionEmbeddings(config)
self.encoder = SiglipEncoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> nn.Module:
return self.embeddings.patch_embedding
import argparse
import json
import re
import torch
import numpy as np
from gguf import *
from transformers.models.idefics2.modeling_idefics2 import Idefics2VisionTransformer, Idefics2VisionConfig
@ -94,6 +501,7 @@ default_image_mean = [0.48145466, 0.4578275, 0.40821073]
default_image_std = [0.26862954, 0.26130258, 0.27577711]
ap.add_argument('--image-mean', type=float, nargs='+', help='Mean of the images for normalization (overrides processor) ', default=None)
ap.add_argument('--image-std', type=float, nargs='+', help='Standard deviation of the images for normalization (overrides processor)', default=None)
ap.add_argument('--minicpmv_version', type=int, help='minicpmv_version: MiniCPM-V-2 use 1; MiniCPM-V-2.5 use 2; MiniCPM-V-2.6 use 3', default=2)
# with proper
args = ap.parse_args()
@ -135,6 +543,15 @@ if args.use_f32:
# model = CLIPModel.from_pretrained(dir_model)
# processor = CLIPProcessor.from_pretrained(dir_model)
minicpmv_version = args.minicpmv_version
emb_dim = 4096
if minicpmv_version == 1:
emb_dim = 2304
elif minicpmv_version == 2:
emb_dim = 4096
elif minicpmv_version == 3:
emb_dim = 3584
default_vision_config = {
"hidden_size": 1152,
"image_size": 980,
@ -144,8 +561,12 @@ default_vision_config = {
"num_hidden_layers": 27,
"patch_size": 14,
}
vision_config = Idefics2VisionConfig(**default_vision_config)
model = Idefics2VisionTransformer(vision_config)
if minicpmv_version == 3:
vision_config = SiglipVisionConfig(**default_vision_config)
model = SiglipVisionTransformer(vision_config)
processor = None
# if model.attn_pool is not None:
@ -158,6 +579,7 @@ fname_middle = None
has_text_encoder = True
has_vision_encoder = True
has_minicpmv_projector = False
if args.text_only:
fname_middle = "text-"
has_vision_encoder = False
@ -165,6 +587,7 @@ elif args.minicpmv_projector is not None:
fname_middle = "mmproj-"
has_text_encoder = False
has_minicpmv_projector = True
minicpmv_version = 3
elif args.vision_only:
fname_middle = "vision-"
has_text_encoder = False
@ -189,6 +612,7 @@ elif has_minicpmv_projector:
fout.add_description("image encoder for MiniCPM-V")
# add projector type
fout.add_string("clip.projector_type", "resampler")
fout.add_int32("clip.minicpmv_version", minicpmv_version)
else:
fout.add_description("two-tower CLIP model")
@ -274,11 +698,11 @@ def _replace_name_resampler(s, v):
if re.match("resampler.pos_embed", s):
return {
s: v,
re.sub("pos_embed", "pos_embed_k", s): torch.from_numpy(get_2d_sincos_pos_embed(4096, (70, 70))),
re.sub("pos_embed", "pos_embed_k", s): torch.from_numpy(get_2d_sincos_pos_embed(emb_dim, (70, 70))),
}
if re.match("resampler.proj", s):
return {
re.sub("proj", "pos_embed_k", s): torch.from_numpy(get_2d_sincos_pos_embed(4096, (70, 70))),
re.sub("proj", "pos_embed_k", s): torch.from_numpy(get_2d_sincos_pos_embed(emb_dim, (70, 70))),
re.sub("proj", "proj.weight", s): v.transpose(-1, -2).contiguous(),
}
if re.match("resampler.attn.in_proj_.*", s):

View File

@ -4,7 +4,7 @@ import torch
from transformers import AutoModel, AutoTokenizer
ap = argparse.ArgumentParser()
ap.add_argument("-m", "--model", help="Path to MiniCPM-V-2.5 model")
ap.add_argument("-m", "--model", help="Path to MiniCPM-V model")
args = ap.parse_args()
# find the model part that includes the the multimodal projector weights
@ -29,7 +29,6 @@ if len(clip_tensors) > 0:
f.write("{}\n")
config = model.llm.config
config._name_or_path = "openbmb/MiniCPM-Llama3-V-2.5"
config.auto_map = {
"AutoConfig": "configuration_minicpm.MiniCPMConfig",
"AutoModel": "modeling_minicpm.MiniCPMModel",
@ -40,7 +39,6 @@ config.auto_map = {
model.llm.save_pretrained(f"{args.model}/model")
tok = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
tok.save_pretrained(f"{args.model}/model")
# os.system(f"cp {args.model}/modeling_minicpm.py {args.model}/MiniCPM_l3/modeling_minicpm.py")
print("Done!")
print(f"Now you can convert {args.model} to a regular LLaMA GGUF file.")