From 6916ed160673d47e1e4f809f5b27ee68e2d9039e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 23 Sep 2024 20:20:38 +0300 Subject: [PATCH] llama : aboud ggml_repeat during classification --- src/llama.cpp | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 1346198f8..f0f7b67cf 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -10243,9 +10243,6 @@ struct llm_build_context { cur = ggml_add (ctx0, ggml_mul_mat(ctx0, model.cls, inp), model.cls_b); cur = ggml_tanh(ctx0, cur); cur = ggml_add (ctx0, ggml_mul_mat(ctx0, model.cls_out, cur), model.cls_out_b); - - // broadcast across the embedding size to make it compatible with the llama_get_embeddings API - cur = ggml_repeat(ctx0, cur, inp); } break; default: { @@ -16997,7 +16994,6 @@ static int llama_decode_internal( case LLAMA_POOLING_TYPE_MEAN: case LLAMA_POOLING_TYPE_CLS: case LLAMA_POOLING_TYPE_LAST: - case LLAMA_POOLING_TYPE_RANK: { // extract sequence embeddings (cleared before processing each batch) auto & embd_seq_out = lctx.embd_seq; @@ -17011,6 +17007,20 @@ static int llama_decode_internal( ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float)); } } break; + case LLAMA_POOLING_TYPE_RANK: + { + // extract the rank score - a single float per sequence + auto & embd_seq_out = lctx.embd_seq; + + for (uint32_t s = 0; s < ubatch.n_seqs; ++s) { + const llama_seq_id seq_id = ubatch.seq_id[s][0]; + if (embd_seq_out.find(seq_id) != embd_seq_out.end()) { + continue; + } + embd_seq_out[seq_id].resize(1); + ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float)); + } + } break; case LLAMA_POOLING_TYPE_UNSPECIFIED: { GGML_ABORT("unknown pooling type");