diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index bf9043750..41e444901 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -39,24 +39,23 @@ static std::vector> encode(llama_context* ctx, const std::vec // testing with and without EOS - unexpected embeddings in both cases - GritLM seems to have EOS = "" std::string input_string = instruction + sentences[i]; auto inputs = llama_tokenize(mdl, input_string, true, false); + uint64_t n_toks = inputs.size(); // https://github.com/ContextualAI/gritlm/blob/92025b16534712b31b3c4aaaf069350e222bd5f8/gritlm/gritlm.py#L116 // inputs.push_back(llama_token_eos(mdl)); // we want to ignore instruction tokens for mean pooling auto inputs_instruct = llama_tokenize(mdl, instruction, true, false); - int n_inst = inputs_instruct.size(); + uint64_t n_inst = inputs_instruct.size(); - /*/ // debug tokens - these are matching as referenced in their sample so doesn't appear to be a token issue std::for_each(inputs.begin(), inputs.end(), [&ctx](llama_token t) { std::printf("[%u:%s]", t, llama_token_to_piece(ctx, t).c_str()); }); std::printf("\n"); - */ // add input to batch (this increments n_tokens) - for (uint64_t j = 0; j < inputs.size(); j++) { - llama_batch_add(batch, inputs[j], j, { 0 }, false); + for (uint64_t j = 0; j < n_toks; j++) { + llama_batch_add(batch, inputs[j], j, { 0 }, j >= n_inst); } // clear previous kv_cache values (irrelevant for embeddings) @@ -66,23 +65,22 @@ static std::vector> encode(llama_context* ctx, const std::vec llama_decode(ctx, batch); // get embedding dimensions - int n_toks = inputs.size(); - int n_embd = llama_n_embd(mdl); + uint64_t n_embd = llama_n_embd(mdl); // allocate embedding output std::vector emb_unorm(n_embd, 0.0f); // sum up all token embeddings - for (int k = n_inst; k < n_toks; k++) { + for (uint64_t k = n_inst; k < n_toks; k++) { float * emb = llama_get_embeddings_ith(ctx, k); - for (int j = 0; j < n_embd; j++) { + for (uint64_t j = 0; j < n_embd; j++) { emb_unorm[j] += emb[j]; } } // divide by number of tokens (mean pooling) - int n_sent = n_toks - n_inst; - for (int j = 0; j < n_embd; j++) { + uint64_t n_sent = n_toks - n_inst; + for (uint64_t j = 0; j < n_embd; j++) { emb_unorm[j] /= n_sent; } @@ -90,14 +88,12 @@ static std::vector> encode(llama_context* ctx, const std::vec normalize(emb_unorm, emb_norm.data()); result.push_back(emb_norm); - /* // print out emb_norm std::printf("embedding %ld: ", i); - for (int j = 0; j < n_embd; j++) { + for (uint64_t j = 0; j < 20; j++) { std::printf("%.5f ", emb_norm[j]); } - std::printf("\n"); - */ + std::printf("\n\n"); llama_batch_free(batch); } @@ -124,14 +120,14 @@ int main(int argc, char* argv[]) ); return true; }; - cparams.embedding = true; + cparams.embeddings = true; + cparams.causal_attn = false; cparams.pooling_type = LLAMA_POOLING_TYPE_NONE; llama_backend_init(); auto mdl = llama_load_model_from_file(params.model.c_str(), mparams); auto ctx = llama_new_context_with_model(mdl, cparams); - auto bat = llama_batch_init(llama_n_ctx(ctx), 0, 1); // ### Embedding/Representation ### taken sample from here: // https://github.com/ContextualAI/gritlm?tab=readme-ov-file#basic @@ -167,7 +163,6 @@ int main(int argc, char* argv[]) std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[1].c_str(), documents[1].c_str(), cosine_sim_q1_d1); } - llama_batch_free(bat); llama_free(ctx); llama_free_model(mdl); llama_backend_free(); diff --git a/llama.cpp b/llama.cpp index 1442dd4d2..fd0e58cca 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1684,6 +1684,7 @@ struct llama_cparams { bool embeddings; bool offload_kqv; + bool causal_attn; enum llama_pooling_type pooling_type; ggml_backend_sched_eval_callback cb_eval; @@ -8029,7 +8030,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos)); } - if (hparams.causal_attn) { + if (cparams.causal_attn) { const int64_t n_kv = kv_self.n; const int64_t n_tokens = batch.n_tokens; @@ -11992,6 +11993,7 @@ struct llama_context_params llama_context_default_params() { /*.logits_all =*/ false, /*.embeddings =*/ false, /*.offload_kqv =*/ true, + /*.causal_attn =*/ true, /*.abort_callback =*/ nullptr, /*.abort_callback_data =*/ nullptr, }; @@ -12143,8 +12145,8 @@ struct llama_context * llama_new_context_with_model( cparams.defrag_thold = params.defrag_thold; cparams.embeddings = params.embeddings; cparams.offload_kqv = params.offload_kqv; + cparams.causal_attn = params.causal_attn; cparams.pooling_type = params.pooling_type; - cparams.causal_attn = !params.embedding; cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx; cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base; diff --git a/llama.h b/llama.h index 3dc162b07..6265d6901 100644 --- a/llama.h +++ b/llama.h @@ -262,6 +262,7 @@ extern "C" { bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead) bool embeddings; // if true, extract embeddings (together with logits) bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU + bool causal_attn; // whether to use causal attention // Abort callback // if it returns true, execution of llama_decode() will be aborted