mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-08 17:51:45 +00:00
server : remove legacy system_prompt feature
ggml-ci
This commit is contained in:
parent
11ac9800af
commit
f6dd38c2dd
@ -1788,23 +1788,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||||||
params.n_threads_http = value;
|
params.n_threads_http = value;
|
||||||
}
|
}
|
||||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_THREADS_HTTP"));
|
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_THREADS_HTTP"));
|
||||||
add_opt(common_arg(
|
|
||||||
{"-spf", "--system-prompt-file"}, "FNAME",
|
|
||||||
"set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications",
|
|
||||||
[](common_params & params, const std::string & value) {
|
|
||||||
std::ifstream file(value);
|
|
||||||
if (!file) {
|
|
||||||
throw std::runtime_error(string_format("error: failed to open file '%s'\n", value.c_str()));
|
|
||||||
}
|
|
||||||
std::string system_prompt;
|
|
||||||
std::copy(
|
|
||||||
std::istreambuf_iterator<char>(file),
|
|
||||||
std::istreambuf_iterator<char>(),
|
|
||||||
std::back_inserter(system_prompt)
|
|
||||||
);
|
|
||||||
params.system_prompt = system_prompt;
|
|
||||||
}
|
|
||||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"--metrics"},
|
{"--metrics"},
|
||||||
string_format("enable prometheus compatible metrics endpoint (default: %s)", params.endpoint_metrics ? "enabled" : "disabled"),
|
string_format("enable prometheus compatible metrics endpoint (default: %s)", params.endpoint_metrics ? "enabled" : "disabled"),
|
||||||
|
@ -282,7 +282,6 @@ struct common_params {
|
|||||||
std::string hostname = "127.0.0.1";
|
std::string hostname = "127.0.0.1";
|
||||||
std::string public_path = ""; // NOLINT
|
std::string public_path = ""; // NOLINT
|
||||||
std::string chat_template = ""; // NOLINT
|
std::string chat_template = ""; // NOLINT
|
||||||
std::string system_prompt = ""; // NOLINT
|
|
||||||
bool enable_chat_template = true;
|
bool enable_chat_template = true;
|
||||||
|
|
||||||
std::vector<std::string> api_keys;
|
std::vector<std::string> api_keys;
|
||||||
|
@ -623,12 +623,6 @@ struct server_context {
|
|||||||
|
|
||||||
int32_t n_ctx; // total context for all clients / slots
|
int32_t n_ctx; // total context for all clients / slots
|
||||||
|
|
||||||
// system prompt
|
|
||||||
bool system_need_update = false;
|
|
||||||
|
|
||||||
std::string system_prompt;
|
|
||||||
std::vector<llama_token> system_tokens;
|
|
||||||
|
|
||||||
// slots / clients
|
// slots / clients
|
||||||
std::vector<server_slot> slots;
|
std::vector<server_slot> slots;
|
||||||
json default_generation_settings_for_props;
|
json default_generation_settings_for_props;
|
||||||
@ -665,7 +659,7 @@ struct server_context {
|
|||||||
bool load_model(const common_params & params_) {
|
bool load_model(const common_params & params_) {
|
||||||
params = params_;
|
params = params_;
|
||||||
|
|
||||||
// dedicate one sequence to the system prompt
|
// reserve one extra sequence (seq_id == 0) for extra features
|
||||||
params.n_parallel += 1;
|
params.n_parallel += 1;
|
||||||
|
|
||||||
common_init_result llama_init = common_init_from_params(params);
|
common_init_result llama_init = common_init_from_params(params);
|
||||||
@ -1061,51 +1055,6 @@ struct server_context {
|
|||||||
clean_kv_cache = false;
|
clean_kv_cache = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
void system_prompt_update() {
|
|
||||||
SRV_DBG("updating system prompt: '%s'\n", system_prompt.c_str());
|
|
||||||
|
|
||||||
kv_cache_clear();
|
|
||||||
system_tokens.clear();
|
|
||||||
|
|
||||||
if (!system_prompt.empty()) {
|
|
||||||
system_tokens = common_tokenize(ctx, system_prompt, true);
|
|
||||||
|
|
||||||
const int32_t n_batch = llama_n_batch(ctx);
|
|
||||||
const int32_t n_tokens_prompt = system_tokens.size();
|
|
||||||
|
|
||||||
for (int32_t i = 0; i < n_tokens_prompt; i += n_batch) {
|
|
||||||
const int32_t n_tokens = std::min(n_batch, n_tokens_prompt - i);
|
|
||||||
|
|
||||||
common_batch_clear(batch);
|
|
||||||
|
|
||||||
for (int32_t j = 0; j < n_tokens; ++j) {
|
|
||||||
common_batch_add(batch, system_tokens[i + j], i + j, { 0 }, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (llama_decode(ctx, batch) != 0) {
|
|
||||||
SRV_ERR("%s", "llama_decode() failed\n");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// assign the system KV cache to all parallel sequences
|
|
||||||
for (int32_t i = 1; i <= params.n_parallel; ++i) {
|
|
||||||
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
system_need_update = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool system_prompt_set(const std::string & sys_prompt) {
|
|
||||||
SRV_DBG("system prompt set: '%s'\n", system_prompt.c_str());
|
|
||||||
|
|
||||||
system_prompt = sys_prompt;
|
|
||||||
// update system_tokens and KV cache as soon as all slots are idle
|
|
||||||
system_need_update = true;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool process_token(completion_token_output & result, server_slot & slot) {
|
bool process_token(completion_token_output & result, server_slot & slot) {
|
||||||
// remember which tokens were sampled - used for repetition penalties during sampling
|
// remember which tokens were sampled - used for repetition penalties during sampling
|
||||||
const std::string token_str = common_token_to_piece(ctx, result.tok, params.special);
|
const std::string token_str = common_token_to_piece(ctx, result.tok, params.special);
|
||||||
@ -1855,12 +1804,8 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (all_idle) {
|
if (all_idle) {
|
||||||
if (system_need_update) {
|
|
||||||
system_prompt_update();
|
|
||||||
}
|
|
||||||
|
|
||||||
SRV_INF("%s", "all slots are idle\n");
|
SRV_INF("%s", "all slots are idle\n");
|
||||||
if (system_prompt.empty() && clean_kv_cache) {
|
if (clean_kv_cache) {
|
||||||
kv_cache_clear();
|
kv_cache_clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1882,7 +1827,7 @@ struct server_context {
|
|||||||
// TODO: simplify and improve
|
// TODO: simplify and improve
|
||||||
for (server_slot & slot : slots) {
|
for (server_slot & slot : slots) {
|
||||||
if (slot.ga_n == 1) {
|
if (slot.ga_n == 1) {
|
||||||
if (slot.is_processing() && (int) system_tokens.size() + slot.n_past >= slot.n_ctx - 1) {
|
if (slot.is_processing() && slot.n_past >= slot.n_ctx - 1) {
|
||||||
if (!params.ctx_shift) {
|
if (!params.ctx_shift) {
|
||||||
// this check is redundant (for good)
|
// this check is redundant (for good)
|
||||||
// we should never get here, because generation should already stopped in process_token()
|
// we should never get here, because generation should already stopped in process_token()
|
||||||
@ -1893,13 +1838,13 @@ struct server_context {
|
|||||||
|
|
||||||
// Shift context
|
// Shift context
|
||||||
const int n_keep = slot.params.n_keep + add_bos_token;
|
const int n_keep = slot.params.n_keep + add_bos_token;
|
||||||
const int n_left = (int) system_tokens.size() + slot.n_past - n_keep;
|
const int n_left = slot.n_past - n_keep;
|
||||||
const int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2);
|
const int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2);
|
||||||
|
|
||||||
SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard);
|
SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard);
|
||||||
|
|
||||||
llama_kv_cache_seq_rm (ctx, slot.id + 1, n_keep , n_keep + n_discard);
|
llama_kv_cache_seq_rm (ctx, slot.id + 1, n_keep , n_keep + n_discard);
|
||||||
llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
|
llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_discard, slot.n_past, -n_discard);
|
||||||
|
|
||||||
if (slot.params.cache_prompt) {
|
if (slot.params.cache_prompt) {
|
||||||
for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
|
for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
|
||||||
@ -1929,9 +1874,7 @@ struct server_context {
|
|||||||
|
|
||||||
const int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
|
const int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
|
||||||
|
|
||||||
// TODO: we always have to take into account the "system_tokens"
|
common_batch_add(batch, slot.sampled, slot_npast, { slot.id + 1 }, true);
|
||||||
// this is not great and needs to be improved somehow
|
|
||||||
common_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id + 1 }, true);
|
|
||||||
|
|
||||||
slot.n_past += 1;
|
slot.n_past += 1;
|
||||||
|
|
||||||
@ -1939,8 +1882,8 @@ struct server_context {
|
|||||||
slot.cache_tokens.push_back(slot.sampled);
|
slot.cache_tokens.push_back(slot.sampled);
|
||||||
}
|
}
|
||||||
|
|
||||||
SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_system_tokens = %d, n_cache_tokens = %d, truncated = %d\n",
|
SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_cache_tokens = %d, truncated = %d\n",
|
||||||
slot.n_ctx, slot.n_past, (int) system_tokens.size(), (int) slot.cache_tokens.size(), slot.truncated);
|
slot.n_ctx, slot.n_past, (int) slot.cache_tokens.size(), slot.truncated);
|
||||||
}
|
}
|
||||||
|
|
||||||
// process in chunks of params.n_batch
|
// process in chunks of params.n_batch
|
||||||
@ -1971,7 +1914,7 @@ struct server_context {
|
|||||||
case SERVER_TASK_CMPL_TYPE_NORMAL:
|
case SERVER_TASK_CMPL_TYPE_NORMAL:
|
||||||
case SERVER_TASK_CMPL_TYPE_EMBEDDING:
|
case SERVER_TASK_CMPL_TYPE_EMBEDDING:
|
||||||
{
|
{
|
||||||
prompt_tokens = tokenize(slot.prompt, system_prompt.empty(), true); // add BOS if there isn't system prompt
|
prompt_tokens = tokenize(slot.prompt, llama_add_bos_token(model), true);
|
||||||
} break;
|
} break;
|
||||||
case SERVER_TASK_CMPL_TYPE_RERANK:
|
case SERVER_TASK_CMPL_TYPE_RERANK:
|
||||||
{
|
{
|
||||||
@ -2050,7 +1993,7 @@ struct server_context {
|
|||||||
} else {
|
} else {
|
||||||
if (!params.ctx_shift) {
|
if (!params.ctx_shift) {
|
||||||
// if context shift is disabled, we make sure prompt size is smaller than KV size
|
// if context shift is disabled, we make sure prompt size is smaller than KV size
|
||||||
if ((int) system_tokens.size() + slot.n_prompt_tokens >= slot.n_ctx) {
|
if (slot.n_prompt_tokens >= slot.n_ctx) {
|
||||||
slot.release();
|
slot.release();
|
||||||
send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_INVALID_REQUEST);
|
send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_INVALID_REQUEST);
|
||||||
continue;
|
continue;
|
||||||
@ -2138,22 +2081,16 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// keep only the common part
|
// keep only the common part
|
||||||
int p0 = (int) system_tokens.size() + slot.n_past;
|
int p0 = slot.n_past;
|
||||||
if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) {
|
if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) {
|
||||||
// could not partially delete (likely using a non-Transformer model)
|
// could not partially delete (likely using a non-Transformer model)
|
||||||
llama_kv_cache_seq_rm(ctx, slot.id + 1, -1, -1);
|
llama_kv_cache_seq_rm(ctx, slot.id + 1, -1, -1);
|
||||||
|
|
||||||
p0 = (int) system_tokens.size();
|
// there is no common part left
|
||||||
if (p0 != 0) {
|
|
||||||
// copy over the system prompt when there is one
|
|
||||||
llama_kv_cache_seq_cp(ctx, 0, slot.id + 1, -1, -1);
|
|
||||||
}
|
|
||||||
|
|
||||||
// there is no common part left (except for the system prompt)
|
|
||||||
slot.n_past = 0;
|
slot.n_past = 0;
|
||||||
slot.n_past_se = 0;
|
slot.n_past_se = 0;
|
||||||
slot.ga_i = 0;
|
slot.ga_i = 0;
|
||||||
// TODO: is the system prompt ever in the sampling context?
|
|
||||||
common_sampler_reset(slot.smpl);
|
common_sampler_reset(slot.smpl);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2179,7 +2116,7 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
common_batch_add(batch, prompt_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id + 1 }, false);
|
common_batch_add(batch, prompt_tokens[slot.n_past], slot_npast, { slot.id + 1 }, false);
|
||||||
|
|
||||||
if (slot.params.cache_prompt) {
|
if (slot.params.cache_prompt) {
|
||||||
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
|
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
|
||||||
@ -2409,10 +2346,6 @@ int main(int argc, char ** argv) {
|
|||||||
// struct that contains llama context and inference
|
// struct that contains llama context and inference
|
||||||
server_context ctx_server;
|
server_context ctx_server;
|
||||||
|
|
||||||
if (!params.system_prompt.empty()) {
|
|
||||||
ctx_server.system_prompt_set(params.system_prompt);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (params.model_alias == "unknown") {
|
if (params.model_alias == "unknown") {
|
||||||
params.model_alias = params.model;
|
params.model_alias = params.model;
|
||||||
}
|
}
|
||||||
@ -2840,7 +2773,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
const auto handle_props = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
|
const auto handle_props = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
|
||||||
json data = {
|
json data = {
|
||||||
{ "system_prompt", ctx_server.system_prompt },
|
{ "system_prompt", "[unavailable]" },
|
||||||
{ "default_generation_settings", ctx_server.default_generation_settings_for_props },
|
{ "default_generation_settings", ctx_server.default_generation_settings_for_props },
|
||||||
{ "total_slots", ctx_server.params.n_parallel },
|
{ "total_slots", ctx_server.params.n_parallel },
|
||||||
{ "chat_template", llama_get_chat_template(ctx_server.model) },
|
{ "chat_template", llama_get_chat_template(ctx_server.model) },
|
||||||
@ -2856,10 +2789,8 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
json data = json::parse(req.body);
|
json data = json::parse(req.body);
|
||||||
if (data.contains("system_prompt")) {
|
|
||||||
std::string system_prompt = data.at("system_prompt");
|
// update any props here
|
||||||
ctx_server.system_prompt_set(system_prompt);
|
|
||||||
}
|
|
||||||
|
|
||||||
res_ok(res, {{ "success", true }});
|
res_ok(res, {{ "success", true }});
|
||||||
};
|
};
|
||||||
|
Loading…
Reference in New Issue
Block a user