grammar + no stream completion

This commit is contained in:
FSSRepo 2023-10-12 18:43:57 -04:00
parent 5b8e29de53
commit 83c2b3553a

View File

@ -32,7 +32,7 @@ struct server_params
{ {
std::string hostname = "127.0.0.1"; std::string hostname = "127.0.0.1";
std::string public_path = "examples/server/public"; std::string public_path = "examples/server/public";
int32_t port = 8040; int32_t port = 8080;
int32_t read_timeout = 600; int32_t read_timeout = 600;
int32_t write_timeout = 600; int32_t write_timeout = 600;
}; };
@ -78,6 +78,8 @@ struct slot_params {
std::string grammar = ""; // optional BNF-like grammar to constrain sampling std::string grammar = ""; // optional BNF-like grammar to constrain sampling
bool remember_generation = false; // remember a the prompt to avoid reprocessing all prompt bool remember_generation = false; // remember a the prompt to avoid reprocessing all prompt
std::vector<std::string> antiprompt; std::vector<std::string> antiprompt;
json input_prefix;
json input_suffix;
}; };
// completion token output with probabilities // completion token output with probabilities
@ -233,6 +235,7 @@ struct llama_client_slot
std::string stopping_word; std::string stopping_word;
int32_t multibyte_pending = 0; int32_t multibyte_pending = 0;
size_t sent_count = 0; size_t sent_count = 0;
bool infill = false;
struct slot_params params; struct slot_params params;
struct llama_sampling_params sparams; struct llama_sampling_params sparams;
@ -257,6 +260,7 @@ struct llama_client_slot
multibyte_pending = 0; multibyte_pending = 0;
n_past = 0; n_past = 0;
sent_count = 0; sent_count = 0;
infill = false;
if (grammar != nullptr) { if (grammar != nullptr) {
llama_grammar_free(grammar); llama_grammar_free(grammar);
@ -508,82 +512,6 @@ struct llama_server_context
return true; return true;
} }
void loadInfill()
{
// bool suff_rm_leading_spc = true;
// if (params.input_suffix.find_first_of(" ") == 0 && params.input_suffix.size() > 1) {
// params.input_suffix.erase(0, 1);
// suff_rm_leading_spc = false;
// }
// auto prefix_tokens = tokenize(params.input_prefix, false);
// auto suffix_tokens = tokenize(params.input_suffix, false);
// const int space_token = 29871;
// if (suff_rm_leading_spc && suffix_tokens[0] == space_token) {
// suffix_tokens.erase(suffix_tokens.begin());
// }
// prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(ctx));
// prefix_tokens.insert(prefix_tokens.begin(), llama_token_bos(ctx)); // always add BOS
// prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(ctx));
// prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end());
// prefix_tokens.push_back(llama_token_middle(ctx));
// auto prompt_tokens = prefix_tokens;
// num_prompt_tokens = prompt_tokens.size();
// if (params.n_keep < 0)
// {
// params.n_keep = (int)num_prompt_tokens;
// }
// params.n_keep = std::min(params.n_ctx - 4, params.n_keep);
// // if input prompt is too big, truncate like normal
// if (num_prompt_tokens >= (size_t)params.n_ctx)
// {
// printf("Input prompt is too big, truncating. Can only take %d tokens but got %zu\n", params.n_ctx, num_prompt_tokens);
// // todo we probably want to cut from both sides
// const int n_left = (params.n_ctx - params.n_keep) / 2;
// std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
// const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left;
// new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end());
// std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), last_n_tokens.begin());
// LOG_VERBOSE("input truncated", {
// {"n_ctx", params.n_ctx},
// {"n_keep", params.n_keep},
// {"n_left", n_left},
// {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
// });
// truncated = true;
// prompt_tokens = new_tokens;
// }
// else
// {
// const size_t ps = num_prompt_tokens;
// std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0);
// std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps);
// }
// // compare the evaluated prompt with the new prompt
// n_past = common_part(embd, prompt_tokens);
// embd = prompt_tokens;
// if (n_past == num_prompt_tokens)
// {
// // we have to evaluate at least 1 token to generate logits.
// printf("we have to evaluate at least 1 token to generate logits\n");
// n_past--;
// }
// LOG_VERBOSE("prompt ingested", {
// {"n_past", n_past},
// {"cached", tokens_to_str(ctx, embd.cbegin(), embd.cbegin() + n_past)},
// {"to_eval", tokens_to_str(ctx, embd.cbegin() + n_past, embd.cend())},
// });
// has_next_token = true;
}
void cleanKVCache() { void cleanKVCache() {
// clear the entire KV cache // clear the entire KV cache
for (int i = 0; i < params.n_parallel; ++i) for (int i = 0; i < params.n_parallel; ++i)
@ -839,8 +767,29 @@ struct llama_server_context
if ((slot.state == IDLE || keep_gen) && slot.command == LOAD_PROMPT) { if ((slot.state == IDLE || keep_gen) && slot.command == LOAD_PROMPT) {
slot.state = PROCESSING; slot.state = PROCESSING;
slot.command = NONE; slot.command = NONE;
std::vector<llama_token> prompt_tokens;
if(slot.infill) {
bool suff_rm_leading_spc = true;
if (params.input_suffix.find_first_of(" ") == 0 && params.input_suffix.size() > 1) {
params.input_suffix.erase(0, 1);
suff_rm_leading_spc = false;
}
auto prefix_tokens = tokenize(slot.params.input_prefix, false);
auto suffix_tokens = tokenize(slot.params.input_suffix, false);
const int space_token = 29871;
if (suff_rm_leading_spc && suffix_tokens[0] == space_token) {
suffix_tokens.erase(suffix_tokens.begin());
}
prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(ctx));
prefix_tokens.insert(prefix_tokens.begin(), llama_token_bos(ctx)); // always add BOS
prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(ctx));
prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end());
prefix_tokens.push_back(llama_token_middle(ctx));
prompt_tokens = prefix_tokens;
} else {
prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt
}
auto prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt
slot.num_prompt_tokens = prompt_tokens.size(); slot.num_prompt_tokens = prompt_tokens.size();
slot.n_past = keep_gen ? common_part(slot.context_tokens, prompt_tokens) : 0; slot.n_past = keep_gen ? common_part(slot.context_tokens, prompt_tokens) : 0;
@ -1304,7 +1253,7 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
} }
} }
static json format_generation_settings(llama_server_context &llama, llama_client_slot* &slot) static json format_generation_settings(llama_server_context &llama, llama_client_slot* slot)
{ {
const auto eos_bias = slot->sparams.logit_bias.find(llama_token_eos(llama.ctx)); const auto eos_bias = slot->sparams.logit_bias.find(llama_token_eos(llama.ctx));
const bool ignore_eos = eos_bias != slot->sparams.logit_bias.end() && const bool ignore_eos = eos_bias != slot->sparams.logit_bias.end() &&
@ -1428,7 +1377,7 @@ static T json_value(const json &body, const std::string &key, const T &default_v
: default_value; : default_value;
} }
static void parse_options_completion(const json &body, llama_client_slot* &slot, llama_server_context &llama) static void parse_options_completion(const json &body, llama_client_slot* slot, llama_server_context &llama)
{ {
slot_params default_params; slot_params default_params;
llama_sampling_params default_sparams; llama_sampling_params default_sparams;
@ -1508,26 +1457,26 @@ static void parse_options_completion(const json &body, llama_client_slot* &slot,
LOG_VERBOSE("completion parameters parsed", format_generation_settings(llama, slot)); LOG_VERBOSE("completion parameters parsed", format_generation_settings(llama, slot));
} }
// static void parse_options_infill(const json &body, llama_server_context &llama) static void parse_options_infill(const json &body, llama_server_context &llama, llama_client_slot *slot)
// { {
// if (body.count("input_prefix") != 0) if (body.count("input_prefix") != 0)
// { {
// llama.params.input_prefix = body["input_prefix"]; slot->params.input_prefix = body["input_prefix"];
// } }
// else else
// { {
// llama.params.input_prefix = ""; slot->params.input_prefix = "";
// } }
// if (body.count("input_suffix") != 0) if (body.count("input_suffix") != 0)
// { {
// llama.params.input_suffix = body["input_suffix"]; slot->params.input_suffix = body["input_suffix"];
// } }
// else else
// { {
// llama.params.input_suffix = ""; slot->params.input_suffix = "";
// } }
// parse_options_completion(body, slot, llama); parse_options_completion(body, slot, llama);
// } }
static void log_server_request(const Request &req, const Response &res) static void log_server_request(const Request &req, const Response &res)
{ {
@ -1682,7 +1631,6 @@ int main(int argc, char **argv)
svr.Post("/completion", [&llama](const Request &req, Response &res) svr.Post("/completion", [&llama](const Request &req, Response &res)
{ {
json data = json::parse(req.body); json data = json::parse(req.body);
llama_client_slot* slot = llama.getSlot(json_value(data, "slot_id", -1)); llama_client_slot* slot = llama.getSlot(json_value(data, "slot_id", -1));
@ -1702,7 +1650,7 @@ int main(int argc, char **argv)
slot->reset(); slot->reset();
parse_options_completion(json::parse(req.body), slot, llama); parse_options_completion(data, slot, llama);
if (!llama.launchSlot(slot)) if (!llama.launchSlot(slot))
{ {
@ -1711,44 +1659,36 @@ int main(int argc, char **argv)
} }
if (!slot->params.stream) { if (!slot->params.stream) {
// if (llama.params.n_beams) { std::string completion_text = "";
// // Fill llama.generated_token_probs vector with final beam. if (llama.params.n_beams) {
// llama_beam_search(llama.ctx, beam_search_callback, &llama, llama.params.n_beams, // // Fill llama.generated_token_probs vector with final beam.
// llama.n_past, llama.n_remain); // llama_beam_search(llama.ctx, beam_search_callback, &llama, llama.params.n_beams,
// // Translate llama.generated_token_probs to llama.generated_text. // slot->n_past, llama.n_remain);
// append_to_generated_text_from_generated_token_probs(llama); // // Translate llama.generated_token_probs to llama.generated_text.
// } else { // append_to_generated_text_from_generated_token_probs(llama);
// size_t stop_pos = std::string::npos; } else {
// while (llama.has_next_token) { while (slot->isProcessing()) {
// const completion_token_output token_with_probs = llama.doCompletion(); if(slot->hasNewToken()) {
// const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_piece(llama.ctx, token_with_probs.tok); completion_text += slot->next().text_to_send;
} else {
std::this_thread::sleep_for(std::chrono::milliseconds(5));
}
}
}
// stop_pos = llama.findStoppingStrings(llama.generated_text, auto probs = slot->generated_token_probs;
// token_text.size(), STOP_FULL); if (slot->sparams.n_probs > 0 && slot->stopped_word) {
// } const std::vector<llama_token> stop_word_toks = llama_tokenize(llama.ctx, slot->stopping_word, false);
probs = std::vector<completion_token_output>(slot->generated_token_probs.begin(), slot->generated_token_probs.end() - stop_word_toks.size());
}
// if (stop_pos == std::string::npos) { const json data = format_final_response(llama, slot, completion_text, probs);
// stop_pos = llama.findStoppingStrings(llama.generated_text, 0, STOP_PARTIAL);
// }
// if (stop_pos != std::string::npos) {
// llama.generated_text.erase(llama.generated_text.begin() + stop_pos,
// llama.generated_text.end());
// }
// }
// auto probs = llama.generated_token_probs; //llama_print_timings(llama.ctx);
// if (llama.params.n_probs > 0 && llama.stopped_word) {
// const std::vector<llama_token> stop_word_toks = llama_tokenize(llama.ctx, llama.stopping_word, false);
// probs = std::vector<completion_token_output>(llama.generated_token_probs.begin(), llama.generated_token_probs.end() - stop_word_toks.size());
// }
// const json data = format_final_response(llama, llama.generated_text, probs); res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace),
"application/json");
// llama_print_timings(llama.ctx);
// res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace),
// "application/json");
} else { } else {
const auto chunked_content_provider = [slot, &llama](size_t, DataSink & sink) { const auto chunked_content_provider = [slot, &llama](size_t, DataSink & sink) {
size_t sent_token_probs_index = 0; size_t sent_token_probs_index = 0;
@ -1810,131 +1750,101 @@ int main(int argc, char **argv)
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
} }); } });
// svr.Post("/infill", [&llama](const Request &req, Response &res) svr.Post("/infill", [&llama](const Request &req, Response &res)
// { {
// auto lock = llama.lock();
// llama.rewind(); json data = json::parse(req.body);
// llama_reset_timings(llama.ctx); llama_client_slot* slot = llama.getSlot(json_value(data, "slot_id", -1));
// parse_options_infill(json::parse(req.body), llama); if(slot == nullptr) {
LOG_TEE("slot unavailable\n");
res.status = 404;
res.set_content("slot_error", "text/plain");
return;
}
// if (!llama.loadGrammar()) if(data.contains("system_prompt")) {
// { llama.processSystemPromptData(data["system_prompt"]);
// res.status = 400; }
// return;
// }
// llama.loadInfill();
// llama.beginCompletion();
// const auto chunked_content_provider = [&](size_t, DataSink & sink) {
// size_t sent_count = 0;
// size_t sent_token_probs_index = 0;
// while (llama.has_next_token) { // llama_reset_timings(llama.ctx);
// const completion_token_output token_with_probs = llama.doCompletion();
// if (token_with_probs.tok == -1 || llama.multibyte_pending > 0) {
// continue;
// }
// const std::string token_text = llama_token_to_piece(llama.ctx, token_with_probs.tok);
// size_t pos = std::min(sent_count, llama.generated_text.size()); slot->reset();
slot->infill = true;
// const std::string str_test = llama.generated_text.substr(pos); parse_options_infill(data, llama, slot);
// bool is_stop_full = false;
// size_t stop_pos =
// llama.findStoppingStrings(str_test, token_text.size(), STOP_FULL);
// if (stop_pos != std::string::npos) {
// is_stop_full = true;
// llama.generated_text.erase(
// llama.generated_text.begin() + pos + stop_pos,
// llama.generated_text.end());
// pos = std::min(sent_count, llama.generated_text.size());
// } else {
// is_stop_full = false;
// stop_pos = llama.findStoppingStrings(str_test, token_text.size(),
// STOP_PARTIAL);
// }
// if ( if (!llama.launchSlot(slot))
// stop_pos == std::string::npos || {
// // Send rest of the text if we are at the end of the generation res.status = 400;
// (!llama.has_next_token && !is_stop_full && stop_pos > 0) return;
// ) { }
// const std::string to_send = llama.generated_text.substr(pos, std::string::npos);
// sent_count += to_send.size(); const auto chunked_content_provider = [slot, &llama](size_t, DataSink & sink) {
size_t sent_token_probs_index = 0;
while(slot->isProcessing()) {
if(slot->hasNewToken()) { // new token notification
const completion_token_output token = slot->next();
std::vector<completion_token_output> probs_output = {};
if (slot->sparams.n_probs > 0) {
const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, token.text_to_send, false);
size_t probs_pos = std::min(sent_token_probs_index, slot->generated_token_probs.size());
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), slot->generated_token_probs.size());
if (probs_pos < probs_stop_pos) {
probs_output = std::vector<completion_token_output>(slot->generated_token_probs.begin() + probs_pos, slot->generated_token_probs.begin() + probs_stop_pos);
}
sent_token_probs_index = probs_stop_pos;
}
const json data = format_partial_response(llama, slot, token.text_to_send, probs_output);
const std::string str =
"data: " +
data.dump(-1, ' ', false, json::error_handler_t::replace) +
"\n\n";
LOG_VERBOSE("data stream", {
{ "to_send", str }
});
if(!sink.write(str.c_str(), str.size())) {
slot->release();
return false;
}
} else {
std::this_thread::sleep_for(std::chrono::milliseconds(5));
}
}
const json data = format_final_response(
llama, slot,
"",
std::vector<completion_token_output>(
slot->generated_token_probs.begin(),
slot->generated_token_probs.begin() + sent_token_probs_index)
);
const std::string str =
"data: " +
data.dump(-1, ' ', false, json::error_handler_t::replace) +
"\n\n";
LOG_VERBOSE("data stream", {
{ "to_send", str }
});
if (!sink.write(str.data(), str.size())) {
slot->release();
return false;
}
sink.done();
return true;
};
auto on_complete = [slot, &llama] (bool) {
slot->sent_tokens = 0;
slot->generated_token_probs.clear();
slot->release();
};
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
});
// std::vector<completion_token_output> probs_output = {}; svr.Get("/model.json", [&llama](const Request &, Response &res)
{
// if (llama.params.n_probs > 0) { const json data = format_generation_settings(llama, llama.getSlot(0));
// const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, to_send, false); return res.set_content(data.dump(), "application/json"); });
// size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size());
// size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size());
// if (probs_pos < probs_stop_pos) {
// probs_output = std::vector<completion_token_output>(llama.generated_token_probs.begin() + probs_pos, llama.generated_token_probs.begin() + probs_stop_pos);
// }
// sent_token_probs_index = probs_stop_pos;
// }
// const json data = format_partial_response(llama, to_send, probs_output);
// const std::string str =
// "data: " +
// data.dump(-1, ' ', false, json::error_handler_t::replace) +
// "\n\n";
// LOG_VERBOSE("data stream", {
// { "to_send", str }
// });
// if (!sink.write(str.data(), str.size())) {
// LOG_VERBOSE("stream closed", {});
// llama_print_timings(llama.ctx);
// return false;
// }
// }
// if (!llama.has_next_token) {
// // Generation is done, send extra information.
// const json data = format_final_response(
// llama,
// "",
// std::vector<completion_token_output>(llama.generated_token_probs.begin(), llama.generated_token_probs.begin() + sent_token_probs_index)
// );
// const std::string str =
// "data: " +
// data.dump(-1, ' ', false, json::error_handler_t::replace) +
// "\n\n";
// LOG_VERBOSE("data stream", {
// { "to_send", str }
// });
// if (!sink.write(str.data(), str.size())) {
// LOG_VERBOSE("stream closed", {});
// llama_print_timings(llama.ctx);
// return false;
// }
// }
// }
// llama_print_timings(llama.ctx);
// sink.done();
// return true;
// };
// const auto on_complete = [&](bool) {
// llama.mutex.unlock();
// };
// lock.release();
// res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
// });
// svr.Get("/model.json", [&llama](const Request &, Response &res)
// {
// const json data = format_generation_settings(llama);
// return res.set_content(data.dump(), "application/json"); });
svr.Options(R"(/.*)", [](const Request &, Response &res) svr.Options(R"(/.*)", [](const Request &, Response &res)
{ return res.set_content("", "application/json"); }); { return res.set_content("", "application/json"); });
@ -1965,29 +1875,29 @@ int main(int argc, char **argv)
const json data = format_detokenized_response(content); const json data = format_detokenized_response(content);
return res.set_content(data.dump(), "application/json"); }); return res.set_content(data.dump(), "application/json"); });
// svr.Post("/embedding", [&llama](const Request &req, Response &res) svr.Post("/embedding", [&llama](const Request &req, Response &res)
// { {
// auto lock = llama.lock(); const json body = json::parse(req.body);
// const json body = json::parse(req.body); llama_client_slot* slot = llama.getSlot(-1);
// llama.rewind(); slot->reset();
// llama_reset_timings(llama.ctx); //llama_reset_timings(llama.ctx);
// if (body.count("content") != 0) if (body.count("content") != 0)
// { {
// llama.prompt = body["content"]; slot->prompt = body["content"];
// } }
// else else
// { {
// llama.prompt = ""; slot->prompt = "";
// } }
// llama.params.n_predict = 0; llama.params.n_predict = 0;
// llama.loadPrompt(); llama.launchSlot(slot);
// llama.beginCompletion(); while(slot->isProcessing()) {
// llama.doCompletion(); std::this_thread::sleep_for(std::chrono::microseconds(10));
}
// const json data = format_embedding_response(llama); const json data = format_embedding_response(llama);
// return res.set_content(data.dump(), "application/json"); }); return res.set_content(data.dump(), "application/json"); });
svr.set_logger(log_server_request); svr.set_logger(log_server_request);