multiple client support

This commit is contained in:
FSSRepo 2023-10-12 17:09:12 -04:00
parent 81484805f0
commit 5b8e29de53

View File

@ -76,7 +76,7 @@ struct slot_params {
uint32_t seed = -1; // RNG seed
int32_t n_predict = 128; // new tokens to predict
std::string grammar = ""; // optional BNF-like grammar to constrain sampling
bool remember_generation = false; // remember a part of the prompt to avoid reprocessing all prompt
bool remember_generation = false; // remember a the prompt to avoid reprocessing all prompt
std::vector<std::string> antiprompt;
};
@ -256,6 +256,7 @@ struct llama_client_slot
stopping_word = "";
multibyte_pending = 0;
n_past = 0;
sent_count = 0;
if (grammar != nullptr) {
llama_grammar_free(grammar);
@ -299,8 +300,7 @@ struct llama_client_slot
}
bool available() {
return state == IDLE &&
command == NONE && !params.remember_generation;
return state == IDLE && command == NONE;
}
bool isProcessing() {
@ -354,12 +354,6 @@ struct llama_server_context
int n_ctx;
int n_vocab;
bool clean_kv_cache = true;
std::mutex mutex;
std::unique_lock<std::mutex> lock()
{
return std::unique_lock<std::mutex>(mutex);
}
~llama_server_context()
{
@ -406,7 +400,7 @@ struct llama_server_context
slot.last_n_tokens.resize(params.n_predict); // max prediction per slot
slot.reset();
std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end(), 0);
LOG_TEE(" - slot %i\n", slot.id);
LOG_TEE(" -> Slot %i\n", slot.id);
slots.push_back(slot);
}
LOG_TEE("Context Size: %i\n", params.n_ctx);
@ -716,7 +710,6 @@ struct llama_server_context
slot.generated_text.begin() + pos + stop_pos,
slot.generated_text.end());
pos = std::min(slot.sent_count, slot.generated_text.size());
result.tok = -1;
} else {
is_stop_full = false;
stop_pos = findStoppingStrings(str_test, token_str.size(),
@ -737,7 +730,6 @@ struct llama_server_context
{
slot.generated_token_probs.push_back(result);
}
if (slot.multibyte_pending > 0)
{
slot.multibyte_pending -= token_str.size();
@ -780,7 +772,6 @@ struct llama_server_context
slot.stopped_eos = true;
LOG_VERBOSE("eos token found", {});
}
LOG_VERBOSE("next token", {
{"token", result.tok},
{"token_text", tokens_to_output_formatted_string(ctx, result.tok)},
@ -818,18 +809,11 @@ struct llama_server_context
if (slot.state == PROCESSING && slot.command == RELEASE && !slot.hasNewToken())
{
LOG_TEE("slot %i released\n", slot.id);
if(!slot.params.remember_generation) {
llama_kv_cache_seq_rm(ctx, slot.id, num_tokens_system, n_ctx);
slot.state = IDLE;
slot.command = NONE;
slot.generated_text.clear();
return true;
} else {
slot.state = SLEEPING;
slot.command = NONE;
}
slot.state = slot.params.remember_generation ? SLEEPING : IDLE;
slot.command = NONE;
continue;
}
kv_cache_free -= slot.num_prompt_tokens;
if (slot.state == IDLE || slot.command == RELEASE) {
@ -858,23 +842,28 @@ struct llama_server_context
auto prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt
slot.num_prompt_tokens = prompt_tokens.size();
slot.n_past = keep_gen ? common_part(slot.context_tokens, prompt_tokens) : 0;
slot.context_tokens = prompt_tokens;
if (slot.n_past == slot.num_prompt_tokens) {
// we have to evaluate at least 1 token to generate logits.
printf("we have to evaluate at least 1 token to generate logits\n");
slot.n_past--;
}
llama_kv_cache_seq_rm(ctx, slot.id, num_tokens_system + slot.n_past, -1);
LOG_VERBOSE("prompt ingested", {
{"n_past", slot.n_past},
{"cached", tokens_to_str(ctx, slot.context_tokens.cbegin(), slot.context_tokens.cbegin() + slot.n_past)},
{"to_eval", tokens_to_str(ctx, slot.context_tokens.cbegin() + slot.n_past, slot.context_tokens.cend())},
});
if(system_prompt.empty()) {
llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1);
}
std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end(), 0);
for (; slot.n_past < prompt_tokens.size(); ++slot.n_past) {
printf(llama_token_to_piece(ctx, prompt_tokens[slot.n_past]).c_str());
//printf(llama_token_to_piece(ctx, prompt_tokens[slot.n_past]).c_str());
batch.token [batch.n_tokens] = prompt_tokens[slot.n_past];
batch.pos [batch.n_tokens] = slot.n_past + num_tokens_system;
batch.seq_id[batch.n_tokens] = slot.id;
@ -1693,7 +1682,6 @@ int main(int argc, char **argv)
svr.Post("/completion", [&llama](const Request &req, Response &res)
{
auto lock = llama.lock();
json data = json::parse(req.body);
@ -1763,13 +1751,12 @@ int main(int argc, char **argv)
// "application/json");
} else {
const auto chunked_content_provider = [slot, &llama](size_t, DataSink & sink) {
size_t sent_count = 0;
size_t sent_token_probs_index = 0;
while(slot->isProcessing()) {
if(slot->hasNewToken()) { // new token notification
const completion_token_output token = slot->next();
std::vector<completion_token_output> probs_output = {};
if (slot->sparams.n_probs > 0) {
if (slot->sparams.n_probs > 0) {
const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, token.text_to_send, false);
size_t probs_pos = std::min(sent_token_probs_index, slot->generated_token_probs.size());
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), slot->generated_token_probs.size());
@ -1816,12 +1803,10 @@ int main(int argc, char **argv)
return true;
};
auto on_complete = [slot, &llama] (bool) {
llama.mutex.unlock();
slot->sent_tokens = 0;
slot->generated_token_probs.clear();
slot->release();
};
lock.release();
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
} });
@ -1956,7 +1941,6 @@ int main(int argc, char **argv)
svr.Post("/tokenize", [&llama](const Request &req, Response &res)
{
auto lock = llama.lock();
const json body = json::parse(req.body);
std::vector<llama_token> tokens;
@ -1969,7 +1953,6 @@ int main(int argc, char **argv)
svr.Post("/detokenize", [&llama](const Request &req, Response &res)
{
auto lock = llama.lock();
const json body = json::parse(req.body);
std::string content;