mostly style fixes; fix KQ_mask comment

This commit is contained in:
Douglas Hanley 2024-03-09 13:03:46 -06:00
parent 03acc82a85
commit b54afce9f4
2 changed files with 36 additions and 34 deletions

View File

@ -4,6 +4,8 @@
#include <string> #include <string>
#include <vector> #include <vector>
// #define GRIT_DEBUG
static float dot_product(const std::vector<float> & v1, const std::vector<float> & v2) { static float dot_product(const std::vector<float> & v1, const std::vector<float> & v2) {
float dot = 0.0f; float dot = 0.0f;
for (uint64_t i = 0; i < v1.size(); ++i) { for (uint64_t i = 0; i < v1.size(); ++i) {
@ -40,21 +42,21 @@ static std::vector<std::vector<float>> encode(llama_context* ctx, const std::vec
std::vector<llama_token> inputs = llama_tokenize(mdl, input_string, true, false); std::vector<llama_token> inputs = llama_tokenize(mdl, input_string, true, false);
auto n_toks = (int32_t)inputs.size(); auto n_toks = (int32_t)inputs.size();
// testing with and without EOS - unexpected embeddings in both cases - GritLM seems to have EOS = "" // GritLM seems to have embed EOS = ""
// https://github.com/ContextualAI/gritlm/blob/92025b16534712b31b3c4aaaf069350e222bd5f8/gritlm/gritlm.py#L116 // https://github.com/ContextualAI/gritlm/blob/92025b16534712b31b3c4aaaf069350e222bd5f8/gritlm/gritlm.py#L18
// inputs.push_back(llama_token_eos(mdl)); // inputs.push_back(llama_token_eos(mdl));
// we want to ignore instruction tokens for mean pooling // we want to ignore instruction tokens for mean pooling
std::vector<llama_token> inputs_instruct = llama_tokenize(mdl, instruction, true, false); std::vector<llama_token> inputs_instruct = llama_tokenize(mdl, instruction, true, false);
auto n_inst = (int32_t)inputs_instruct.size(); auto n_inst = (int32_t)inputs_instruct.size();
/* #ifdef GRIT_DEBUG
// debug tokens - should be matching as referenced in the GritLM sample // debug tokens - should be matching as referenced in the GritLM sample
std::for_each(inputs.begin(), inputs.end(), [&ctx](llama_token t) { std::for_each(inputs.begin(), inputs.end(), [&ctx](llama_token t) {
std::printf("[%u:%s]", t, llama_token_to_piece(ctx, t).c_str()); std::printf("[%u:%s]", t, llama_token_to_piece(ctx, t).c_str());
}); });
std::printf("\n"); std::printf("\n");
*/ #endif
// add input to batch (this increments n_tokens) // add input to batch (this increments n_tokens)
for (int32_t j = 0; j < n_toks; j++) { for (int32_t j = 0; j < n_toks; j++) {
@ -91,14 +93,14 @@ static std::vector<std::vector<float>> encode(llama_context* ctx, const std::vec
normalize(emb_unorm, emb_norm.data()); normalize(emb_unorm, emb_norm.data());
result.push_back(emb_norm); result.push_back(emb_norm);
/* #ifdef GRIT_DEBUG
// print out emb_norm // print out emb_norm
std::printf("embedding %ld: ", i); std::printf("embedding %ld: ", i);
for (uint64_t j = 0; j < n_embd; j++) { for (uint64_t j = 0; j < n_embd; j++) {
std::printf("%.5f ", emb_norm[j]); std::printf("%.5f ", emb_norm[j]);
} }
std::printf("\n\n"); std::printf("\n\n");
*/ #endif
} }
llama_batch_free(batch); llama_batch_free(batch);
@ -128,6 +130,7 @@ static std::string generate(llama_context* ctx, const std::string& prompt, bool
std::vector<std::string> pieces; std::vector<std::string> pieces;
const llama_model * mdl = llama_get_model(ctx); const llama_model * mdl = llama_get_model(ctx);
llama_token eos_token = llama_token_eos(mdl);
llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1); llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1);
std::vector<llama_token> inputs = llama_tokenize(mdl, prompt, false, true); std::vector<llama_token> inputs = llama_tokenize(mdl, prompt, false, true);
@ -135,25 +138,24 @@ static std::string generate(llama_context* ctx, const std::string& prompt, bool
while (true) { while (true) {
llama_batch_clear(bat); llama_batch_clear(bat);
for (auto i = 0; i < inputs.size(); i++) {
for (auto i = 0; i < inputs.size(); i++)
llama_batch_add(bat, inputs[i], i_current_token++, { 0 }, i == inputs.size() - 1); llama_batch_add(bat, inputs[i], i_current_token++, { 0 }, i == inputs.size() - 1);
}
inputs.clear(); inputs.clear();
llama_decode(ctx, bat); llama_decode(ctx, bat);
auto logits = llama_get_logits_ith(ctx, bat.n_tokens - 1); auto logits = llama_get_logits_ith(ctx, bat.n_tokens - 1);
auto candidates = std::vector<llama_token_data>(llama_n_vocab(mdl)); auto candidates = std::vector<llama_token_data>(llama_n_vocab(mdl));
for (auto token = 0; token < candidates.size(); token++) for (auto token = 0; token < candidates.size(); token++) {
candidates[token] = llama_token_data{ token, logits[token], 0.0f }; candidates[token] = llama_token_data{ token, logits[token], 0.0f };
}
auto candidates_p = llama_token_data_array{ candidates.data(), candidates.size(), false }; auto candidates_p = llama_token_data_array{ candidates.data(), candidates.size(), false };
llama_token token = llama_sample_token_greedy(ctx, &candidates_p); llama_token token = llama_sample_token_greedy(ctx, &candidates_p);
if (token == llama_token_eos(mdl)) if (token == eos_token) {
break; break;
}
std::string piece = llama_token_to_piece(ctx, token); std::string piece = llama_token_to_piece(ctx, token);
if (stream) { if (stream) {
@ -194,8 +196,8 @@ int main(int argc, char* argv[])
// create new context - default mode is causal // create new context - default mode is causal
llama_context * causal_ctx = llama_new_context_with_model(mdl, cparams); llama_context * causal_ctx = llama_new_context_with_model(mdl, cparams);
// ### Embedding/Representation ### samples taken from here: // samples taken from here: https://github.com/ContextualAI/gritlm#basic
// https://github.com/ContextualAI/gritlm?tab=readme-ov-file#basic // Embedding/Representation
{ {
std::string instruction = "Given a scientific paper title, retrieve the paper's abstract"; std::string instruction = "Given a scientific paper title, retrieve the paper's abstract";
@ -224,8 +226,8 @@ int main(int argc, char* argv[])
std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[1].c_str(), documents[1].c_str(), cosine_sim_q1_d1); std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[1].c_str(), documents[1].c_str(), cosine_sim_q1_d1);
} }
// ### Generation ### // Generation
// # GritLM models are not finetuned with system prompts, as you can just include system-like instructions together with your user instruction // GritLM models are not finetuned with system prompts, as you can just include system-like instructions together with your user instruction
{ {
const std::string prompt = "<|user|>\nPlease write me a poem about my recent hike of Mt. Fuji at midnight in the style of Shakespeare.\n<|assistant|>\n"; const std::string prompt = "<|user|>\nPlease write me a poem about my recent hike of Mt. Fuji at midnight in the style of Shakespeare.\n<|assistant|>\n";
std::string response = generate(causal_ctx, prompt, true); std::string response = generate(causal_ctx, prompt, true);

View File

@ -8061,7 +8061,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
} }
} }
} else { } else {
// with causal attention, the mask needs to match the kv cache size // for models using the kv cache, the mask needs to match the kv cache size
const int64_t n_tokens = batch.n_tokens; const int64_t n_tokens = batch.n_tokens;
const int64_t n_stride = hparams.causal_attn ? kv_self.n : n_tokens; const int64_t n_stride = hparams.causal_attn ? kv_self.n : n_tokens;