diff --git a/common/common.cpp b/common/common.cpp index c244db644..d8baf7782 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1304,6 +1304,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.pooling_type = params.pooling_type; cparams.defrag_thold = params.defrag_thold; cparams.offload_kqv = !params.no_kv_offload; + cparams.causal_attn = !params.embedding; cparams.type_k = kv_cache_type_from_str(params.cache_type_k); cparams.type_v = kv_cache_type_from_str(params.cache_type_v); diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index 41e444901..d06f64e24 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -47,11 +47,13 @@ static std::vector> encode(llama_context* ctx, const std::vec auto inputs_instruct = llama_tokenize(mdl, instruction, true, false); uint64_t n_inst = inputs_instruct.size(); + /* // debug tokens - these are matching as referenced in their sample so doesn't appear to be a token issue std::for_each(inputs.begin(), inputs.end(), [&ctx](llama_token t) { std::printf("[%u:%s]", t, llama_token_to_piece(ctx, t).c_str()); }); std::printf("\n"); + */ // add input to batch (this increments n_tokens) for (uint64_t j = 0; j < n_toks; j++) { @@ -88,12 +90,14 @@ static std::vector> encode(llama_context* ctx, const std::vec normalize(emb_unorm, emb_norm.data()); result.push_back(emb_norm); + /* // print out emb_norm std::printf("embedding %ld: ", i); - for (uint64_t j = 0; j < 20; j++) { + for (uint64_t j = 0; j < n_embd; j++) { std::printf("%.5f ", emb_norm[j]); } std::printf("\n\n"); + */ llama_batch_free(batch); } @@ -120,6 +124,7 @@ int main(int argc, char* argv[]) ); return true; }; + cparams.embeddings = true; cparams.causal_attn = false; cparams.pooling_type = LLAMA_POOLING_TYPE_NONE; diff --git a/llama.cpp b/llama.cpp index fd0e58cca..04816ea9e 100644 --- a/llama.cpp +++ b/llama.cpp @@ -8057,6 +8057,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } else { // non-causal attention attends only the tokens within the batch (i.e. the KV cache is not used) const int64_t n_tokens = batch.n_tokens; + const int64_t n_stride = hparams.causal_attn ? kv_self.n : n_tokens; assert(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer)); @@ -8075,7 +8076,11 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } - data[h*(n_tokens*n_tokens) + j*n_tokens + i] = f; + data[h*(n_tokens*n_tokens) + j*n_stride + i] = f; + } + + for (int i = n_tokens; i < n_stride; ++i) { + data[h*(n_tokens*n_tokens) + j*n_stride + i] = -INFINITY; } } }