mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-02 14:54:35 +00:00
rebase to new embed
This commit is contained in:
parent
805ae529c4
commit
97936078b7
@ -39,24 +39,23 @@ static std::vector<std::vector<float>> encode(llama_context* ctx, const std::vec
|
||||
// testing with and without EOS - unexpected embeddings in both cases - GritLM seems to have EOS = ""
|
||||
std::string input_string = instruction + sentences[i];
|
||||
auto inputs = llama_tokenize(mdl, input_string, true, false);
|
||||
uint64_t n_toks = inputs.size();
|
||||
// https://github.com/ContextualAI/gritlm/blob/92025b16534712b31b3c4aaaf069350e222bd5f8/gritlm/gritlm.py#L116
|
||||
// inputs.push_back(llama_token_eos(mdl));
|
||||
|
||||
// we want to ignore instruction tokens for mean pooling
|
||||
auto inputs_instruct = llama_tokenize(mdl, instruction, true, false);
|
||||
int n_inst = inputs_instruct.size();
|
||||
uint64_t n_inst = inputs_instruct.size();
|
||||
|
||||
/*/
|
||||
// debug tokens - these are matching as referenced in their sample so doesn't appear to be a token issue
|
||||
std::for_each(inputs.begin(), inputs.end(), [&ctx](llama_token t) {
|
||||
std::printf("[%u:%s]", t, llama_token_to_piece(ctx, t).c_str());
|
||||
});
|
||||
std::printf("\n");
|
||||
*/
|
||||
|
||||
// add input to batch (this increments n_tokens)
|
||||
for (uint64_t j = 0; j < inputs.size(); j++) {
|
||||
llama_batch_add(batch, inputs[j], j, { 0 }, false);
|
||||
for (uint64_t j = 0; j < n_toks; j++) {
|
||||
llama_batch_add(batch, inputs[j], j, { 0 }, j >= n_inst);
|
||||
}
|
||||
|
||||
// clear previous kv_cache values (irrelevant for embeddings)
|
||||
@ -66,23 +65,22 @@ static std::vector<std::vector<float>> encode(llama_context* ctx, const std::vec
|
||||
llama_decode(ctx, batch);
|
||||
|
||||
// get embedding dimensions
|
||||
int n_toks = inputs.size();
|
||||
int n_embd = llama_n_embd(mdl);
|
||||
uint64_t n_embd = llama_n_embd(mdl);
|
||||
|
||||
// allocate embedding output
|
||||
std::vector<float> emb_unorm(n_embd, 0.0f);
|
||||
|
||||
// sum up all token embeddings
|
||||
for (int k = n_inst; k < n_toks; k++) {
|
||||
for (uint64_t k = n_inst; k < n_toks; k++) {
|
||||
float * emb = llama_get_embeddings_ith(ctx, k);
|
||||
for (int j = 0; j < n_embd; j++) {
|
||||
for (uint64_t j = 0; j < n_embd; j++) {
|
||||
emb_unorm[j] += emb[j];
|
||||
}
|
||||
}
|
||||
|
||||
// divide by number of tokens (mean pooling)
|
||||
int n_sent = n_toks - n_inst;
|
||||
for (int j = 0; j < n_embd; j++) {
|
||||
uint64_t n_sent = n_toks - n_inst;
|
||||
for (uint64_t j = 0; j < n_embd; j++) {
|
||||
emb_unorm[j] /= n_sent;
|
||||
}
|
||||
|
||||
@ -90,14 +88,12 @@ static std::vector<std::vector<float>> encode(llama_context* ctx, const std::vec
|
||||
normalize(emb_unorm, emb_norm.data());
|
||||
result.push_back(emb_norm);
|
||||
|
||||
/*
|
||||
// print out emb_norm
|
||||
std::printf("embedding %ld: ", i);
|
||||
for (int j = 0; j < n_embd; j++) {
|
||||
for (uint64_t j = 0; j < 20; j++) {
|
||||
std::printf("%.5f ", emb_norm[j]);
|
||||
}
|
||||
std::printf("\n");
|
||||
*/
|
||||
std::printf("\n\n");
|
||||
|
||||
llama_batch_free(batch);
|
||||
}
|
||||
@ -124,14 +120,14 @@ int main(int argc, char* argv[])
|
||||
);
|
||||
return true;
|
||||
};
|
||||
cparams.embedding = true;
|
||||
cparams.embeddings = true;
|
||||
cparams.causal_attn = false;
|
||||
cparams.pooling_type = LLAMA_POOLING_TYPE_NONE;
|
||||
|
||||
llama_backend_init();
|
||||
|
||||
auto mdl = llama_load_model_from_file(params.model.c_str(), mparams);
|
||||
auto ctx = llama_new_context_with_model(mdl, cparams);
|
||||
auto bat = llama_batch_init(llama_n_ctx(ctx), 0, 1);
|
||||
|
||||
// ### Embedding/Representation ### taken sample from here:
|
||||
// https://github.com/ContextualAI/gritlm?tab=readme-ov-file#basic
|
||||
@ -167,7 +163,6 @@ int main(int argc, char* argv[])
|
||||
std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[1].c_str(), documents[1].c_str(), cosine_sim_q1_d1);
|
||||
}
|
||||
|
||||
llama_batch_free(bat);
|
||||
llama_free(ctx);
|
||||
llama_free_model(mdl);
|
||||
llama_backend_free();
|
||||
|
@ -1684,6 +1684,7 @@ struct llama_cparams {
|
||||
|
||||
bool embeddings;
|
||||
bool offload_kqv;
|
||||
bool causal_attn;
|
||||
enum llama_pooling_type pooling_type;
|
||||
|
||||
ggml_backend_sched_eval_callback cb_eval;
|
||||
@ -8029,7 +8030,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
|
||||
}
|
||||
|
||||
if (hparams.causal_attn) {
|
||||
if (cparams.causal_attn) {
|
||||
const int64_t n_kv = kv_self.n;
|
||||
const int64_t n_tokens = batch.n_tokens;
|
||||
|
||||
@ -11992,6 +11993,7 @@ struct llama_context_params llama_context_default_params() {
|
||||
/*.logits_all =*/ false,
|
||||
/*.embeddings =*/ false,
|
||||
/*.offload_kqv =*/ true,
|
||||
/*.causal_attn =*/ true,
|
||||
/*.abort_callback =*/ nullptr,
|
||||
/*.abort_callback_data =*/ nullptr,
|
||||
};
|
||||
@ -12143,8 +12145,8 @@ struct llama_context * llama_new_context_with_model(
|
||||
cparams.defrag_thold = params.defrag_thold;
|
||||
cparams.embeddings = params.embeddings;
|
||||
cparams.offload_kqv = params.offload_kqv;
|
||||
cparams.causal_attn = params.causal_attn;
|
||||
cparams.pooling_type = params.pooling_type;
|
||||
cparams.causal_attn = !params.embedding;
|
||||
|
||||
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
|
||||
cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
|
||||
|
1
llama.h
1
llama.h
@ -262,6 +262,7 @@ extern "C" {
|
||||
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
|
||||
bool embeddings; // if true, extract embeddings (together with logits)
|
||||
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
|
||||
bool causal_attn; // whether to use causal attention
|
||||
|
||||
// Abort callback
|
||||
// if it returns true, execution of llama_decode() will be aborted
|
||||
|
Loading…
Reference in New Issue
Block a user