rebase to new embed

This commit is contained in:
Douglas Hanley 2024-03-05 23:23:17 -06:00
parent 805ae529c4
commit 97936078b7
3 changed files with 18 additions and 20 deletions

View File

@ -39,24 +39,23 @@ static std::vector<std::vector<float>> encode(llama_context* ctx, const std::vec
// testing with and without EOS - unexpected embeddings in both cases - GritLM seems to have EOS = ""
std::string input_string = instruction + sentences[i];
auto inputs = llama_tokenize(mdl, input_string, true, false);
uint64_t n_toks = inputs.size();
// https://github.com/ContextualAI/gritlm/blob/92025b16534712b31b3c4aaaf069350e222bd5f8/gritlm/gritlm.py#L116
// inputs.push_back(llama_token_eos(mdl));
// we want to ignore instruction tokens for mean pooling
auto inputs_instruct = llama_tokenize(mdl, instruction, true, false);
int n_inst = inputs_instruct.size();
uint64_t n_inst = inputs_instruct.size();
/*/
// debug tokens - these are matching as referenced in their sample so doesn't appear to be a token issue
std::for_each(inputs.begin(), inputs.end(), [&ctx](llama_token t) {
std::printf("[%u:%s]", t, llama_token_to_piece(ctx, t).c_str());
});
std::printf("\n");
*/
// add input to batch (this increments n_tokens)
for (uint64_t j = 0; j < inputs.size(); j++) {
llama_batch_add(batch, inputs[j], j, { 0 }, false);
for (uint64_t j = 0; j < n_toks; j++) {
llama_batch_add(batch, inputs[j], j, { 0 }, j >= n_inst);
}
// clear previous kv_cache values (irrelevant for embeddings)
@ -66,23 +65,22 @@ static std::vector<std::vector<float>> encode(llama_context* ctx, const std::vec
llama_decode(ctx, batch);
// get embedding dimensions
int n_toks = inputs.size();
int n_embd = llama_n_embd(mdl);
uint64_t n_embd = llama_n_embd(mdl);
// allocate embedding output
std::vector<float> emb_unorm(n_embd, 0.0f);
// sum up all token embeddings
for (int k = n_inst; k < n_toks; k++) {
for (uint64_t k = n_inst; k < n_toks; k++) {
float * emb = llama_get_embeddings_ith(ctx, k);
for (int j = 0; j < n_embd; j++) {
for (uint64_t j = 0; j < n_embd; j++) {
emb_unorm[j] += emb[j];
}
}
// divide by number of tokens (mean pooling)
int n_sent = n_toks - n_inst;
for (int j = 0; j < n_embd; j++) {
uint64_t n_sent = n_toks - n_inst;
for (uint64_t j = 0; j < n_embd; j++) {
emb_unorm[j] /= n_sent;
}
@ -90,14 +88,12 @@ static std::vector<std::vector<float>> encode(llama_context* ctx, const std::vec
normalize(emb_unorm, emb_norm.data());
result.push_back(emb_norm);
/*
// print out emb_norm
std::printf("embedding %ld: ", i);
for (int j = 0; j < n_embd; j++) {
for (uint64_t j = 0; j < 20; j++) {
std::printf("%.5f ", emb_norm[j]);
}
std::printf("\n");
*/
std::printf("\n\n");
llama_batch_free(batch);
}
@ -124,14 +120,14 @@ int main(int argc, char* argv[])
);
return true;
};
cparams.embedding = true;
cparams.embeddings = true;
cparams.causal_attn = false;
cparams.pooling_type = LLAMA_POOLING_TYPE_NONE;
llama_backend_init();
auto mdl = llama_load_model_from_file(params.model.c_str(), mparams);
auto ctx = llama_new_context_with_model(mdl, cparams);
auto bat = llama_batch_init(llama_n_ctx(ctx), 0, 1);
// ### Embedding/Representation ### taken sample from here:
// https://github.com/ContextualAI/gritlm?tab=readme-ov-file#basic
@ -167,7 +163,6 @@ int main(int argc, char* argv[])
std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[1].c_str(), documents[1].c_str(), cosine_sim_q1_d1);
}
llama_batch_free(bat);
llama_free(ctx);
llama_free_model(mdl);
llama_backend_free();

View File

@ -1684,6 +1684,7 @@ struct llama_cparams {
bool embeddings;
bool offload_kqv;
bool causal_attn;
enum llama_pooling_type pooling_type;
ggml_backend_sched_eval_callback cb_eval;
@ -8029,7 +8030,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
}
if (hparams.causal_attn) {
if (cparams.causal_attn) {
const int64_t n_kv = kv_self.n;
const int64_t n_tokens = batch.n_tokens;
@ -11992,6 +11993,7 @@ struct llama_context_params llama_context_default_params() {
/*.logits_all =*/ false,
/*.embeddings =*/ false,
/*.offload_kqv =*/ true,
/*.causal_attn =*/ true,
/*.abort_callback =*/ nullptr,
/*.abort_callback_data =*/ nullptr,
};
@ -12143,8 +12145,8 @@ struct llama_context * llama_new_context_with_model(
cparams.defrag_thold = params.defrag_thold;
cparams.embeddings = params.embeddings;
cparams.offload_kqv = params.offload_kqv;
cparams.causal_attn = params.causal_attn;
cparams.pooling_type = params.pooling_type;
cparams.causal_attn = !params.embedding;
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;

View File

@ -262,6 +262,7 @@ extern "C" {
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
bool embeddings; // if true, extract embeddings (together with logits)
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
bool causal_attn; // whether to use causal attention
// Abort callback
// if it returns true, execution of llama_decode() will be aborted