From 80ea089d771f0c2d97afa8bead80ded412f600d7 Mon Sep 17 00:00:00 2001 From: Douglas Hanley Date: Fri, 21 Jun 2024 00:38:22 -0500 Subject: [PATCH] llama : allow pooled embeddings on any model (#7477) * create append_pooling operation; allow to specify attention_type; add last token pooling; update examples * find result_norm/result_embd tensors properly; update output allocation logic * only use embd output for pooling_type NONE * get rid of old causal_attn accessor * take out attention_type; add in llama_set_embeddings * bypass logits when doing non-NONE pooling --- common/common.cpp | 2 + examples/embedding/embedding.cpp | 21 +++-- examples/gritlm/gritlm.cpp | 6 +- examples/retrieval/retrieval.cpp | 13 ++- llama.cpp | 152 ++++++++++++++++++++----------- llama.h | 6 +- 6 files changed, 130 insertions(+), 70 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 9c23d001b..64f160af1 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -541,6 +541,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa /**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; } else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; } else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; } + else if (value == "last") { params.pooling_type = LLAMA_POOLING_TYPE_LAST; } else { invalid_param = true; } return true; } @@ -1869,6 +1870,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "backend" }); options.push_back({ "*", " --rpc SERVERS", "comma separated list of RPC servers" }); + if (llama_supports_mlock()) { options.push_back({ "*", " --mlock", "force system to keep model in RAM rather than swapping or compressing" }); } diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 244751e00..b4b73c017 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -17,9 +17,10 @@ static std::vector split_lines(const std::string & s) { return lines; } -static void batch_add_seq(llama_batch & batch, const std::vector & tokens, int seq_id) { - for (size_t i = 0; i < tokens.size(); i++) { - llama_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1); +static void batch_add_seq(llama_batch & batch, const std::vector & tokens, llama_seq_id seq_id) { + size_t n_tokens = tokens.size(); + for (size_t i = 0; i < n_tokens; i++) { + llama_batch_add(batch, tokens[i], i, { seq_id }, true); } } @@ -40,13 +41,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu // try to get sequence embeddings - supported only when pooling_type is not NONE const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); - if (embd == NULL) { - embd = llama_get_embeddings_ith(ctx, i); - if (embd == NULL) { - fprintf(stderr, "%s: failed to get embeddings for token %d\n", __func__, i); - continue; - } - } + GGML_ASSERT(embd != NULL && "failed to get sequence embeddings"); float * out = output + batch.seq_id[i][0] * n_embd; //TODO: I would also add a parameter here to enable normalization or not. @@ -97,6 +92,12 @@ int main(int argc, char ** argv) { const int n_ctx_train = llama_n_ctx_train(model); const int n_ctx = llama_n_ctx(ctx); + const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); + if (pooling_type == LLAMA_POOLING_TYPE_NONE) { + fprintf(stderr, "%s: error: pooling type NONE not supported\n", __func__); + return 1; + } + if (n_ctx > n_ctx_train) { fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n", __func__, n_ctx_train, n_ctx); diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index 213515791..2c61c2e1e 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -44,6 +44,7 @@ static std::vector> encode(llama_context * ctx, const std::ve // clear previous kv_cache values (irrelevant for embeddings) llama_kv_cache_clear(ctx); + llama_set_embeddings(ctx, true); llama_set_causal_attn(ctx, false); // run model @@ -98,7 +99,9 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo llama_token eos_token = llama_token_eos(mdl); llama_kv_cache_clear(ctx); + llama_set_embeddings(ctx, false); llama_set_causal_attn(ctx, true); + llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1); std::vector inputs = llama_tokenize(mdl, prompt, false, true); @@ -166,8 +169,7 @@ int main(int argc, char * argv[]) { llama_model * mdl = llama_load_model_from_file(params.model.c_str(), mparams); - // create new context - set to embedding mode - cparams.embeddings = true; + // create generation context llama_context * ctx = llama_new_context_with_model(mdl, cparams); // ### Embedding/Representation ### diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index 55b7b2f70..eb89d16da 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -73,9 +73,10 @@ static std::vector chunk_file(const std::string & filename, int chunk_siz return chunks; } -static void batch_add_seq(llama_batch & batch, const std::vector & tokens, int seq_id) { - for (size_t i = 0; i < tokens.size(); i++) { - llama_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1); +static void batch_add_seq(llama_batch & batch, const std::vector & tokens, llama_seq_id seq_id) { + size_t n_tokens = tokens.size(); + for (size_t i = 0; i < n_tokens; i++) { + llama_batch_add(batch, tokens[i], i, { seq_id }, true); } } @@ -160,6 +161,12 @@ int main(int argc, char ** argv) { const int n_ctx_train = llama_n_ctx_train(model); const int n_ctx = llama_n_ctx(ctx); + const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); + if (pooling_type == LLAMA_POOLING_TYPE_NONE) { + fprintf(stderr, "%s: error: pooling type NONE not supported\n", __func__); + return 1; + } + if (n_ctx > n_ctx_train) { fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n", __func__, n_ctx_train, n_ctx); diff --git a/llama.cpp b/llama.cpp index 8818c6928..9ca0b7479 100644 --- a/llama.cpp +++ b/llama.cpp @@ -7649,6 +7649,50 @@ struct llm_build_context { return lctx.inp_s_seq; } + struct ggml_cgraph * append_pooling(struct ggml_cgraph * gf) { + // find result_norm tensor for input + struct ggml_tensor * inp = nullptr; + for (int i = gf->n_nodes - 1; i >= 0; --i) { + inp = gf->nodes[i]; + if (strcmp(inp->name, "result_norm") == 0 || strcmp(inp->name, "result_embd") == 0) { + break; + } else { + inp = nullptr; + } + } + GGML_ASSERT(inp != nullptr && "missing result_norm/result_embd tensor"); + + struct ggml_tensor * cur; + + switch (pooling_type) { + case LLAMA_POOLING_TYPE_MEAN: + { + struct ggml_tensor * inp_mean = build_inp_mean(); + cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, inp)), inp_mean); + } break; + case LLAMA_POOLING_TYPE_CLS: + case LLAMA_POOLING_TYPE_LAST: + { + struct ggml_tensor * inp_cls = build_inp_cls(); + cur = ggml_get_rows(ctx0, inp, inp_cls); + } break; + case LLAMA_POOLING_TYPE_NONE: + { + cur = inp; + } break; + default: + { + GGML_ASSERT(false && "unknown pooling type"); + } break; + } + + cb(cur, "result_embd_pooled", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + struct ggml_cgraph * build_llama() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); @@ -8629,8 +8673,6 @@ struct llm_build_context { if (model.arch != LLM_ARCH_JINA_BERT_V2) { inp_pos = build_inp_pos(); } - struct ggml_tensor * inp_mean = build_inp_mean(); - struct ggml_tensor * inp_cls = build_inp_cls(); // construct input embeddings (token, type, position) inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); @@ -8805,28 +8847,6 @@ struct llm_build_context { cur = inpL; cb(cur, "result_embd", -1); - // pooling layer - switch (pooling_type) { - case LLAMA_POOLING_TYPE_NONE: - { - // nop - } break; - case LLAMA_POOLING_TYPE_MEAN: - { - cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), inp_mean); - cb(cur, "result_embd_pooled", -1); - } break; - case LLAMA_POOLING_TYPE_CLS: - { - cur = ggml_get_rows(ctx0, cur, inp_cls); - cb(cur, "result_embd_pooled", -1); - } break; - case LLAMA_POOLING_TYPE_UNSPECIFIED: - { - GGML_ASSERT(false && "Invalid pooling type"); - } break; - } - ggml_build_forward_expand(gf, cur); return gf; @@ -11911,6 +11931,11 @@ static struct ggml_cgraph * llama_build_graph( GGML_ASSERT(false); } + // add on pooling layer + if (lctx.cparams.embeddings) { + result = llm.append_pooling(result); + } + llm.free(); return result; @@ -12000,7 +12025,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { // (!a || b) is a logical implication (a -> b) // !hparams.causal_attn -> !cparams.causal_attn (hparams.causal_attn || !cparams.causal_attn) && - "causal attention with embedding models is not supported" + "causal attention is not supported by this model" ); if (lctx.inp_KQ_mask) { @@ -12132,6 +12157,37 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } + if (cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) { + const int64_t n_tokens = batch.n_tokens; + + GGML_ASSERT(lctx.inp_cls); + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer)); + + uint32_t * data = (uint32_t *) lctx.inp_cls->data; + memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls)); + + std::vector last_pos(n_tokens, -1); + std::vector last_row(n_tokens, -1); + + for (int i = 0; i < n_tokens; ++i) { + const llama_seq_id seq_id = batch.seq_id[i][0]; + const llama_pos pos = batch.pos[i]; + + GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST"); + + if (pos >= last_pos[seq_id]) { + last_pos[seq_id] = pos; + last_row[seq_id] = i; + } + } + + for (int i = 0; i < n_tokens; ++i) { + if (last_row[i] >= 0) { + data[i] = last_row[i]; + } + } + } + if (kv_self.recurrent) { const int64_t n_kv = kv_self.n; @@ -12193,8 +12249,8 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) { const auto n_embd = hparams.n_embd; // TODO: use a per-batch flag for logits presence instead - const bool has_logits = cparams.causal_attn; - const bool has_embd = cparams.embeddings && (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE); + const bool has_logits = !cparams.embeddings; + const bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE); const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0; const size_t embd_size = has_embd ? n_embd*n_outputs_max : 0; @@ -12324,11 +12380,13 @@ static int llama_decode_internal( std::vector> seq_id; // count outputs - if (batch_all.logits) { + if (cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE) { + n_outputs = n_tokens_all; + } else if (batch_all.logits) { for (uint32_t i = 0; i < n_tokens_all; ++i) { n_outputs += batch_all.logits[i] != 0; } - } else if (lctx.logits_all || (cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE)) { + } else if (lctx.logits_all) { n_outputs = n_tokens_all; } else { // keep last output only @@ -12459,30 +12517,13 @@ static int llama_decode_internal( // no output res = nullptr; embd = nullptr; - } else if (!hparams.causal_attn) { - res = nullptr; // do not extract logits for embedding models such as BERT - - // token or sequence embeddings - embd = gf->nodes[gf->n_nodes - 1]; - - GGML_ASSERT(strcmp(embd->name, "result_embd") == 0 || strcmp(embd->name, "result_embd_pooled") == 0); } else if (cparams.embeddings) { - // the embeddings could be in the second to last tensor, or any of the previous tensors - int i_embd = gf->n_nodes - 2; - for (int i = 3; strcmp(embd->name, "result_norm") != 0; ++i) { - i_embd = gf->n_nodes - i; - if (i_embd < 0) { break; } - embd = gf->nodes[i_embd]; - } - GGML_ASSERT(i_embd >= 0 && "missing result_norm tensor"); - - // TODO: use a per-batch flag to know when to skip logits while keeping embeddings - if (!cparams.causal_attn) { - res = nullptr; // do not extract logits when not needed - // skip computing logits - // TODO: is this safe? - gf->n_nodes = i_embd + 1; + res = nullptr; // do not extract logits for embedding case + embd = gf->nodes[gf->n_nodes - 1]; + if (strcmp(embd->name, "result_embd_pooled") != 0) { + embd = gf->nodes[gf->n_nodes - 2]; } + GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0 && "missing embeddings tensor"); } else { embd = nullptr; // do not extract embeddings when not needed GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor"); @@ -12551,11 +12592,10 @@ static int llama_decode_internal( ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_outputs_new*n_embd*sizeof(float)); } } break; - case LLAMA_POOLING_TYPE_CLS: case LLAMA_POOLING_TYPE_MEAN: + case LLAMA_POOLING_TYPE_CLS: + case LLAMA_POOLING_TYPE_LAST: { - GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0); - // extract sequence embeddings auto & embd_seq_out = lctx.embd_seq; embd_seq_out.clear(); @@ -18112,6 +18152,10 @@ void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback) ctx->abort_callback_data = abort_callback_data; } +void llama_set_embeddings(struct llama_context * ctx, bool embeddings) { + ctx->cparams.embeddings = embeddings; +} + void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) { ctx->cparams.causal_attn = causal_attn; } diff --git a/llama.h b/llama.h index da310ffaf..05d8b092b 100644 --- a/llama.h +++ b/llama.h @@ -174,6 +174,7 @@ extern "C" { LLAMA_POOLING_TYPE_NONE = 0, LLAMA_POOLING_TYPE_MEAN = 1, LLAMA_POOLING_TYPE_CLS = 2, + LLAMA_POOLING_TYPE_LAST = 3, }; enum llama_split_mode { @@ -293,7 +294,6 @@ extern "C" { enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type` enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id - // (ignored if no pooling layer) // ref: https://github.com/ggerganov/llama.cpp/pull/2054 float rope_freq_base; // RoPE base frequency, 0 = from model @@ -786,6 +786,10 @@ extern "C" { // Get the number of threads used for prompt and batch processing (multiple token). LLAMA_API uint32_t llama_n_threads_batch(struct llama_context * ctx); + // Set whether the model is in embeddings model or not + // If true, embeddings will be returned but logits will not + LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings); + // Set whether to use causal attention or not // If set to true, the model will only attend to the past tokens LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn);