mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-30 21:34:36 +00:00
gritlm embeddings are back babeee
This commit is contained in:
parent
97936078b7
commit
1ab6aeeeee
@ -1304,6 +1304,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
|
|||||||
cparams.pooling_type = params.pooling_type;
|
cparams.pooling_type = params.pooling_type;
|
||||||
cparams.defrag_thold = params.defrag_thold;
|
cparams.defrag_thold = params.defrag_thold;
|
||||||
cparams.offload_kqv = !params.no_kv_offload;
|
cparams.offload_kqv = !params.no_kv_offload;
|
||||||
|
cparams.causal_attn = !params.embedding;
|
||||||
|
|
||||||
cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
|
cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
|
||||||
cparams.type_v = kv_cache_type_from_str(params.cache_type_v);
|
cparams.type_v = kv_cache_type_from_str(params.cache_type_v);
|
||||||
|
@ -47,11 +47,13 @@ static std::vector<std::vector<float>> encode(llama_context* ctx, const std::vec
|
|||||||
auto inputs_instruct = llama_tokenize(mdl, instruction, true, false);
|
auto inputs_instruct = llama_tokenize(mdl, instruction, true, false);
|
||||||
uint64_t n_inst = inputs_instruct.size();
|
uint64_t n_inst = inputs_instruct.size();
|
||||||
|
|
||||||
|
/*
|
||||||
// debug tokens - these are matching as referenced in their sample so doesn't appear to be a token issue
|
// debug tokens - these are matching as referenced in their sample so doesn't appear to be a token issue
|
||||||
std::for_each(inputs.begin(), inputs.end(), [&ctx](llama_token t) {
|
std::for_each(inputs.begin(), inputs.end(), [&ctx](llama_token t) {
|
||||||
std::printf("[%u:%s]", t, llama_token_to_piece(ctx, t).c_str());
|
std::printf("[%u:%s]", t, llama_token_to_piece(ctx, t).c_str());
|
||||||
});
|
});
|
||||||
std::printf("\n");
|
std::printf("\n");
|
||||||
|
*/
|
||||||
|
|
||||||
// add input to batch (this increments n_tokens)
|
// add input to batch (this increments n_tokens)
|
||||||
for (uint64_t j = 0; j < n_toks; j++) {
|
for (uint64_t j = 0; j < n_toks; j++) {
|
||||||
@ -88,12 +90,14 @@ static std::vector<std::vector<float>> encode(llama_context* ctx, const std::vec
|
|||||||
normalize(emb_unorm, emb_norm.data());
|
normalize(emb_unorm, emb_norm.data());
|
||||||
result.push_back(emb_norm);
|
result.push_back(emb_norm);
|
||||||
|
|
||||||
|
/*
|
||||||
// print out emb_norm
|
// print out emb_norm
|
||||||
std::printf("embedding %ld: ", i);
|
std::printf("embedding %ld: ", i);
|
||||||
for (uint64_t j = 0; j < 20; j++) {
|
for (uint64_t j = 0; j < n_embd; j++) {
|
||||||
std::printf("%.5f ", emb_norm[j]);
|
std::printf("%.5f ", emb_norm[j]);
|
||||||
}
|
}
|
||||||
std::printf("\n\n");
|
std::printf("\n\n");
|
||||||
|
*/
|
||||||
|
|
||||||
llama_batch_free(batch);
|
llama_batch_free(batch);
|
||||||
}
|
}
|
||||||
@ -120,6 +124,7 @@ int main(int argc, char* argv[])
|
|||||||
);
|
);
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
|
||||||
cparams.embeddings = true;
|
cparams.embeddings = true;
|
||||||
cparams.causal_attn = false;
|
cparams.causal_attn = false;
|
||||||
cparams.pooling_type = LLAMA_POOLING_TYPE_NONE;
|
cparams.pooling_type = LLAMA_POOLING_TYPE_NONE;
|
||||||
|
@ -8057,6 +8057,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|||||||
} else {
|
} else {
|
||||||
// non-causal attention attends only the tokens within the batch (i.e. the KV cache is not used)
|
// non-causal attention attends only the tokens within the batch (i.e. the KV cache is not used)
|
||||||
const int64_t n_tokens = batch.n_tokens;
|
const int64_t n_tokens = batch.n_tokens;
|
||||||
|
const int64_t n_stride = hparams.causal_attn ? kv_self.n : n_tokens;
|
||||||
|
|
||||||
assert(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
|
assert(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
|
||||||
|
|
||||||
@ -8075,7 +8076,11 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
data[h*(n_tokens*n_tokens) + j*n_tokens + i] = f;
|
data[h*(n_tokens*n_tokens) + j*n_stride + i] = f;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = n_tokens; i < n_stride; ++i) {
|
||||||
|
data[h*(n_tokens*n_tokens) + j*n_stride + i] = -INFINITY;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user