mostly style fixes; fix KQ_mask comment

This commit is contained in:
Douglas Hanley 2024-03-09 13:03:46 -06:00
parent 03acc82a85
commit b54afce9f4
2 changed files with 36 additions and 34 deletions

View File

@ -4,7 +4,9 @@
#include <string>
#include <vector>
static float dot_product(const std::vector<float>& v1, const std::vector<float>& v2) {
// #define GRIT_DEBUG
static float dot_product(const std::vector<float> & v1, const std::vector<float> & v2) {
float dot = 0.0f;
for (uint64_t i = 0; i < v1.size(); ++i) {
dot += v1[i] * v2[i];
@ -12,22 +14,22 @@ static float dot_product(const std::vector<float>& v1, const std::vector<float>&
return dot;
}
static float norm(const std::vector<float>& v) {
static float norm(const std::vector<float> & v) {
return std::sqrt(dot_product(v, v));
}
static float cosine_similarity(const std::vector<float>& v1, const std::vector<float>& v2) {
static float cosine_similarity(const std::vector<float> & v1, const std::vector<float> & v2) {
return dot_product(v1, v2) / (norm(v1) * norm(v2));
}
static void normalize(const std::vector<float>& in, float* out) {
static void normalize(const std::vector<float> & in, float * out) {
float inorm = norm(in);
for (uint64_t i = 0; i < in.size(); i++) {
out[i] = in[i] / inorm;
}
}
static std::vector<std::vector<float>> encode(llama_context* ctx, const std::vector<std::string>& sentences, const std::string& instruction) {
static std::vector<std::vector<float>> encode(llama_context * ctx, const std::vector<std::string> & sentences, const std::string & instruction) {
auto result = std::vector<std::vector<float>>{};
auto mdl = llama_get_model(ctx);
@ -40,21 +42,21 @@ static std::vector<std::vector<float>> encode(llama_context* ctx, const std::vec
std::vector<llama_token> inputs = llama_tokenize(mdl, input_string, true, false);
auto n_toks = (int32_t)inputs.size();
// testing with and without EOS - unexpected embeddings in both cases - GritLM seems to have EOS = ""
// https://github.com/ContextualAI/gritlm/blob/92025b16534712b31b3c4aaaf069350e222bd5f8/gritlm/gritlm.py#L116
// GritLM seems to have embed EOS = ""
// https://github.com/ContextualAI/gritlm/blob/92025b16534712b31b3c4aaaf069350e222bd5f8/gritlm/gritlm.py#L18
// inputs.push_back(llama_token_eos(mdl));
// we want to ignore instruction tokens for mean pooling
std::vector<llama_token> inputs_instruct = llama_tokenize(mdl, instruction, true, false);
auto n_inst = (int32_t)inputs_instruct.size();
/*
#ifdef GRIT_DEBUG
// debug tokens - should be matching as referenced in the GritLM sample
std::for_each(inputs.begin(), inputs.end(), [&ctx](llama_token t) {
std::printf("[%u:%s]", t, llama_token_to_piece(ctx, t).c_str());
});
std::printf("\n");
*/
#endif
// add input to batch (this increments n_tokens)
for (int32_t j = 0; j < n_toks; j++) {
@ -75,7 +77,7 @@ static std::vector<std::vector<float>> encode(llama_context* ctx, const std::vec
// sum up all token embeddings
for (int32_t k = n_inst; k < n_toks; k++) {
float* emb = llama_get_embeddings_ith(ctx, k);
float * emb = llama_get_embeddings_ith(ctx, k);
for (uint64_t j = 0; j < n_embd; j++) {
emb_unorm[j] += emb[j];
}
@ -91,24 +93,24 @@ static std::vector<std::vector<float>> encode(llama_context* ctx, const std::vec
normalize(emb_unorm, emb_norm.data());
result.push_back(emb_norm);
/*
#ifdef GRIT_DEBUG
// print out emb_norm
std::printf("embedding %ld: ", i);
for (uint64_t j = 0; j < n_embd; j++) {
std::printf("%.5f ", emb_norm[j]);
}
std::printf("\n\n");
*/
#endif
}
llama_batch_free(batch);
return result;
}
static std::string aggregate_pieces(const std::vector<std::string>& pieces) {
static std::string aggregate_pieces(const std::vector<std::string> & pieces) {
// calculate total length required
size_t length = 0;
for (const auto& str : pieces) {
for (const auto & str : pieces) {
length += str.size();
}
@ -117,17 +119,18 @@ static std::string aggregate_pieces(const std::vector<std::string>& pieces) {
result.reserve(length);
// append pieces
for (const auto& str : pieces) {
for (const auto & str : pieces) {
result += str;
}
return result;
}
static std::string generate(llama_context* ctx, const std::string& prompt, bool stream) {
static std::string generate(llama_context * ctx, const std::string & prompt, bool stream) {
std::vector<std::string> pieces;
const llama_model* mdl = llama_get_model(ctx);
const llama_model * mdl = llama_get_model(ctx);
llama_token eos_token = llama_token_eos(mdl);
llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1);
std::vector<llama_token> inputs = llama_tokenize(mdl, prompt, false, true);
@ -135,25 +138,24 @@ static std::string generate(llama_context* ctx, const std::string& prompt, bool
while (true) {
llama_batch_clear(bat);
for (auto i = 0; i < inputs.size(); i++)
for (auto i = 0; i < inputs.size(); i++) {
llama_batch_add(bat, inputs[i], i_current_token++, { 0 }, i == inputs.size() - 1);
}
inputs.clear();
llama_decode(ctx, bat);
auto logits = llama_get_logits_ith(ctx, bat.n_tokens - 1);
auto candidates = std::vector<llama_token_data>(llama_n_vocab(mdl));
for (auto token = 0; token < candidates.size(); token++)
for (auto token = 0; token < candidates.size(); token++) {
candidates[token] = llama_token_data{ token, logits[token], 0.0f };
}
auto candidates_p = llama_token_data_array{ candidates.data(), candidates.size(), false };
llama_token token = llama_sample_token_greedy(ctx, &candidates_p);
if (token == llama_token_eos(mdl))
if (token == eos_token) {
break;
}
std::string piece = llama_token_to_piece(ctx, token);
if (stream) {
@ -169,11 +171,11 @@ static std::string generate(llama_context* ctx, const std::string& prompt, bool
return aggregate_pieces(pieces);
}
static std::string gritlm_instruction(const std::string& instruction) {
static std::string gritlm_instruction(const std::string & instruction) {
return !instruction.empty() ? "<|user|>\n" + instruction + "\n<|embed|>\n" : "<|embed|>\n";
}
int main(int argc, char* argv[])
int main(int argc, char * argv[])
{
gpt_params params;
if (!gpt_params_parse(argc, argv, params)) {
@ -185,17 +187,17 @@ int main(int argc, char* argv[])
llama_backend_init();
llama_model* mdl = llama_load_model_from_file(params.model.c_str(), mparams);
llama_model * mdl = llama_load_model_from_file(params.model.c_str(), mparams);
// create new context - set to embedding mode
llama_context* embd_ctx = llama_new_context_with_model(mdl, cparams);
llama_context * embd_ctx = llama_new_context_with_model(mdl, cparams);
llama_set_embeddings(embd_ctx, true);
// create new context - default mode is causal
llama_context* causal_ctx = llama_new_context_with_model(mdl, cparams);
llama_context * causal_ctx = llama_new_context_with_model(mdl, cparams);
// ### Embedding/Representation ### samples taken from here:
// https://github.com/ContextualAI/gritlm?tab=readme-ov-file#basic
// samples taken from here: https://github.com/ContextualAI/gritlm#basic
// Embedding/Representation
{
std::string instruction = "Given a scientific paper title, retrieve the paper's abstract";
@ -224,8 +226,8 @@ int main(int argc, char* argv[])
std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[1].c_str(), documents[1].c_str(), cosine_sim_q1_d1);
}
// ### Generation ###
// # GritLM models are not finetuned with system prompts, as you can just include system-like instructions together with your user instruction
// Generation
// GritLM models are not finetuned with system prompts, as you can just include system-like instructions together with your user instruction
{
const std::string prompt = "<|user|>\nPlease write me a poem about my recent hike of Mt. Fuji at midnight in the style of Shakespeare.\n<|assistant|>\n";
std::string response = generate(causal_ctx, prompt, true);

View File

@ -8061,7 +8061,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
}
}
} else {
// with causal attention, the mask needs to match the kv cache size
// for models using the kv cache, the mask needs to match the kv cache size
const int64_t n_tokens = batch.n_tokens;
const int64_t n_stride = hparams.causal_attn ? kv_self.n : n_tokens;