mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-31 22:04:35 +00:00
rerank : cleanup + comments
This commit is contained in:
parent
6916ed1606
commit
62a45d12ef
@ -236,7 +236,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
} else if (pooling_type == LLAMA_POOLING_TYPE_RANK) {
|
||||
for (int j = 0; j < n_embd_count; j++) {
|
||||
LOG("rank score %d: %8.3f\n", j, emb[j * n_embd]);
|
||||
LOG("rerank score %d: %8.3f\n", j, emb[j * n_embd]);
|
||||
}
|
||||
} else {
|
||||
// print the first part of the embeddings or for a single prompt, the full embedding
|
||||
|
@ -1419,7 +1419,7 @@ struct server_context {
|
||||
queue_results.send(res);
|
||||
}
|
||||
|
||||
void send_rank(const server_slot & slot, const llama_batch & batch) {
|
||||
void send_rerank(const server_slot & slot, const llama_batch & batch) {
|
||||
server_task_result res;
|
||||
res.id = slot.id_task;
|
||||
res.error = false;
|
||||
@ -1440,7 +1440,7 @@ struct server_context {
|
||||
|
||||
res.data = json {
|
||||
{"index", slot.index},
|
||||
{"rank", -1e6},
|
||||
{"score", -1e6},
|
||||
};
|
||||
|
||||
continue;
|
||||
@ -1448,11 +1448,11 @@ struct server_context {
|
||||
|
||||
res.data = json {
|
||||
{"index", slot.index},
|
||||
{"rank", embd[0]},
|
||||
{"score", embd[0]},
|
||||
};
|
||||
}
|
||||
|
||||
SLT_DBG(slot, "sending rank, res = '%s'\n", res.data.dump().c_str());
|
||||
SLT_DBG(slot, "sending rerank result, res = '%s'\n", res.data.dump().c_str());
|
||||
|
||||
queue_results.send(res);
|
||||
}
|
||||
@ -1493,6 +1493,9 @@ struct server_context {
|
||||
else if (prompt.is_array()) {
|
||||
std::vector<json> prompts = prompt;
|
||||
if (cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
||||
// prompts[0] is the question
|
||||
// the rest are the answers/documents
|
||||
SRV_DBG("creating rerank tasks, n_prompts = %d\n", (int) prompts.size() - 1);
|
||||
for (size_t i = 1; i < prompts.size(); i++) {
|
||||
json qd;
|
||||
qd.push_back(prompts[0]);
|
||||
@ -1501,6 +1504,7 @@ struct server_context {
|
||||
create_task(data, true, qd);
|
||||
}
|
||||
} else {
|
||||
SRV_DBG("creating multi-prompt tasks, n_prompts = %d\n", (int) prompts.size());
|
||||
for (size_t i = 0; i < prompts.size(); i++) {
|
||||
const auto & e = prompts[i];
|
||||
if (e.is_string() || json_is_array_of_numbers(e)) {
|
||||
@ -1965,6 +1969,7 @@ struct server_context {
|
||||
// track if this is an embedding or non-embedding batch
|
||||
// if we've added sampled tokens above, we are in non-embedding mode
|
||||
// -1: none, 0: non-embedding, 1: embedding
|
||||
// TODO: make enum
|
||||
int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;
|
||||
|
||||
// next, batch any pending prompts without exceeding n_batch
|
||||
@ -2133,6 +2138,7 @@ struct server_context {
|
||||
slot.n_prompt_tokens_processed = 0;
|
||||
}
|
||||
|
||||
// non-causal tasks require to fit the entire prompt in the physical batch
|
||||
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
||||
// cannot fit the prompt in the current batch - will try next iter
|
||||
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
|
||||
@ -2318,7 +2324,7 @@ struct server_context {
|
||||
}
|
||||
|
||||
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
||||
send_rank(slot, batch_view);
|
||||
send_rerank(slot, batch_view);
|
||||
slot.release();
|
||||
slot.i_batch = -1;
|
||||
continue; // continue loop of slots
|
||||
|
@ -553,7 +553,7 @@ static json format_response_rerank(const json & request, const json & ranks) {
|
||||
for (const auto & rank : ranks) {
|
||||
data.push_back(json{
|
||||
{"index", i++},
|
||||
{"relevance_score", json_value(rank, "rank", 0.0)},
|
||||
{"relevance_score", json_value(rank, "score", 0.0)},
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -192,7 +192,7 @@ extern "C" {
|
||||
LLAMA_POOLING_TYPE_MEAN = 1,
|
||||
LLAMA_POOLING_TYPE_CLS = 2,
|
||||
LLAMA_POOLING_TYPE_LAST = 3,
|
||||
LLAMA_POOLING_TYPE_RANK = 4,
|
||||
LLAMA_POOLING_TYPE_RANK = 4, // used by reranking models to attach the classification head to the graph
|
||||
};
|
||||
|
||||
enum llama_attention_type {
|
||||
@ -202,9 +202,9 @@ extern "C" {
|
||||
};
|
||||
|
||||
enum llama_split_mode {
|
||||
LLAMA_SPLIT_MODE_NONE = 0, // single GPU
|
||||
LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs
|
||||
LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs
|
||||
LLAMA_SPLIT_MODE_NONE = 0, // single GPU
|
||||
LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs
|
||||
LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs
|
||||
};
|
||||
|
||||
// TODO: simplify (https://github.com/ggerganov/llama.cpp/pull/9294#pullrequestreview-2286561979)
|
||||
@ -872,7 +872,8 @@ extern "C" {
|
||||
|
||||
// Get the embeddings for a sequence id
|
||||
// Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE
|
||||
// shape: [n_embd] (1-dimensional)
|
||||
// when pooling_type == LLAMA_POOLING_TYPE_RANK, returns float[1] with the rank of the sequence
|
||||
// otherwise: float[n_embd] (1-dimensional)
|
||||
LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);
|
||||
|
||||
//
|
||||
|
@ -17009,7 +17009,7 @@ static int llama_decode_internal(
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_RANK:
|
||||
{
|
||||
// extract the rank score - a single float per sequence
|
||||
// extract the rerank score - a single float per sequence
|
||||
auto & embd_seq_out = lctx.embd_seq;
|
||||
|
||||
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
|
||||
@ -17211,7 +17211,6 @@ static int llama_encode_internal(
|
||||
case LLAMA_POOLING_TYPE_MEAN:
|
||||
case LLAMA_POOLING_TYPE_CLS:
|
||||
case LLAMA_POOLING_TYPE_LAST:
|
||||
case LLAMA_POOLING_TYPE_RANK:
|
||||
{
|
||||
// extract sequence embeddings
|
||||
auto & embd_seq_out = lctx.embd_seq;
|
||||
@ -17228,6 +17227,13 @@ static int llama_encode_internal(
|
||||
ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
|
||||
}
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_RANK:
|
||||
{
|
||||
// TODO: this likely should be the same logic as in llama_decoder_internal, but better to
|
||||
// wait for an encoder model that requires this pooling type in order to test it
|
||||
// https://github.com/ggerganov/llama.cpp/pull/9510
|
||||
GGML_ABORT("RANK pooling not implemented yet");
|
||||
}
|
||||
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
||||
{
|
||||
GGML_ABORT("unknown pooling type");
|
||||
|
Loading…
Reference in New Issue
Block a user