rerank : cleanup + comments

This commit is contained in:
Georgi Gerganov 2024-09-25 16:58:54 +03:00
parent 6916ed1606
commit 62a45d12ef
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
5 changed files with 27 additions and 14 deletions

View File

@ -236,7 +236,7 @@ int main(int argc, char ** argv) {
} }
} else if (pooling_type == LLAMA_POOLING_TYPE_RANK) { } else if (pooling_type == LLAMA_POOLING_TYPE_RANK) {
for (int j = 0; j < n_embd_count; j++) { for (int j = 0; j < n_embd_count; j++) {
LOG("rank score %d: %8.3f\n", j, emb[j * n_embd]); LOG("rerank score %d: %8.3f\n", j, emb[j * n_embd]);
} }
} else { } else {
// print the first part of the embeddings or for a single prompt, the full embedding // print the first part of the embeddings or for a single prompt, the full embedding

View File

@ -1419,7 +1419,7 @@ struct server_context {
queue_results.send(res); queue_results.send(res);
} }
void send_rank(const server_slot & slot, const llama_batch & batch) { void send_rerank(const server_slot & slot, const llama_batch & batch) {
server_task_result res; server_task_result res;
res.id = slot.id_task; res.id = slot.id_task;
res.error = false; res.error = false;
@ -1440,7 +1440,7 @@ struct server_context {
res.data = json { res.data = json {
{"index", slot.index}, {"index", slot.index},
{"rank", -1e6}, {"score", -1e6},
}; };
continue; continue;
@ -1448,11 +1448,11 @@ struct server_context {
res.data = json { res.data = json {
{"index", slot.index}, {"index", slot.index},
{"rank", embd[0]}, {"score", embd[0]},
}; };
} }
SLT_DBG(slot, "sending rank, res = '%s'\n", res.data.dump().c_str()); SLT_DBG(slot, "sending rerank result, res = '%s'\n", res.data.dump().c_str());
queue_results.send(res); queue_results.send(res);
} }
@ -1493,6 +1493,9 @@ struct server_context {
else if (prompt.is_array()) { else if (prompt.is_array()) {
std::vector<json> prompts = prompt; std::vector<json> prompts = prompt;
if (cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) { if (cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
// prompts[0] is the question
// the rest are the answers/documents
SRV_DBG("creating rerank tasks, n_prompts = %d\n", (int) prompts.size() - 1);
for (size_t i = 1; i < prompts.size(); i++) { for (size_t i = 1; i < prompts.size(); i++) {
json qd; json qd;
qd.push_back(prompts[0]); qd.push_back(prompts[0]);
@ -1501,6 +1504,7 @@ struct server_context {
create_task(data, true, qd); create_task(data, true, qd);
} }
} else { } else {
SRV_DBG("creating multi-prompt tasks, n_prompts = %d\n", (int) prompts.size());
for (size_t i = 0; i < prompts.size(); i++) { for (size_t i = 0; i < prompts.size(); i++) {
const auto & e = prompts[i]; const auto & e = prompts[i];
if (e.is_string() || json_is_array_of_numbers(e)) { if (e.is_string() || json_is_array_of_numbers(e)) {
@ -1965,6 +1969,7 @@ struct server_context {
// track if this is an embedding or non-embedding batch // track if this is an embedding or non-embedding batch
// if we've added sampled tokens above, we are in non-embedding mode // if we've added sampled tokens above, we are in non-embedding mode
// -1: none, 0: non-embedding, 1: embedding // -1: none, 0: non-embedding, 1: embedding
// TODO: make enum
int32_t batch_type = batch.n_tokens > 0 ? 0 : -1; int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;
// next, batch any pending prompts without exceeding n_batch // next, batch any pending prompts without exceeding n_batch
@ -2133,6 +2138,7 @@ struct server_context {
slot.n_prompt_tokens_processed = 0; slot.n_prompt_tokens_processed = 0;
} }
// non-causal tasks require to fit the entire prompt in the physical batch
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) { if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
// cannot fit the prompt in the current batch - will try next iter // cannot fit the prompt in the current batch - will try next iter
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) { if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
@ -2318,7 +2324,7 @@ struct server_context {
} }
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) { if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
send_rank(slot, batch_view); send_rerank(slot, batch_view);
slot.release(); slot.release();
slot.i_batch = -1; slot.i_batch = -1;
continue; // continue loop of slots continue; // continue loop of slots

View File

@ -553,7 +553,7 @@ static json format_response_rerank(const json & request, const json & ranks) {
for (const auto & rank : ranks) { for (const auto & rank : ranks) {
data.push_back(json{ data.push_back(json{
{"index", i++}, {"index", i++},
{"relevance_score", json_value(rank, "rank", 0.0)}, {"relevance_score", json_value(rank, "score", 0.0)},
}); });
} }

View File

@ -192,7 +192,7 @@ extern "C" {
LLAMA_POOLING_TYPE_MEAN = 1, LLAMA_POOLING_TYPE_MEAN = 1,
LLAMA_POOLING_TYPE_CLS = 2, LLAMA_POOLING_TYPE_CLS = 2,
LLAMA_POOLING_TYPE_LAST = 3, LLAMA_POOLING_TYPE_LAST = 3,
LLAMA_POOLING_TYPE_RANK = 4, LLAMA_POOLING_TYPE_RANK = 4, // used by reranking models to attach the classification head to the graph
}; };
enum llama_attention_type { enum llama_attention_type {
@ -872,7 +872,8 @@ extern "C" {
// Get the embeddings for a sequence id // Get the embeddings for a sequence id
// Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE // Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE
// shape: [n_embd] (1-dimensional) // when pooling_type == LLAMA_POOLING_TYPE_RANK, returns float[1] with the rank of the sequence
// otherwise: float[n_embd] (1-dimensional)
LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id); LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);
// //

View File

@ -17009,7 +17009,7 @@ static int llama_decode_internal(
} break; } break;
case LLAMA_POOLING_TYPE_RANK: case LLAMA_POOLING_TYPE_RANK:
{ {
// extract the rank score - a single float per sequence // extract the rerank score - a single float per sequence
auto & embd_seq_out = lctx.embd_seq; auto & embd_seq_out = lctx.embd_seq;
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) { for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
@ -17211,7 +17211,6 @@ static int llama_encode_internal(
case LLAMA_POOLING_TYPE_MEAN: case LLAMA_POOLING_TYPE_MEAN:
case LLAMA_POOLING_TYPE_CLS: case LLAMA_POOLING_TYPE_CLS:
case LLAMA_POOLING_TYPE_LAST: case LLAMA_POOLING_TYPE_LAST:
case LLAMA_POOLING_TYPE_RANK:
{ {
// extract sequence embeddings // extract sequence embeddings
auto & embd_seq_out = lctx.embd_seq; auto & embd_seq_out = lctx.embd_seq;
@ -17228,6 +17227,13 @@ static int llama_encode_internal(
ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float)); ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
} }
} break; } break;
case LLAMA_POOLING_TYPE_RANK:
{
// TODO: this likely should be the same logic as in llama_decoder_internal, but better to
// wait for an encoder model that requires this pooling type in order to test it
// https://github.com/ggerganov/llama.cpp/pull/9510
GGML_ABORT("RANK pooling not implemented yet");
}
case LLAMA_POOLING_TYPE_UNSPECIFIED: case LLAMA_POOLING_TYPE_UNSPECIFIED:
{ {
GGML_ABORT("unknown pooling type"); GGML_ABORT("unknown pooling type");