From e66da356a41530137161d20feb224c76f5bc13ec Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 4 Mar 2024 14:06:33 +0200 Subject: [PATCH] llama : add pooling switch --- llama.cpp | 43 +++++++++++++++++++++++++------------------ 1 file changed, 25 insertions(+), 18 deletions(-) diff --git a/llama.cpp b/llama.cpp index c5c78714d..6245af221 100644 --- a/llama.cpp +++ b/llama.cpp @@ -8113,7 +8113,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { for (int i = 0; i < n_tokens; ++i) { const llama_seq_id seq_id = batch.seq_id[i][0]; - const llama_pos pos = batch.pos[i]; + const llama_pos pos = batch.pos[i]; if (pos == 0) { data[seq_id] = i; } @@ -8379,10 +8379,17 @@ static int llama_decode_internal( if (batch.logits[i] == 0) { continue; } - if (hparams.pooling_type == LLAMA_POOLING_TYPE_CLS) { - ggml_backend_tensor_get_async(backend_embd, embd, embeddings_out.data() + (n_embd*i), (n_embd*batch.seq_id[i][0])*sizeof(float), n_embd*sizeof(float)); - } else { - ggml_backend_tensor_get_async(backend_embd, embd, embeddings_out.data() + (n_embd*i), (n_embd*i)*sizeof(float), n_embd*sizeof(float)); + switch (hparams.pooling_type) { + case LLAMA_POOLING_TYPE_CLS: + ggml_backend_tensor_get_async(backend_embd, embd, embeddings_out.data() + (n_embd*i), (n_embd*batch.seq_id[i][0])*sizeof(float), n_embd*sizeof(float)); + break; + case LLAMA_POOLING_TYPE_MEAN: + case LLAMA_POOLING_TYPE_NONE: + ggml_backend_tensor_get_async(backend_embd, embd, embeddings_out.data() + (n_embd*i), (n_embd*i)*sizeof(float), n_embd*sizeof(float)); + break; + default: + GGML_ASSERT(false && "unknown pooling type"); + break; } } } @@ -8680,19 +8687,19 @@ static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) { GGML_ASSERT(llama_is_byte_token(vocab, id)); const auto& token_data = vocab.id_to_token.at(id); switch (llama_vocab_get_type(vocab)) { - case LLAMA_VOCAB_TYPE_SPM: { - auto buf = token_data.text.substr(3, 2); - return strtol(buf.c_str(), NULL, 16); - } - case LLAMA_VOCAB_TYPE_BPE: { - GGML_ASSERT(false); - return unicode_to_bytes_bpe(token_data.text); - } - case LLAMA_VOCAB_TYPE_WPM: { - GGML_ASSERT(false); - } - default: - GGML_ASSERT(false); + case LLAMA_VOCAB_TYPE_SPM: { + auto buf = token_data.text.substr(3, 2); + return strtol(buf.c_str(), NULL, 16); + } + case LLAMA_VOCAB_TYPE_BPE: { + GGML_ASSERT(false); + return unicode_to_bytes_bpe(token_data.text); + } + case LLAMA_VOCAB_TYPE_WPM: { + GGML_ASSERT(false); + } + default: + GGML_ASSERT(false); } }