mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 02:44:36 +00:00
server : ensure batches are either all embed or all completion (#8420)
* make sure batches are all embed or all non-embed * non-embedding batch for sampled tokens; fix unused params warning
This commit is contained in:
parent
8a4441ea1a
commit
c3ebcfa148
@ -2005,6 +2005,11 @@ struct server_context {
|
|||||||
int32_t n_batch = llama_n_batch(ctx);
|
int32_t n_batch = llama_n_batch(ctx);
|
||||||
int32_t n_ubatch = llama_n_ubatch(ctx);
|
int32_t n_ubatch = llama_n_ubatch(ctx);
|
||||||
|
|
||||||
|
// track if this is an embedding or non-embedding batch
|
||||||
|
// if we've added sampled tokens above, we are in non-embedding mode
|
||||||
|
// -1: none, 0: non-embedding, 1: embedding
|
||||||
|
int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;
|
||||||
|
|
||||||
// next, batch any pending prompts without exceeding n_batch
|
// next, batch any pending prompts without exceeding n_batch
|
||||||
if (params.cont_batching || batch.n_tokens == 0) {
|
if (params.cont_batching || batch.n_tokens == 0) {
|
||||||
for (auto & slot : slots) {
|
for (auto & slot : slots) {
|
||||||
@ -2175,6 +2180,14 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// check that we are in the right batch_type, if not defer the slot
|
||||||
|
bool slot_type = slot.embedding ? 1 : 0;
|
||||||
|
if (batch_type == -1) {
|
||||||
|
batch_type = slot_type;
|
||||||
|
} else if (batch_type != slot_type) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
// keep only the common part
|
// keep only the common part
|
||||||
int p0 = (int) system_tokens.size() + slot.n_past;
|
int p0 = (int) system_tokens.size() + slot.n_past;
|
||||||
if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) {
|
if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) {
|
||||||
@ -2276,6 +2289,9 @@ struct server_context {
|
|||||||
{"n_tokens", batch.n_tokens},
|
{"n_tokens", batch.n_tokens},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// make sure we're in the right embedding mode
|
||||||
|
llama_set_embeddings(ctx, batch_type == 1);
|
||||||
|
|
||||||
// process the created batch of tokens
|
// process the created batch of tokens
|
||||||
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
|
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
|
||||||
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
|
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
|
||||||
@ -2990,6 +3006,11 @@ int main(int argc, char ** argv) {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const auto handle_completions = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
|
const auto handle_completions = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
|
||||||
|
if (ctx_server.params.embedding) {
|
||||||
|
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||||
|
|
||||||
json data = json::parse(req.body);
|
json data = json::parse(req.body);
|
||||||
@ -3085,6 +3106,11 @@ int main(int argc, char ** argv) {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error](const httplib::Request & req, httplib::Response & res) {
|
const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error](const httplib::Request & req, httplib::Response & res) {
|
||||||
|
if (ctx_server.params.embedding) {
|
||||||
|
res_error(res, format_error_response("This server does not support chat completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||||
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
|
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
|
||||||
|
|
||||||
@ -3157,6 +3183,11 @@ int main(int argc, char ** argv) {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const auto handle_infill = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
|
const auto handle_infill = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
|
||||||
|
if (ctx_server.params.embedding) {
|
||||||
|
res_error(res, format_error_response("This server does not support infill. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||||
|
|
||||||
json data = json::parse(req.body);
|
json data = json::parse(req.body);
|
||||||
@ -3243,13 +3274,8 @@ int main(int argc, char ** argv) {
|
|||||||
return res.set_content(data.dump(), "application/json; charset=utf-8");
|
return res.set_content(data.dump(), "application/json; charset=utf-8");
|
||||||
};
|
};
|
||||||
|
|
||||||
const auto handle_embeddings = [¶ms, &ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
|
const auto handle_embeddings = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
|
||||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||||
if (!params.embedding) {
|
|
||||||
res.status = 501;
|
|
||||||
res.set_content("This server does not support embeddings. Start it with `--embeddings`", "text/plain; charset=utf-8");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const json body = json::parse(req.body);
|
const json body = json::parse(req.body);
|
||||||
bool is_openai = false;
|
bool is_openai = false;
|
||||||
|
Loading…
Reference in New Issue
Block a user