mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-09-22 21:16:20 +00:00
ggml-ci
This commit is contained in:
parent
f03bcd84e7
commit
5f95dccea8
@ -1103,7 +1103,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
|
|||||||
else if (value == "rank") { params.pooling_type = LLAMA_POOLING_TYPE_RANK; }
|
else if (value == "rank") { params.pooling_type = LLAMA_POOLING_TYPE_RANK; }
|
||||||
else { throw std::invalid_argument("invalid value"); }
|
else { throw std::invalid_argument("invalid value"); }
|
||||||
}
|
}
|
||||||
).set_examples({LLAMA_EXAMPLE_EMBEDDING}));
|
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_SERVER}));
|
||||||
add_opt(llama_arg(
|
add_opt(llama_arg(
|
||||||
{"--attention"}, "{causal,non,causal}",
|
{"--attention"}, "{causal,non,causal}",
|
||||||
"attention type for embeddings, use model default if unspecified",
|
"attention type for embeddings, use model default if unspecified",
|
||||||
|
@ -92,6 +92,7 @@ enum server_task_type {
|
|||||||
enum server_task_cmpl_type {
|
enum server_task_cmpl_type {
|
||||||
SERVER_TASK_CMPL_TYPE_NORMAL,
|
SERVER_TASK_CMPL_TYPE_NORMAL,
|
||||||
SERVER_TASK_CMPL_TYPE_EMBEDDING,
|
SERVER_TASK_CMPL_TYPE_EMBEDDING,
|
||||||
|
SERVER_TASK_CMPL_TYPE_RERANK,
|
||||||
SERVER_TASK_CMPL_TYPE_INFILL,
|
SERVER_TASK_CMPL_TYPE_INFILL,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -172,6 +173,7 @@ struct server_slot {
|
|||||||
std::vector<completion_token_output> generated_token_probs;
|
std::vector<completion_token_output> generated_token_probs;
|
||||||
|
|
||||||
server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
|
server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
|
||||||
|
|
||||||
bool has_next_token = true;
|
bool has_next_token = true;
|
||||||
bool truncated = false;
|
bool truncated = false;
|
||||||
bool stopped_eos = false;
|
bool stopped_eos = false;
|
||||||
@ -942,8 +944,17 @@ struct server_context {
|
|||||||
slot.prompt = *prompt;
|
slot.prompt = *prompt;
|
||||||
} else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_array()) {
|
} else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_array()) {
|
||||||
slot.prompt = prompt->at(0);
|
slot.prompt = prompt->at(0);
|
||||||
|
} else if (prompt->is_array() && prompt->size() > 1) {
|
||||||
|
// array of strings
|
||||||
|
for (const auto & el : *prompt) {
|
||||||
|
if (!el.is_string()) {
|
||||||
|
send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
slot.prompt = *prompt;
|
||||||
} else {
|
} else {
|
||||||
send_error(task, "\"prompt\" must be a string or an array of integers", ERROR_TYPE_INVALID_REQUEST);
|
send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1368,6 +1379,7 @@ struct server_context {
|
|||||||
|
|
||||||
res.data = json {
|
res.data = json {
|
||||||
{"embedding", std::vector<float>(n_embd, 0.0f)},
|
{"embedding", std::vector<float>(n_embd, 0.0f)},
|
||||||
|
{"index", slot.index},
|
||||||
};
|
};
|
||||||
|
|
||||||
continue;
|
continue;
|
||||||
@ -1386,6 +1398,44 @@ struct server_context {
|
|||||||
queue_results.send(res);
|
queue_results.send(res);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void send_rank(const server_slot & slot, const llama_batch & batch) {
|
||||||
|
server_task_result res;
|
||||||
|
res.id = slot.id_task;
|
||||||
|
res.error = false;
|
||||||
|
res.stop = true;
|
||||||
|
|
||||||
|
for (int i = 0; i < batch.n_tokens; ++i) {
|
||||||
|
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
|
||||||
|
if (embd == NULL) {
|
||||||
|
embd = llama_get_embeddings_ith(ctx, i);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (embd == NULL) {
|
||||||
|
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
|
||||||
|
|
||||||
|
res.data = json {
|
||||||
|
{"index", slot.index},
|
||||||
|
{"rank", -1e6},
|
||||||
|
};
|
||||||
|
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
res.data = json {
|
||||||
|
{"index", slot.index},
|
||||||
|
{"rank", embd[0]},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
SLT_DBG(slot, "sending rank, res = '%s'\n", res.data.dump().c_str());
|
||||||
|
|
||||||
|
queue_results.send(res);
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// Functions to create new task(s) and receive result(s)
|
// Functions to create new task(s) and receive result(s)
|
||||||
//
|
//
|
||||||
@ -1421,13 +1471,23 @@ struct server_context {
|
|||||||
// otherwise, it's a multiple-prompt task, we break it into smaller tasks
|
// otherwise, it's a multiple-prompt task, we break it into smaller tasks
|
||||||
else if (prompt.is_array()) {
|
else if (prompt.is_array()) {
|
||||||
std::vector<json> prompts = prompt;
|
std::vector<json> prompts = prompt;
|
||||||
for (size_t i = 0; i < prompts.size(); i++) {
|
if (cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
||||||
const auto & e = prompts[i];
|
for (size_t i = 1; i < prompts.size(); i++) {
|
||||||
if (e.is_string() || json_is_array_of_numbers(e)) {
|
json qd;
|
||||||
data["index"] = i;
|
qd.push_back(prompts[0]);
|
||||||
create_task(data, true, e);
|
qd.push_back(prompts[i]);
|
||||||
} else {
|
data["index"] = i - 1;
|
||||||
throw std::runtime_error(error_msg);
|
create_task(data, true, qd);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (size_t i = 0; i < prompts.size(); i++) {
|
||||||
|
const auto & e = prompts[i];
|
||||||
|
if (e.is_string() || json_is_array_of_numbers(e)) {
|
||||||
|
data["index"] = i;
|
||||||
|
create_task(data, true, e);
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error(error_msg);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1471,7 +1531,9 @@ struct server_context {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t idx = result.data["index"];
|
const size_t idx = result.data["index"];
|
||||||
|
GGML_ASSERT(idx < results.size() && "index out of range");
|
||||||
|
|
||||||
results[idx] = result;
|
results[idx] = result;
|
||||||
}
|
}
|
||||||
result_handler(results);
|
result_handler(results);
|
||||||
@ -1922,6 +1984,29 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
prompt_tokens = embd_inp;
|
prompt_tokens = embd_inp;
|
||||||
|
} else if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
||||||
|
// require slot.prompt to be array of 2 strings
|
||||||
|
if (!slot.prompt.is_array() || slot.prompt.size() != 2) {
|
||||||
|
SLT_ERR(slot, "%s", "invalid prompt for rerank task\n");
|
||||||
|
slot.release();
|
||||||
|
send_error(slot, "invalid prompt for rerank task", ERROR_TYPE_INVALID_REQUEST);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// prompt: <s>query</s><s>doc</s>
|
||||||
|
prompt_tokens.clear();
|
||||||
|
prompt_tokens.push_back(llama_token_bos(model));
|
||||||
|
{
|
||||||
|
const auto part = tokenize(slot.prompt[0], false);
|
||||||
|
prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
|
||||||
|
}
|
||||||
|
prompt_tokens.push_back(llama_token_eos(model));
|
||||||
|
prompt_tokens.push_back(llama_token_bos(model));
|
||||||
|
{
|
||||||
|
const auto part = tokenize(slot.prompt[1], false);
|
||||||
|
prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
|
||||||
|
}
|
||||||
|
prompt_tokens.push_back(llama_token_eos(model));
|
||||||
} else {
|
} else {
|
||||||
prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt
|
prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt
|
||||||
}
|
}
|
||||||
@ -1941,7 +2026,7 @@ struct server_context {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) {
|
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
||||||
// this prompt is too large to process - discard it
|
// this prompt is too large to process - discard it
|
||||||
if (slot.n_prompt_tokens > n_ubatch) {
|
if (slot.n_prompt_tokens > n_ubatch) {
|
||||||
slot.release();
|
slot.release();
|
||||||
@ -2011,7 +2096,7 @@ struct server_context {
|
|||||||
slot.n_prompt_tokens_processed = 0;
|
slot.n_prompt_tokens_processed = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) {
|
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
||||||
// cannot fit the prompt in the current batch - will try next iter
|
// cannot fit the prompt in the current batch - will try next iter
|
||||||
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
|
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
|
||||||
continue;
|
continue;
|
||||||
@ -2019,7 +2104,10 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// check that we are in the right batch_type, if not defer the slot
|
// check that we are in the right batch_type, if not defer the slot
|
||||||
bool slot_type = slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING ? 1 : 0;
|
const bool slot_type =
|
||||||
|
slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING ||
|
||||||
|
slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK ? 1 : 0;
|
||||||
|
|
||||||
if (batch_type == -1) {
|
if (batch_type == -1) {
|
||||||
batch_type = slot_type;
|
batch_type = slot_type;
|
||||||
} else if (batch_type != slot_type) {
|
} else if (batch_type != slot_type) {
|
||||||
@ -2192,6 +2280,13 @@ struct server_context {
|
|||||||
continue; // continue loop of slots
|
continue; // continue loop of slots
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
||||||
|
send_rank(slot, batch_view);
|
||||||
|
slot.release();
|
||||||
|
slot.i_batch = -1;
|
||||||
|
continue; // continue loop of slots
|
||||||
|
}
|
||||||
|
|
||||||
// prompt evaluated for next-token prediction
|
// prompt evaluated for next-token prediction
|
||||||
slot.state = SLOT_STATE_GENERATING;
|
slot.state = SLOT_STATE_GENERATING;
|
||||||
} else if (slot.state != SLOT_STATE_GENERATING) {
|
} else if (slot.state != SLOT_STATE_GENERATING) {
|
||||||
@ -2974,6 +3069,82 @@ int main(int argc, char ** argv) {
|
|||||||
res_ok(res, root);
|
res_ok(res, root);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
||||||
|
const json body = json::parse(req.body);
|
||||||
|
|
||||||
|
// TODO: implement
|
||||||
|
//int top_n = 1;
|
||||||
|
//if (body.count("top_n") != 1) {
|
||||||
|
// top_n = body.at("top_n");
|
||||||
|
//} else {
|
||||||
|
// res_error(res, format_error_response("\"top_n\" must be provided", ERROR_TYPE_INVALID_REQUEST));
|
||||||
|
// return;
|
||||||
|
//}
|
||||||
|
|
||||||
|
json query;
|
||||||
|
if (body.count("query") == 1) {
|
||||||
|
query = body.at("query");
|
||||||
|
if (!query.is_string()) {
|
||||||
|
res_error(res, format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
exit(0);
|
||||||
|
res_error(res, format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
json documents;
|
||||||
|
if (body.count("documents") != 0) {
|
||||||
|
documents = body.at("documents");
|
||||||
|
if (!documents.is_array() || documents.size() == 0) {
|
||||||
|
res_error(res, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
res_error(res, format_error_response("\"documents\" must be provided", ERROR_TYPE_INVALID_REQUEST));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// construct prompt object: array of ["query", "doc0", "doc1", ...]
|
||||||
|
json prompt;
|
||||||
|
prompt.push_back(query);
|
||||||
|
for (const auto & doc : documents) {
|
||||||
|
prompt.push_back(doc);
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG_DBG("rerank prompt: %s\n", prompt.dump().c_str());
|
||||||
|
|
||||||
|
// create and queue the task
|
||||||
|
json responses = json::array();
|
||||||
|
bool error = false;
|
||||||
|
{
|
||||||
|
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_RERANK);
|
||||||
|
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||||
|
ctx_server.queue_tasks.post(tasks);
|
||||||
|
|
||||||
|
// get the result
|
||||||
|
std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
|
||||||
|
|
||||||
|
ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
|
||||||
|
for (const auto & res : results) {
|
||||||
|
responses.push_back(res.data);
|
||||||
|
}
|
||||||
|
}, [&](const json & error_data) {
|
||||||
|
res_error(res, error_data);
|
||||||
|
error = true;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if (error) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// write JSON response
|
||||||
|
json root = format_response_rerank(body, responses);
|
||||||
|
res_ok(res, root);
|
||||||
|
};
|
||||||
|
|
||||||
const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) {
|
const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) {
|
||||||
json result = json::array();
|
json result = json::array();
|
||||||
for (size_t i = 0; i < ctx_server.loras.size(); ++i) {
|
for (size_t i = 0; i < ctx_server.loras.size(); ++i) {
|
||||||
@ -3070,6 +3241,7 @@ int main(int argc, char ** argv) {
|
|||||||
svr->Post("/embedding", handle_embeddings); // legacy
|
svr->Post("/embedding", handle_embeddings); // legacy
|
||||||
svr->Post("/embeddings", handle_embeddings);
|
svr->Post("/embeddings", handle_embeddings);
|
||||||
svr->Post("/v1/embeddings", handle_embeddings);
|
svr->Post("/v1/embeddings", handle_embeddings);
|
||||||
|
svr->Post("/v1/rerank", handle_rerank);
|
||||||
svr->Post("/tokenize", handle_tokenize);
|
svr->Post("/tokenize", handle_tokenize);
|
||||||
svr->Post("/detokenize", handle_detokenize);
|
svr->Post("/detokenize", handle_detokenize);
|
||||||
// LoRA adapters hotswap
|
// LoRA adapters hotswap
|
||||||
|
@ -534,7 +534,7 @@ static json format_embeddings_response_oaicompat(const json & request, const jso
|
|||||||
json res = json {
|
json res = json {
|
||||||
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
|
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
|
||||||
{"object", "list"},
|
{"object", "list"},
|
||||||
{"usage", json {
|
{"usage", json { // TODO: fill
|
||||||
{"prompt_tokens", 0},
|
{"prompt_tokens", 0},
|
||||||
{"total_tokens", 0}
|
{"total_tokens", 0}
|
||||||
}},
|
}},
|
||||||
@ -544,6 +544,29 @@ static json format_embeddings_response_oaicompat(const json & request, const jso
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static json format_response_rerank(const json & request, const json & ranks) {
|
||||||
|
json data = json::array();
|
||||||
|
int i = 0;
|
||||||
|
for (const auto & rank : ranks) {
|
||||||
|
data.push_back(json{
|
||||||
|
{"index", i++},
|
||||||
|
{"relevance_score", json_value(rank, "rank", 0.0)},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
json res = json {
|
||||||
|
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
|
||||||
|
{"object", "list"},
|
||||||
|
{"usage", json { // TODO: fill
|
||||||
|
{"prompt_tokens", 0},
|
||||||
|
{"total_tokens", 0}
|
||||||
|
}},
|
||||||
|
{"results", data}
|
||||||
|
};
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
static bool is_valid_utf8(const std::string & str) {
|
static bool is_valid_utf8(const std::string & str) {
|
||||||
const unsigned char* bytes = reinterpret_cast<const unsigned char*>(str.data());
|
const unsigned char* bytes = reinterpret_cast<const unsigned char*>(str.data());
|
||||||
const unsigned char* end = bytes + str.length();
|
const unsigned char* end = bytes + str.length();
|
||||||
|
Loading…
Reference in New Issue
Block a user