multiple client support

This commit is contained in:
FSSRepo 2023-10-12 17:09:12 -04:00
parent 81484805f0
commit 5b8e29de53

View File

@ -76,7 +76,7 @@ struct slot_params {
uint32_t seed = -1; // RNG seed uint32_t seed = -1; // RNG seed
int32_t n_predict = 128; // new tokens to predict int32_t n_predict = 128; // new tokens to predict
std::string grammar = ""; // optional BNF-like grammar to constrain sampling std::string grammar = ""; // optional BNF-like grammar to constrain sampling
bool remember_generation = false; // remember a part of the prompt to avoid reprocessing all prompt bool remember_generation = false; // remember a the prompt to avoid reprocessing all prompt
std::vector<std::string> antiprompt; std::vector<std::string> antiprompt;
}; };
@ -256,6 +256,7 @@ struct llama_client_slot
stopping_word = ""; stopping_word = "";
multibyte_pending = 0; multibyte_pending = 0;
n_past = 0; n_past = 0;
sent_count = 0;
if (grammar != nullptr) { if (grammar != nullptr) {
llama_grammar_free(grammar); llama_grammar_free(grammar);
@ -299,8 +300,7 @@ struct llama_client_slot
} }
bool available() { bool available() {
return state == IDLE && return state == IDLE && command == NONE;
command == NONE && !params.remember_generation;
} }
bool isProcessing() { bool isProcessing() {
@ -354,12 +354,6 @@ struct llama_server_context
int n_ctx; int n_ctx;
int n_vocab; int n_vocab;
bool clean_kv_cache = true; bool clean_kv_cache = true;
std::mutex mutex;
std::unique_lock<std::mutex> lock()
{
return std::unique_lock<std::mutex>(mutex);
}
~llama_server_context() ~llama_server_context()
{ {
@ -406,7 +400,7 @@ struct llama_server_context
slot.last_n_tokens.resize(params.n_predict); // max prediction per slot slot.last_n_tokens.resize(params.n_predict); // max prediction per slot
slot.reset(); slot.reset();
std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end(), 0); std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end(), 0);
LOG_TEE(" - slot %i\n", slot.id); LOG_TEE(" -> Slot %i\n", slot.id);
slots.push_back(slot); slots.push_back(slot);
} }
LOG_TEE("Context Size: %i\n", params.n_ctx); LOG_TEE("Context Size: %i\n", params.n_ctx);
@ -716,7 +710,6 @@ struct llama_server_context
slot.generated_text.begin() + pos + stop_pos, slot.generated_text.begin() + pos + stop_pos,
slot.generated_text.end()); slot.generated_text.end());
pos = std::min(slot.sent_count, slot.generated_text.size()); pos = std::min(slot.sent_count, slot.generated_text.size());
result.tok = -1;
} else { } else {
is_stop_full = false; is_stop_full = false;
stop_pos = findStoppingStrings(str_test, token_str.size(), stop_pos = findStoppingStrings(str_test, token_str.size(),
@ -737,7 +730,6 @@ struct llama_server_context
{ {
slot.generated_token_probs.push_back(result); slot.generated_token_probs.push_back(result);
} }
if (slot.multibyte_pending > 0) if (slot.multibyte_pending > 0)
{ {
slot.multibyte_pending -= token_str.size(); slot.multibyte_pending -= token_str.size();
@ -780,7 +772,6 @@ struct llama_server_context
slot.stopped_eos = true; slot.stopped_eos = true;
LOG_VERBOSE("eos token found", {}); LOG_VERBOSE("eos token found", {});
} }
LOG_VERBOSE("next token", { LOG_VERBOSE("next token", {
{"token", result.tok}, {"token", result.tok},
{"token_text", tokens_to_output_formatted_string(ctx, result.tok)}, {"token_text", tokens_to_output_formatted_string(ctx, result.tok)},
@ -818,18 +809,11 @@ struct llama_server_context
if (slot.state == PROCESSING && slot.command == RELEASE && !slot.hasNewToken()) if (slot.state == PROCESSING && slot.command == RELEASE && !slot.hasNewToken())
{ {
LOG_TEE("slot %i released\n", slot.id); LOG_TEE("slot %i released\n", slot.id);
if(!slot.params.remember_generation) { slot.state = slot.params.remember_generation ? SLEEPING : IDLE;
llama_kv_cache_seq_rm(ctx, slot.id, num_tokens_system, n_ctx); slot.command = NONE;
slot.state = IDLE;
slot.command = NONE;
slot.generated_text.clear();
return true;
} else {
slot.state = SLEEPING;
slot.command = NONE;
}
continue; continue;
} }
kv_cache_free -= slot.num_prompt_tokens; kv_cache_free -= slot.num_prompt_tokens;
if (slot.state == IDLE || slot.command == RELEASE) { if (slot.state == IDLE || slot.command == RELEASE) {
@ -858,23 +842,28 @@ struct llama_server_context
auto prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt auto prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt
slot.num_prompt_tokens = prompt_tokens.size(); slot.num_prompt_tokens = prompt_tokens.size();
slot.n_past = keep_gen ? common_part(slot.context_tokens, prompt_tokens) : 0; slot.n_past = keep_gen ? common_part(slot.context_tokens, prompt_tokens) : 0;
slot.context_tokens = prompt_tokens; slot.context_tokens = prompt_tokens;
if (slot.n_past == slot.num_prompt_tokens) {
// we have to evaluate at least 1 token to generate logits.
printf("we have to evaluate at least 1 token to generate logits\n");
slot.n_past--;
}
llama_kv_cache_seq_rm(ctx, slot.id, num_tokens_system + slot.n_past, -1);
LOG_VERBOSE("prompt ingested", { LOG_VERBOSE("prompt ingested", {
{"n_past", slot.n_past}, {"n_past", slot.n_past},
{"cached", tokens_to_str(ctx, slot.context_tokens.cbegin(), slot.context_tokens.cbegin() + slot.n_past)}, {"cached", tokens_to_str(ctx, slot.context_tokens.cbegin(), slot.context_tokens.cbegin() + slot.n_past)},
{"to_eval", tokens_to_str(ctx, slot.context_tokens.cbegin() + slot.n_past, slot.context_tokens.cend())}, {"to_eval", tokens_to_str(ctx, slot.context_tokens.cbegin() + slot.n_past, slot.context_tokens.cend())},
}); });
if(system_prompt.empty()) {
llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1);
}
std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end(), 0); std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end(), 0);
for (; slot.n_past < prompt_tokens.size(); ++slot.n_past) { for (; slot.n_past < prompt_tokens.size(); ++slot.n_past) {
printf(llama_token_to_piece(ctx, prompt_tokens[slot.n_past]).c_str()); //printf(llama_token_to_piece(ctx, prompt_tokens[slot.n_past]).c_str());
batch.token [batch.n_tokens] = prompt_tokens[slot.n_past]; batch.token [batch.n_tokens] = prompt_tokens[slot.n_past];
batch.pos [batch.n_tokens] = slot.n_past + num_tokens_system; batch.pos [batch.n_tokens] = slot.n_past + num_tokens_system;
batch.seq_id[batch.n_tokens] = slot.id; batch.seq_id[batch.n_tokens] = slot.id;
@ -1693,7 +1682,6 @@ int main(int argc, char **argv)
svr.Post("/completion", [&llama](const Request &req, Response &res) svr.Post("/completion", [&llama](const Request &req, Response &res)
{ {
auto lock = llama.lock();
json data = json::parse(req.body); json data = json::parse(req.body);
@ -1763,13 +1751,12 @@ int main(int argc, char **argv)
// "application/json"); // "application/json");
} else { } else {
const auto chunked_content_provider = [slot, &llama](size_t, DataSink & sink) { const auto chunked_content_provider = [slot, &llama](size_t, DataSink & sink) {
size_t sent_count = 0;
size_t sent_token_probs_index = 0; size_t sent_token_probs_index = 0;
while(slot->isProcessing()) { while(slot->isProcessing()) {
if(slot->hasNewToken()) { // new token notification if(slot->hasNewToken()) { // new token notification
const completion_token_output token = slot->next(); const completion_token_output token = slot->next();
std::vector<completion_token_output> probs_output = {}; std::vector<completion_token_output> probs_output = {};
if (slot->sparams.n_probs > 0) { if (slot->sparams.n_probs > 0) {
const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, token.text_to_send, false); const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, token.text_to_send, false);
size_t probs_pos = std::min(sent_token_probs_index, slot->generated_token_probs.size()); size_t probs_pos = std::min(sent_token_probs_index, slot->generated_token_probs.size());
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), slot->generated_token_probs.size()); size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), slot->generated_token_probs.size());
@ -1816,12 +1803,10 @@ int main(int argc, char **argv)
return true; return true;
}; };
auto on_complete = [slot, &llama] (bool) { auto on_complete = [slot, &llama] (bool) {
llama.mutex.unlock();
slot->sent_tokens = 0; slot->sent_tokens = 0;
slot->generated_token_probs.clear(); slot->generated_token_probs.clear();
slot->release(); slot->release();
}; };
lock.release();
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
} }); } });
@ -1956,7 +1941,6 @@ int main(int argc, char **argv)
svr.Post("/tokenize", [&llama](const Request &req, Response &res) svr.Post("/tokenize", [&llama](const Request &req, Response &res)
{ {
auto lock = llama.lock();
const json body = json::parse(req.body); const json body = json::parse(req.body);
std::vector<llama_token> tokens; std::vector<llama_token> tokens;
@ -1969,7 +1953,6 @@ int main(int argc, char **argv)
svr.Post("/detokenize", [&llama](const Request &req, Response &res) svr.Post("/detokenize", [&llama](const Request &req, Response &res)
{ {
auto lock = llama.lock();
const json body = json::parse(req.body); const json body = json::parse(req.body);
std::string content; std::string content;