mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-24 10:24:35 +00:00
clip : enable gpu backend (#4205)
* clip: enable CUDA backend * add missing kernels * add enough padding for alignment * remove ggml_repeat of clip.cpp * add metal backend * llava : fixes - avoid ggml_repeat - use GGML_USE_ instead of CLIP_USE_ macros - remove unused vars --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
91bb39cec7
commit
ce18d727a4
@ -24,7 +24,8 @@ endif()
|
|||||||
|
|
||||||
if (NOT MSVC)
|
if (NOT MSVC)
|
||||||
target_compile_options(llava PRIVATE -Wno-cast-qual) # stb_image.h
|
target_compile_options(llava PRIVATE -Wno-cast-qual) # stb_image.h
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(TARGET BUILD_INFO)
|
if(TARGET BUILD_INFO)
|
||||||
add_dependencies(llava BUILD_INFO)
|
add_dependencies(llava BUILD_INFO)
|
||||||
endif()
|
endif()
|
||||||
|
@ -16,12 +16,19 @@
|
|||||||
#include "clip.h"
|
#include "clip.h"
|
||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
#include "ggml-alloc.h"
|
#include "ggml-alloc.h"
|
||||||
|
#include "ggml-backend.h"
|
||||||
|
|
||||||
|
#ifdef GGML_USE_CUBLAS
|
||||||
|
#include "ggml-cuda.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef GGML_USE_METAL
|
||||||
|
#include "ggml-metal.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
#define STB_IMAGE_IMPLEMENTATION
|
#define STB_IMAGE_IMPLEMENTATION
|
||||||
#include "stb_image.h"
|
#include "stb_image.h"
|
||||||
|
|
||||||
#define CLIP_DEBUG
|
|
||||||
|
|
||||||
static std::string format(const char * fmt, ...) {
|
static std::string format(const char * fmt, ...) {
|
||||||
va_list ap;
|
va_list ap;
|
||||||
va_list ap2;
|
va_list ap2;
|
||||||
@ -196,20 +203,6 @@ struct clip_vision_model {
|
|||||||
struct ggml_tensor * mm_2_b;
|
struct ggml_tensor * mm_2_b;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Replacement for std::vector<uint8_t> that doesn't require zero-initialization.
|
|
||||||
struct clip_buffer {
|
|
||||||
uint8_t * data = NULL;
|
|
||||||
size_t size = 0;
|
|
||||||
|
|
||||||
void resize(size_t size) {
|
|
||||||
delete[] data;
|
|
||||||
data = new uint8_t[size];
|
|
||||||
this->size = size;
|
|
||||||
}
|
|
||||||
|
|
||||||
~clip_buffer() { delete[] data; }
|
|
||||||
};
|
|
||||||
|
|
||||||
struct clip_ctx {
|
struct clip_ctx {
|
||||||
bool has_text_encoder = false;
|
bool has_text_encoder = false;
|
||||||
bool has_vision_encoder = false;
|
bool has_vision_encoder = false;
|
||||||
@ -223,9 +216,10 @@ struct clip_ctx {
|
|||||||
struct gguf_context * ctx_gguf;
|
struct gguf_context * ctx_gguf;
|
||||||
|
|
||||||
// memory buffers to evaluate the model
|
// memory buffers to evaluate the model
|
||||||
clip_buffer buf_compute;
|
ggml_backend_buffer_t params_buffer = NULL;
|
||||||
clip_buffer buf_alloc;
|
ggml_backend_buffer_t compute_buffer = NULL;
|
||||||
ggml_allocr * alloc = NULL;
|
ggml_backend_t backend = NULL;
|
||||||
|
ggml_allocr * compute_alloc = NULL;
|
||||||
};
|
};
|
||||||
|
|
||||||
static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_image_f32_batch * imgs) {
|
static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_image_f32_batch * imgs) {
|
||||||
@ -252,25 +246,20 @@ static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_ima
|
|||||||
if(ctx->has_llava_projector) {
|
if(ctx->has_llava_projector) {
|
||||||
GGML_ASSERT(batch_size == 1);
|
GGML_ASSERT(batch_size == 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto & buf_compute = ctx->buf_compute;
|
|
||||||
|
|
||||||
struct ggml_init_params params = {
|
struct ggml_init_params params = {
|
||||||
/*.mem_size =*/ buf_compute.size,
|
/*.mem_size =*/ GGML_DEFAULT_GRAPH_SIZE * ggml_tensor_overhead() + ggml_graph_overhead(),
|
||||||
/*.mem_buffer =*/ buf_compute.data,
|
/*.mem_buffer =*/ NULL,
|
||||||
/*.no_alloc =*/ false,
|
/*.no_alloc =*/ true,
|
||||||
};
|
};
|
||||||
|
|
||||||
params.no_alloc = true;
|
|
||||||
|
|
||||||
struct ggml_context * ctx0 = ggml_init(params);
|
struct ggml_context * ctx0 = ggml_init(params);
|
||||||
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
|
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||||
|
|
||||||
struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size, image_size, 3, batch_size);
|
struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size, image_size, 3, batch_size);
|
||||||
ggml_allocr_alloc(ctx->alloc, inp_raw);
|
ggml_allocr_alloc(ctx->compute_alloc, inp_raw);
|
||||||
|
|
||||||
if (!ggml_allocr_is_measure(ctx->alloc)) {
|
if (!ggml_allocr_is_measure(ctx->compute_alloc)) {
|
||||||
float * data = (float *)ggml_get_data(inp_raw);
|
float * data = (float *)malloc(ggml_nbytes(inp_raw));
|
||||||
|
|
||||||
for (size_t i = 0; i < imgs->size; i++) {
|
for (size_t i = 0; i < imgs->size; i++) {
|
||||||
const int nx = imgs->data[i].nx;
|
const int nx = imgs->data[i].nx;
|
||||||
@ -289,6 +278,8 @@ static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_ima
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
ggml_backend_tensor_set(inp_raw, data, 0, ggml_nbytes(inp_raw));
|
||||||
|
free(data);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
|
struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
|
||||||
@ -298,36 +289,39 @@ static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_ima
|
|||||||
|
|
||||||
// concat class_embeddings and patch_embeddings
|
// concat class_embeddings and patch_embeddings
|
||||||
struct ggml_tensor * embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size);
|
struct ggml_tensor * embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size);
|
||||||
ggml_allocr_alloc(ctx->alloc, embeddings);
|
ggml_allocr_alloc(ctx->compute_alloc, embeddings);
|
||||||
if (!ggml_allocr_is_measure(ctx->alloc)) {
|
if (!ggml_allocr_is_measure(ctx->compute_alloc)) {
|
||||||
ggml_set_zero(embeddings);
|
void* zero_mem = malloc(ggml_nbytes(embeddings));
|
||||||
|
memset(zero_mem, 0, ggml_nbytes(embeddings));
|
||||||
|
ggml_backend_tensor_set(embeddings, zero_mem, 0, ggml_nbytes(embeddings));
|
||||||
|
free(zero_mem);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * temp = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, 1, batch_size);
|
embeddings = ggml_acc(ctx0, embeddings, model.class_embedding,
|
||||||
ggml_allocr_alloc(ctx->alloc, temp);
|
embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], 0);
|
||||||
|
|
||||||
embeddings = ggml_acc(ctx0, embeddings, ggml_repeat(ctx0, model.class_embedding, temp), embeddings->nb[1],
|
embeddings = ggml_acc(ctx0, embeddings, inp,
|
||||||
embeddings->nb[2], embeddings->nb[3], 0);
|
embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]);
|
||||||
embeddings =
|
|
||||||
ggml_acc(ctx0, embeddings, inp, embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]);
|
|
||||||
|
|
||||||
struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions);
|
struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions);
|
||||||
ggml_allocr_alloc(ctx->alloc, positions);
|
ggml_allocr_alloc(ctx->compute_alloc, positions);
|
||||||
if (!ggml_allocr_is_measure(ctx->alloc)) {
|
if (!ggml_allocr_is_measure(ctx->compute_alloc)) {
|
||||||
|
int* positions_data = (int*)malloc(ggml_nbytes(positions));
|
||||||
for (int i = 0; i < num_positions; i++) {
|
for (int i = 0; i < num_positions; i++) {
|
||||||
ggml_set_i32_1d(positions, i, i);
|
positions_data[i] = i;
|
||||||
}
|
}
|
||||||
|
ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
|
||||||
|
free(positions_data);
|
||||||
}
|
}
|
||||||
|
|
||||||
embeddings =
|
embeddings =
|
||||||
ggml_add(ctx0, embeddings, ggml_repeat(ctx0, ggml_get_rows(ctx0, model.position_embeddings, positions), embeddings));
|
ggml_add(ctx0, embeddings, ggml_get_rows(ctx0, model.position_embeddings, positions));
|
||||||
|
|
||||||
// pre-layernorm
|
// pre-layernorm
|
||||||
{
|
{
|
||||||
embeddings = ggml_norm(ctx0, embeddings, eps);
|
embeddings = ggml_norm(ctx0, embeddings, eps);
|
||||||
|
|
||||||
embeddings = ggml_add(ctx0, ggml_mul(ctx0, ggml_repeat(ctx0, model.pre_ln_w, embeddings), embeddings),
|
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.pre_ln_w), model.pre_ln_b);
|
||||||
ggml_repeat(ctx0, model.pre_ln_b, embeddings));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// loop over layers
|
// loop over layers
|
||||||
@ -340,15 +334,15 @@ static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_ima
|
|||||||
{
|
{
|
||||||
cur = ggml_norm(ctx0, cur, eps);
|
cur = ggml_norm(ctx0, cur, eps);
|
||||||
|
|
||||||
cur = ggml_add(ctx0, ggml_mul(ctx0, ggml_repeat(ctx0, model.layers[il].ln_1_w, cur), cur),
|
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_1_w),
|
||||||
ggml_repeat(ctx0, model.layers[il].ln_1_b, cur));
|
model.layers[il].ln_1_b);
|
||||||
}
|
}
|
||||||
|
|
||||||
// self-attention
|
// self-attention
|
||||||
{
|
{
|
||||||
|
|
||||||
struct ggml_tensor * Q =
|
struct ggml_tensor * Q =
|
||||||
ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].q_b, cur), ggml_mul_mat(ctx0, model.layers[il].q_w, cur));
|
ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].q_w, cur), model.layers[il].q_b);
|
||||||
|
|
||||||
Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head));
|
Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head));
|
||||||
Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_positions, batch_size);
|
Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_positions, batch_size);
|
||||||
@ -356,14 +350,14 @@ static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_ima
|
|||||||
Q = ggml_reshape_3d(ctx0, Q, d_head, num_positions, n_head * batch_size);
|
Q = ggml_reshape_3d(ctx0, Q, d_head, num_positions, n_head * batch_size);
|
||||||
|
|
||||||
struct ggml_tensor * K =
|
struct ggml_tensor * K =
|
||||||
ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].k_b, cur), ggml_mul_mat(ctx0, model.layers[il].k_w, cur));
|
ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].k_w, cur), model.layers[il].k_b);
|
||||||
|
|
||||||
K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size);
|
K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size);
|
||||||
K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
|
K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
|
||||||
K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size);
|
K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size);
|
||||||
|
|
||||||
struct ggml_tensor * V =
|
struct ggml_tensor * V =
|
||||||
ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].v_b, cur), ggml_mul_mat(ctx0, model.layers[il].v_w, cur));
|
ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].v_w, cur), model.layers[il].v_b);
|
||||||
|
|
||||||
V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size);
|
V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size);
|
||||||
V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
|
V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
|
||||||
@ -379,7 +373,7 @@ static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_ima
|
|||||||
}
|
}
|
||||||
|
|
||||||
// attention output
|
// attention output
|
||||||
cur = ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].o_b, cur), ggml_mul_mat(ctx0, model.layers[il].o_w, cur));
|
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].o_w, cur), model.layers[il].o_b);
|
||||||
|
|
||||||
// re-add the layer input, e.g., residual
|
// re-add the layer input, e.g., residual
|
||||||
cur = ggml_add(ctx0, cur, embeddings);
|
cur = ggml_add(ctx0, cur, embeddings);
|
||||||
@ -390,12 +384,11 @@ static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_ima
|
|||||||
{
|
{
|
||||||
cur = ggml_norm(ctx0, cur, eps);
|
cur = ggml_norm(ctx0, cur, eps);
|
||||||
|
|
||||||
cur = ggml_add(ctx0, ggml_mul(ctx0, ggml_repeat(ctx0, model.layers[il].ln_2_w, cur), cur),
|
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_2_w), model.layers[il].ln_2_b);
|
||||||
ggml_repeat(ctx0, model.layers[il].ln_2_b, cur));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur);
|
cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur);
|
||||||
cur = ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].ff_i_b, cur), cur);
|
cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b);
|
||||||
|
|
||||||
if (ctx->use_gelu) {
|
if (ctx->use_gelu) {
|
||||||
cur = ggml_gelu_inplace(ctx0, cur);
|
cur = ggml_gelu_inplace(ctx0, cur);
|
||||||
@ -404,7 +397,7 @@ static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_ima
|
|||||||
}
|
}
|
||||||
|
|
||||||
cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur);
|
cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur);
|
||||||
cur = ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].ff_o_b, cur), cur);
|
cur = ggml_add(ctx0, cur, model.layers[il].ff_o_b);
|
||||||
|
|
||||||
// residual 2
|
// residual 2
|
||||||
cur = ggml_add(ctx0, embeddings, cur);
|
cur = ggml_add(ctx0, embeddings, cur);
|
||||||
@ -417,23 +410,26 @@ static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_ima
|
|||||||
embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]);
|
embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]);
|
||||||
|
|
||||||
struct ggml_tensor * patches = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches);
|
struct ggml_tensor * patches = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches);
|
||||||
ggml_allocr_alloc(ctx->alloc, patches);
|
ggml_allocr_alloc(ctx->compute_alloc, patches);
|
||||||
if (!ggml_allocr_is_measure(ctx->alloc)) {
|
if (!ggml_allocr_is_measure(ctx->compute_alloc)) {
|
||||||
for (int i = 0; i < num_patches; ++i) {
|
int* patches_data = (int*)malloc(ggml_nbytes(patches));
|
||||||
ggml_set_i32_1d(patches, i, i+1);
|
for (int i = 0; i < num_positions; i++) {
|
||||||
|
patches_data[i] = i + 1;
|
||||||
}
|
}
|
||||||
|
ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches));
|
||||||
|
free(patches_data);
|
||||||
}
|
}
|
||||||
|
|
||||||
embeddings = ggml_get_rows(ctx0, embeddings, patches);
|
embeddings = ggml_get_rows(ctx0, embeddings, patches);
|
||||||
|
|
||||||
// mm projection 0
|
// mm projection 0
|
||||||
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
|
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
|
||||||
embeddings = ggml_add(ctx0, ggml_repeat(ctx0, model.mm_0_b, embeddings), embeddings);
|
embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
|
||||||
|
|
||||||
embeddings = ggml_gelu(ctx0, embeddings);
|
embeddings = ggml_gelu(ctx0, embeddings);
|
||||||
|
|
||||||
embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings);
|
embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings);
|
||||||
embeddings = ggml_add(ctx0, ggml_repeat(ctx0, model.mm_2_b, embeddings), embeddings);
|
embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
|
||||||
}
|
}
|
||||||
|
|
||||||
// build the graph
|
// build the graph
|
||||||
@ -446,7 +442,6 @@ static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_ima
|
|||||||
|
|
||||||
// read and create ggml_context containing the tensors and their data
|
// read and create ggml_context containing the tensors and their data
|
||||||
struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
||||||
|
|
||||||
struct ggml_context * meta = NULL;
|
struct ggml_context * meta = NULL;
|
||||||
|
|
||||||
struct gguf_init_params params = {
|
struct gguf_init_params params = {
|
||||||
@ -479,7 +474,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
|||||||
printf("%s: ftype: %s\n", __func__, ftype_str.c_str());
|
printf("%s: ftype: %s\n", __func__, ftype_str.c_str());
|
||||||
printf("\n");
|
printf("\n");
|
||||||
}
|
}
|
||||||
|
const int n_tensors = gguf_get_n_tensors(ctx);
|
||||||
// kv
|
// kv
|
||||||
if (verbosity >= 3) {
|
if (verbosity >= 3) {
|
||||||
const int n_kv = gguf_get_n_kv(ctx);
|
const int n_kv = gguf_get_n_kv(ctx);
|
||||||
@ -493,27 +488,38 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// data
|
// data
|
||||||
size_t ctx_size = 0;
|
size_t buffer_size = 0;
|
||||||
{
|
{
|
||||||
const int n_tensors = gguf_get_n_tensors(ctx);
|
|
||||||
|
|
||||||
for (int i = 0; i < n_tensors; ++i) {
|
for (int i = 0; i < n_tensors; ++i) {
|
||||||
const char * name = gguf_get_tensor_name(ctx, i);
|
const char * name = gguf_get_tensor_name(ctx, i);
|
||||||
const size_t offset = gguf_get_tensor_offset(ctx, i);
|
const size_t offset = gguf_get_tensor_offset(ctx, i);
|
||||||
|
|
||||||
struct ggml_tensor * cur = ggml_get_tensor(meta, name);
|
struct ggml_tensor * cur = ggml_get_tensor(meta, name);
|
||||||
ctx_size += sizeof(struct ggml_tensor) + GGML_OBJECT_SIZE;
|
|
||||||
size_t tensor_size = ggml_nbytes(cur);
|
size_t tensor_size = ggml_nbytes(cur);
|
||||||
size_t padded_size = ggml_nbytes_pad(cur);
|
buffer_size += tensor_size;
|
||||||
ctx_size += padded_size;
|
|
||||||
if (verbosity >= 3) {
|
if (verbosity >= 3) {
|
||||||
printf("%s: tensor[%d]: n_dims = %d, name = %s, tensor_size=%zu, padded_size=%zu, offset=%zu\n", __func__, i,
|
printf("%s: tensor[%d]: n_dims = %d, name = %s, tensor_size=%zu, offset=%zu\n", __func__, i,
|
||||||
ggml_n_dims(cur), cur->name, tensor_size, padded_size, offset);
|
ggml_n_dims(cur), cur->name, tensor_size, offset);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
buffer_size += n_tensors * 128 /* CLIP PADDING */;
|
||||||
|
|
||||||
clip_ctx * new_clip = new clip_ctx;
|
clip_ctx * new_clip = new clip_ctx;
|
||||||
|
#ifdef GGML_USE_CUBLAS
|
||||||
|
new_clip->backend = ggml_backend_cuda_init(0);
|
||||||
|
printf("%s: CLIP using CUDA backend\n", __func__);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef GGML_USE_METAL
|
||||||
|
new_clip->backend = ggml_backend_metal_init();
|
||||||
|
printf("%s: CLIP using Metal backend\n", __func__);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
if (!new_clip->backend) {
|
||||||
|
new_clip->backend = ggml_backend_cpu_init();
|
||||||
|
printf("%s: CLIP using CPU backend\n", __func__);
|
||||||
|
}
|
||||||
|
|
||||||
// model size and capabilities
|
// model size and capabilities
|
||||||
{
|
{
|
||||||
@ -539,17 +545,20 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
|||||||
printf("%s: text_encoder: %d\n", __func__, new_clip->has_text_encoder);
|
printf("%s: text_encoder: %d\n", __func__, new_clip->has_text_encoder);
|
||||||
printf("%s: vision_encoder: %d\n", __func__, new_clip->has_vision_encoder);
|
printf("%s: vision_encoder: %d\n", __func__, new_clip->has_vision_encoder);
|
||||||
printf("%s: llava_projector: %d\n", __func__, new_clip->has_llava_projector);
|
printf("%s: llava_projector: %d\n", __func__, new_clip->has_llava_projector);
|
||||||
printf("%s: model size: %.2f MB\n", __func__, (ctx_size / 1024.0 / 1024.0));
|
printf("%s: model size: %.2f MB\n", __func__, buffer_size / 1024.0 / 1024.0);
|
||||||
printf("%s: metadata size: %.2f MB\n", __func__, ggml_get_mem_size(meta) / 1024.0 / 1024.0);
|
printf("%s: metadata size: %.2f MB\n", __func__, ggml_get_mem_size(meta) / 1024.0 / 1024.0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
printf("%s: params backend buffer size = % 6.2f MB (%i tensors)\n", __func__, buffer_size / (1024.0 * 1024.0), n_tensors);
|
||||||
|
|
||||||
// load tensors
|
// load tensors
|
||||||
{
|
{
|
||||||
|
std::vector<uint8_t> read_buf;
|
||||||
struct ggml_init_params params = {
|
struct ggml_init_params params = {
|
||||||
/*.mem_size =*/ ctx_size,
|
/*.mem_size =*/ (n_tensors + 1) * ggml_tensor_overhead(),
|
||||||
/*.mem_buffer =*/ NULL,
|
/*.mem_buffer =*/ NULL,
|
||||||
/*.no_alloc =*/ false,
|
/*.no_alloc =*/ true,
|
||||||
};
|
};
|
||||||
|
|
||||||
new_clip->ctx = ggml_init(params);
|
new_clip->ctx = ggml_init(params);
|
||||||
@ -566,13 +575,21 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int n_tensors = gguf_get_n_tensors(ctx);
|
// add tensors to context
|
||||||
for (int i = 0; i < n_tensors; ++i) {
|
for (int i = 0; i < n_tensors; ++i) {
|
||||||
const char * name = gguf_get_tensor_name(ctx, i);
|
const char * name = gguf_get_tensor_name(ctx, i);
|
||||||
struct ggml_tensor * t = ggml_get_tensor(meta, name);
|
struct ggml_tensor * t = ggml_get_tensor(meta, name);
|
||||||
struct ggml_tensor * cur = ggml_dup_tensor(new_clip->ctx, t);
|
struct ggml_tensor * cur = ggml_dup_tensor(new_clip->ctx, t);
|
||||||
ggml_set_name(cur, name);
|
ggml_set_name(cur, name);
|
||||||
|
}
|
||||||
|
|
||||||
|
// alloc memory and offload data
|
||||||
|
new_clip->params_buffer = ggml_backend_alloc_buffer(new_clip->backend, buffer_size);
|
||||||
|
ggml_allocr* alloc = ggml_allocr_new_from_buffer(new_clip->params_buffer);
|
||||||
|
for (int i = 0; i < n_tensors; ++i) {
|
||||||
|
const char * name = gguf_get_tensor_name(ctx, i);
|
||||||
|
struct ggml_tensor * cur = ggml_get_tensor(new_clip->ctx, name);
|
||||||
|
ggml_allocr_alloc(alloc, cur);
|
||||||
const size_t offset = gguf_get_data_offset(ctx) + gguf_get_tensor_offset(ctx, i);
|
const size_t offset = gguf_get_data_offset(ctx) + gguf_get_tensor_offset(ctx, i);
|
||||||
fin.seekg(offset, std::ios::beg);
|
fin.seekg(offset, std::ios::beg);
|
||||||
if (!fin) {
|
if (!fin) {
|
||||||
@ -580,10 +597,22 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
|||||||
clip_free(new_clip);
|
clip_free(new_clip);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
int num_bytes = ggml_nbytes(cur);
|
||||||
fin.read(reinterpret_cast<char *>(cur->data), ggml_nbytes(t));
|
if (ggml_backend_is_cpu(new_clip->backend)
|
||||||
|
#ifdef GGML_USE_METAL
|
||||||
|
|| ggml_backend_is_metal(new_clip->backend)
|
||||||
|
#endif
|
||||||
|
) {
|
||||||
|
// for the CPU and Metal backend, we can read directly into the tensor
|
||||||
|
fin.read(reinterpret_cast<char *>(cur->data), num_bytes);
|
||||||
|
} else {
|
||||||
|
// read into a temporary buffer first, then copy to device memory
|
||||||
|
read_buf.resize(num_bytes);
|
||||||
|
fin.read(reinterpret_cast<char *>(read_buf.data()), num_bytes);
|
||||||
|
ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
ggml_allocr_free(alloc);
|
||||||
fin.close();
|
fin.close();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -657,18 +686,16 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
|||||||
|
|
||||||
// measure mem requirement and allocate
|
// measure mem requirement and allocate
|
||||||
{
|
{
|
||||||
static const size_t tensor_alignment = 32;
|
new_clip->compute_alloc = ggml_allocr_new_measure_from_backend(new_clip->backend);
|
||||||
new_clip->buf_compute.resize(ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead());
|
|
||||||
new_clip->alloc = ggml_allocr_new_measure(tensor_alignment);
|
|
||||||
clip_image_f32_batch batch;
|
clip_image_f32_batch batch;
|
||||||
batch.size = 1;
|
batch.size = 1;
|
||||||
ggml_cgraph * gf = clip_image_build_graph(new_clip, &batch);
|
ggml_cgraph * gf = clip_image_build_graph(new_clip, &batch);
|
||||||
size_t alloc_size = ggml_allocr_alloc_graph(new_clip->alloc, gf) + tensor_alignment;
|
size_t compute_memory_buffer_size = ggml_allocr_alloc_graph(new_clip->compute_alloc, gf);
|
||||||
ggml_allocr_free(new_clip->alloc);
|
ggml_allocr_free(new_clip->compute_alloc);
|
||||||
new_clip->buf_alloc.resize(alloc_size);
|
new_clip->compute_buffer = ggml_backend_alloc_buffer(new_clip->backend, compute_memory_buffer_size);
|
||||||
new_clip->alloc = ggml_allocr_new(new_clip->buf_alloc.data, new_clip->buf_alloc.size, tensor_alignment);
|
new_clip->compute_alloc = ggml_allocr_new_from_buffer(new_clip->compute_buffer);
|
||||||
|
|
||||||
printf("%s: total allocated memory: %.2f MB\n", __func__, (new_clip->buf_compute.size + alloc_size)/1024.0/1024.0);
|
printf("%s: compute allocated memory: %.2f MB\n", __func__, compute_memory_buffer_size /1024.0/1024.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
return new_clip;
|
return new_clip;
|
||||||
@ -852,29 +879,29 @@ bool clip_image_batch_encode(const clip_ctx * ctx, const int n_threads, const cl
|
|||||||
}
|
}
|
||||||
|
|
||||||
// reset alloc buffer to clean the memory from previous invocations
|
// reset alloc buffer to clean the memory from previous invocations
|
||||||
ggml_allocr_reset(ctx->alloc);
|
ggml_allocr_reset(ctx->compute_alloc);
|
||||||
|
|
||||||
// build the inference graph
|
// build the inference graph
|
||||||
ggml_cgraph * gf = clip_image_build_graph(ctx, imgs);
|
ggml_cgraph * gf = clip_image_build_graph(ctx, imgs);
|
||||||
ggml_allocr_alloc_graph(ctx->alloc, gf);
|
ggml_allocr_alloc_graph(ctx->compute_alloc, gf);
|
||||||
|
|
||||||
struct ggml_cplan plan = ggml_graph_plan(gf, n_threads);
|
if (ggml_backend_is_cpu(ctx->backend)) {
|
||||||
if (plan.work_size > 0) {
|
ggml_backend_cpu_set_n_threads(ctx->backend, n_threads);
|
||||||
plan.work_data = (uint8_t *)malloc(plan.work_size);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_graph_compute(gf, &plan);
|
#ifdef GGML_USE_METAL
|
||||||
|
if (ggml_backend_is_metal(ctx->backend)) {
|
||||||
|
ggml_backend_metal_set_n_cb(ctx->backend, n_threads);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
ggml_backend_graph_compute(ctx->backend, gf);
|
||||||
|
|
||||||
// the last node is the embedding tensor
|
// the last node is the embedding tensor
|
||||||
struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 1];
|
struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 1];
|
||||||
|
|
||||||
// copy the embeddings to the location passed by the user
|
// copy the embeddings to the location passed by the user
|
||||||
memcpy(vec, ggml_get_data_f32(embeddings), ggml_nbytes(embeddings));
|
ggml_backend_tensor_get(embeddings, vec, 0, ggml_nbytes(embeddings));
|
||||||
|
|
||||||
if (plan.work_size > 0) {
|
|
||||||
free(plan.work_data);
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1045,8 +1072,8 @@ bool clip_model_quantize(const char * fname_inp, const char * fname_out, const i
|
|||||||
gguf_free(ctx_out);
|
gguf_free(ctx_out);
|
||||||
|
|
||||||
{
|
{
|
||||||
printf("%s: original size = %8.2f MB\n", __func__, total_size_org / 1024.0 / 1024.0);
|
printf("%s: original size = %8.2f MB\n", __func__, total_size_org / 1024.0 / 1024.0);
|
||||||
printf("%s: quantized size = %8.2f MB\n", __func__, total_size_new / 1024.0 / 1024.0);
|
printf("%s: quantized size = %8.2f MB\n", __func__, total_size_new / 1024.0 / 1024.0);
|
||||||
|
|
||||||
int64_t sum_all = 0;
|
int64_t sum_all = 0;
|
||||||
for (size_t i = 0; i < hist_all.size(); ++i) {
|
for (size_t i = 0; i < hist_all.size(); ++i) {
|
||||||
|
Loading…
Reference in New Issue
Block a user