mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-03 15:24:35 +00:00
rerank : cleanup + comments
This commit is contained in:
parent
6916ed1606
commit
62a45d12ef
@ -236,7 +236,7 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
} else if (pooling_type == LLAMA_POOLING_TYPE_RANK) {
|
} else if (pooling_type == LLAMA_POOLING_TYPE_RANK) {
|
||||||
for (int j = 0; j < n_embd_count; j++) {
|
for (int j = 0; j < n_embd_count; j++) {
|
||||||
LOG("rank score %d: %8.3f\n", j, emb[j * n_embd]);
|
LOG("rerank score %d: %8.3f\n", j, emb[j * n_embd]);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// print the first part of the embeddings or for a single prompt, the full embedding
|
// print the first part of the embeddings or for a single prompt, the full embedding
|
||||||
|
@ -1419,7 +1419,7 @@ struct server_context {
|
|||||||
queue_results.send(res);
|
queue_results.send(res);
|
||||||
}
|
}
|
||||||
|
|
||||||
void send_rank(const server_slot & slot, const llama_batch & batch) {
|
void send_rerank(const server_slot & slot, const llama_batch & batch) {
|
||||||
server_task_result res;
|
server_task_result res;
|
||||||
res.id = slot.id_task;
|
res.id = slot.id_task;
|
||||||
res.error = false;
|
res.error = false;
|
||||||
@ -1440,7 +1440,7 @@ struct server_context {
|
|||||||
|
|
||||||
res.data = json {
|
res.data = json {
|
||||||
{"index", slot.index},
|
{"index", slot.index},
|
||||||
{"rank", -1e6},
|
{"score", -1e6},
|
||||||
};
|
};
|
||||||
|
|
||||||
continue;
|
continue;
|
||||||
@ -1448,11 +1448,11 @@ struct server_context {
|
|||||||
|
|
||||||
res.data = json {
|
res.data = json {
|
||||||
{"index", slot.index},
|
{"index", slot.index},
|
||||||
{"rank", embd[0]},
|
{"score", embd[0]},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
SLT_DBG(slot, "sending rank, res = '%s'\n", res.data.dump().c_str());
|
SLT_DBG(slot, "sending rerank result, res = '%s'\n", res.data.dump().c_str());
|
||||||
|
|
||||||
queue_results.send(res);
|
queue_results.send(res);
|
||||||
}
|
}
|
||||||
@ -1493,6 +1493,9 @@ struct server_context {
|
|||||||
else if (prompt.is_array()) {
|
else if (prompt.is_array()) {
|
||||||
std::vector<json> prompts = prompt;
|
std::vector<json> prompts = prompt;
|
||||||
if (cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
if (cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
||||||
|
// prompts[0] is the question
|
||||||
|
// the rest are the answers/documents
|
||||||
|
SRV_DBG("creating rerank tasks, n_prompts = %d\n", (int) prompts.size() - 1);
|
||||||
for (size_t i = 1; i < prompts.size(); i++) {
|
for (size_t i = 1; i < prompts.size(); i++) {
|
||||||
json qd;
|
json qd;
|
||||||
qd.push_back(prompts[0]);
|
qd.push_back(prompts[0]);
|
||||||
@ -1501,6 +1504,7 @@ struct server_context {
|
|||||||
create_task(data, true, qd);
|
create_task(data, true, qd);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
SRV_DBG("creating multi-prompt tasks, n_prompts = %d\n", (int) prompts.size());
|
||||||
for (size_t i = 0; i < prompts.size(); i++) {
|
for (size_t i = 0; i < prompts.size(); i++) {
|
||||||
const auto & e = prompts[i];
|
const auto & e = prompts[i];
|
||||||
if (e.is_string() || json_is_array_of_numbers(e)) {
|
if (e.is_string() || json_is_array_of_numbers(e)) {
|
||||||
@ -1965,6 +1969,7 @@ struct server_context {
|
|||||||
// track if this is an embedding or non-embedding batch
|
// track if this is an embedding or non-embedding batch
|
||||||
// if we've added sampled tokens above, we are in non-embedding mode
|
// if we've added sampled tokens above, we are in non-embedding mode
|
||||||
// -1: none, 0: non-embedding, 1: embedding
|
// -1: none, 0: non-embedding, 1: embedding
|
||||||
|
// TODO: make enum
|
||||||
int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;
|
int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;
|
||||||
|
|
||||||
// next, batch any pending prompts without exceeding n_batch
|
// next, batch any pending prompts without exceeding n_batch
|
||||||
@ -2133,6 +2138,7 @@ struct server_context {
|
|||||||
slot.n_prompt_tokens_processed = 0;
|
slot.n_prompt_tokens_processed = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// non-causal tasks require to fit the entire prompt in the physical batch
|
||||||
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
||||||
// cannot fit the prompt in the current batch - will try next iter
|
// cannot fit the prompt in the current batch - will try next iter
|
||||||
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
|
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
|
||||||
@ -2318,7 +2324,7 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
||||||
send_rank(slot, batch_view);
|
send_rerank(slot, batch_view);
|
||||||
slot.release();
|
slot.release();
|
||||||
slot.i_batch = -1;
|
slot.i_batch = -1;
|
||||||
continue; // continue loop of slots
|
continue; // continue loop of slots
|
||||||
|
@ -553,7 +553,7 @@ static json format_response_rerank(const json & request, const json & ranks) {
|
|||||||
for (const auto & rank : ranks) {
|
for (const auto & rank : ranks) {
|
||||||
data.push_back(json{
|
data.push_back(json{
|
||||||
{"index", i++},
|
{"index", i++},
|
||||||
{"relevance_score", json_value(rank, "rank", 0.0)},
|
{"relevance_score", json_value(rank, "score", 0.0)},
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -192,7 +192,7 @@ extern "C" {
|
|||||||
LLAMA_POOLING_TYPE_MEAN = 1,
|
LLAMA_POOLING_TYPE_MEAN = 1,
|
||||||
LLAMA_POOLING_TYPE_CLS = 2,
|
LLAMA_POOLING_TYPE_CLS = 2,
|
||||||
LLAMA_POOLING_TYPE_LAST = 3,
|
LLAMA_POOLING_TYPE_LAST = 3,
|
||||||
LLAMA_POOLING_TYPE_RANK = 4,
|
LLAMA_POOLING_TYPE_RANK = 4, // used by reranking models to attach the classification head to the graph
|
||||||
};
|
};
|
||||||
|
|
||||||
enum llama_attention_type {
|
enum llama_attention_type {
|
||||||
@ -872,7 +872,8 @@ extern "C" {
|
|||||||
|
|
||||||
// Get the embeddings for a sequence id
|
// Get the embeddings for a sequence id
|
||||||
// Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE
|
// Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE
|
||||||
// shape: [n_embd] (1-dimensional)
|
// when pooling_type == LLAMA_POOLING_TYPE_RANK, returns float[1] with the rank of the sequence
|
||||||
|
// otherwise: float[n_embd] (1-dimensional)
|
||||||
LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);
|
LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);
|
||||||
|
|
||||||
//
|
//
|
||||||
|
@ -17009,7 +17009,7 @@ static int llama_decode_internal(
|
|||||||
} break;
|
} break;
|
||||||
case LLAMA_POOLING_TYPE_RANK:
|
case LLAMA_POOLING_TYPE_RANK:
|
||||||
{
|
{
|
||||||
// extract the rank score - a single float per sequence
|
// extract the rerank score - a single float per sequence
|
||||||
auto & embd_seq_out = lctx.embd_seq;
|
auto & embd_seq_out = lctx.embd_seq;
|
||||||
|
|
||||||
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
|
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
|
||||||
@ -17211,7 +17211,6 @@ static int llama_encode_internal(
|
|||||||
case LLAMA_POOLING_TYPE_MEAN:
|
case LLAMA_POOLING_TYPE_MEAN:
|
||||||
case LLAMA_POOLING_TYPE_CLS:
|
case LLAMA_POOLING_TYPE_CLS:
|
||||||
case LLAMA_POOLING_TYPE_LAST:
|
case LLAMA_POOLING_TYPE_LAST:
|
||||||
case LLAMA_POOLING_TYPE_RANK:
|
|
||||||
{
|
{
|
||||||
// extract sequence embeddings
|
// extract sequence embeddings
|
||||||
auto & embd_seq_out = lctx.embd_seq;
|
auto & embd_seq_out = lctx.embd_seq;
|
||||||
@ -17228,6 +17227,13 @@ static int llama_encode_internal(
|
|||||||
ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
|
ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case LLAMA_POOLING_TYPE_RANK:
|
||||||
|
{
|
||||||
|
// TODO: this likely should be the same logic as in llama_decoder_internal, but better to
|
||||||
|
// wait for an encoder model that requires this pooling type in order to test it
|
||||||
|
// https://github.com/ggerganov/llama.cpp/pull/9510
|
||||||
|
GGML_ABORT("RANK pooling not implemented yet");
|
||||||
|
}
|
||||||
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
||||||
{
|
{
|
||||||
GGML_ABORT("unknown pooling type");
|
GGML_ABORT("unknown pooling type");
|
||||||
|
Loading…
Reference in New Issue
Block a user