mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 03:44:35 +00:00
multiple client support
This commit is contained in:
parent
81484805f0
commit
5b8e29de53
@ -76,7 +76,7 @@ struct slot_params {
|
||||
uint32_t seed = -1; // RNG seed
|
||||
int32_t n_predict = 128; // new tokens to predict
|
||||
std::string grammar = ""; // optional BNF-like grammar to constrain sampling
|
||||
bool remember_generation = false; // remember a part of the prompt to avoid reprocessing all prompt
|
||||
bool remember_generation = false; // remember a the prompt to avoid reprocessing all prompt
|
||||
std::vector<std::string> antiprompt;
|
||||
};
|
||||
|
||||
@ -256,6 +256,7 @@ struct llama_client_slot
|
||||
stopping_word = "";
|
||||
multibyte_pending = 0;
|
||||
n_past = 0;
|
||||
sent_count = 0;
|
||||
|
||||
if (grammar != nullptr) {
|
||||
llama_grammar_free(grammar);
|
||||
@ -299,8 +300,7 @@ struct llama_client_slot
|
||||
}
|
||||
|
||||
bool available() {
|
||||
return state == IDLE &&
|
||||
command == NONE && !params.remember_generation;
|
||||
return state == IDLE && command == NONE;
|
||||
}
|
||||
|
||||
bool isProcessing() {
|
||||
@ -354,12 +354,6 @@ struct llama_server_context
|
||||
int n_ctx;
|
||||
int n_vocab;
|
||||
bool clean_kv_cache = true;
|
||||
std::mutex mutex;
|
||||
|
||||
std::unique_lock<std::mutex> lock()
|
||||
{
|
||||
return std::unique_lock<std::mutex>(mutex);
|
||||
}
|
||||
|
||||
~llama_server_context()
|
||||
{
|
||||
@ -406,7 +400,7 @@ struct llama_server_context
|
||||
slot.last_n_tokens.resize(params.n_predict); // max prediction per slot
|
||||
slot.reset();
|
||||
std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end(), 0);
|
||||
LOG_TEE(" - slot %i\n", slot.id);
|
||||
LOG_TEE(" -> Slot %i\n", slot.id);
|
||||
slots.push_back(slot);
|
||||
}
|
||||
LOG_TEE("Context Size: %i\n", params.n_ctx);
|
||||
@ -716,7 +710,6 @@ struct llama_server_context
|
||||
slot.generated_text.begin() + pos + stop_pos,
|
||||
slot.generated_text.end());
|
||||
pos = std::min(slot.sent_count, slot.generated_text.size());
|
||||
result.tok = -1;
|
||||
} else {
|
||||
is_stop_full = false;
|
||||
stop_pos = findStoppingStrings(str_test, token_str.size(),
|
||||
@ -737,7 +730,6 @@ struct llama_server_context
|
||||
{
|
||||
slot.generated_token_probs.push_back(result);
|
||||
}
|
||||
|
||||
if (slot.multibyte_pending > 0)
|
||||
{
|
||||
slot.multibyte_pending -= token_str.size();
|
||||
@ -780,7 +772,6 @@ struct llama_server_context
|
||||
slot.stopped_eos = true;
|
||||
LOG_VERBOSE("eos token found", {});
|
||||
}
|
||||
|
||||
LOG_VERBOSE("next token", {
|
||||
{"token", result.tok},
|
||||
{"token_text", tokens_to_output_formatted_string(ctx, result.tok)},
|
||||
@ -818,18 +809,11 @@ struct llama_server_context
|
||||
if (slot.state == PROCESSING && slot.command == RELEASE && !slot.hasNewToken())
|
||||
{
|
||||
LOG_TEE("slot %i released\n", slot.id);
|
||||
if(!slot.params.remember_generation) {
|
||||
llama_kv_cache_seq_rm(ctx, slot.id, num_tokens_system, n_ctx);
|
||||
slot.state = IDLE;
|
||||
slot.command = NONE;
|
||||
slot.generated_text.clear();
|
||||
return true;
|
||||
} else {
|
||||
slot.state = SLEEPING;
|
||||
slot.command = NONE;
|
||||
}
|
||||
slot.state = slot.params.remember_generation ? SLEEPING : IDLE;
|
||||
slot.command = NONE;
|
||||
continue;
|
||||
}
|
||||
|
||||
kv_cache_free -= slot.num_prompt_tokens;
|
||||
|
||||
if (slot.state == IDLE || slot.command == RELEASE) {
|
||||
@ -858,23 +842,28 @@ struct llama_server_context
|
||||
|
||||
auto prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt
|
||||
slot.num_prompt_tokens = prompt_tokens.size();
|
||||
|
||||
slot.n_past = keep_gen ? common_part(slot.context_tokens, prompt_tokens) : 0;
|
||||
|
||||
slot.context_tokens = prompt_tokens;
|
||||
|
||||
if (slot.n_past == slot.num_prompt_tokens) {
|
||||
// we have to evaluate at least 1 token to generate logits.
|
||||
printf("we have to evaluate at least 1 token to generate logits\n");
|
||||
slot.n_past--;
|
||||
}
|
||||
|
||||
llama_kv_cache_seq_rm(ctx, slot.id, num_tokens_system + slot.n_past, -1);
|
||||
|
||||
LOG_VERBOSE("prompt ingested", {
|
||||
{"n_past", slot.n_past},
|
||||
{"cached", tokens_to_str(ctx, slot.context_tokens.cbegin(), slot.context_tokens.cbegin() + slot.n_past)},
|
||||
{"to_eval", tokens_to_str(ctx, slot.context_tokens.cbegin() + slot.n_past, slot.context_tokens.cend())},
|
||||
});
|
||||
|
||||
if(system_prompt.empty()) {
|
||||
llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1);
|
||||
}
|
||||
|
||||
std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end(), 0);
|
||||
for (; slot.n_past < prompt_tokens.size(); ++slot.n_past) {
|
||||
printf(llama_token_to_piece(ctx, prompt_tokens[slot.n_past]).c_str());
|
||||
//printf(llama_token_to_piece(ctx, prompt_tokens[slot.n_past]).c_str());
|
||||
batch.token [batch.n_tokens] = prompt_tokens[slot.n_past];
|
||||
batch.pos [batch.n_tokens] = slot.n_past + num_tokens_system;
|
||||
batch.seq_id[batch.n_tokens] = slot.id;
|
||||
@ -1693,7 +1682,6 @@ int main(int argc, char **argv)
|
||||
|
||||
svr.Post("/completion", [&llama](const Request &req, Response &res)
|
||||
{
|
||||
auto lock = llama.lock();
|
||||
|
||||
json data = json::parse(req.body);
|
||||
|
||||
@ -1763,13 +1751,12 @@ int main(int argc, char **argv)
|
||||
// "application/json");
|
||||
} else {
|
||||
const auto chunked_content_provider = [slot, &llama](size_t, DataSink & sink) {
|
||||
size_t sent_count = 0;
|
||||
size_t sent_token_probs_index = 0;
|
||||
while(slot->isProcessing()) {
|
||||
if(slot->hasNewToken()) { // new token notification
|
||||
const completion_token_output token = slot->next();
|
||||
std::vector<completion_token_output> probs_output = {};
|
||||
if (slot->sparams.n_probs > 0) {
|
||||
if (slot->sparams.n_probs > 0) {
|
||||
const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, token.text_to_send, false);
|
||||
size_t probs_pos = std::min(sent_token_probs_index, slot->generated_token_probs.size());
|
||||
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), slot->generated_token_probs.size());
|
||||
@ -1816,12 +1803,10 @@ int main(int argc, char **argv)
|
||||
return true;
|
||||
};
|
||||
auto on_complete = [slot, &llama] (bool) {
|
||||
llama.mutex.unlock();
|
||||
slot->sent_tokens = 0;
|
||||
slot->generated_token_probs.clear();
|
||||
slot->release();
|
||||
};
|
||||
lock.release();
|
||||
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
|
||||
} });
|
||||
|
||||
@ -1956,7 +1941,6 @@ int main(int argc, char **argv)
|
||||
|
||||
svr.Post("/tokenize", [&llama](const Request &req, Response &res)
|
||||
{
|
||||
auto lock = llama.lock();
|
||||
|
||||
const json body = json::parse(req.body);
|
||||
std::vector<llama_token> tokens;
|
||||
@ -1969,7 +1953,6 @@ int main(int argc, char **argv)
|
||||
|
||||
svr.Post("/detokenize", [&llama](const Request &req, Response &res)
|
||||
{
|
||||
auto lock = llama.lock();
|
||||
|
||||
const json body = json::parse(req.body);
|
||||
std::string content;
|
||||
|
Loading…
Reference in New Issue
Block a user