fix multiple clients

This commit is contained in:
FSSRepo 2023-10-17 17:54:56 -04:00
parent d2b1fac6c7
commit c02c52efb5
3 changed files with 2257 additions and 2245 deletions

File diff suppressed because it is too large Load Diff

View File

@ -195,6 +195,7 @@
import { llama } from '/completion.js';
import { SchemaConverter } from '/json-schema-to-grammar.mjs';
let selected_image = false;
var slot_id = -1;
const session = signal({
prompt: "This is a conversation between User and Llama, a friendly chatbot. Llama is helpful, kind, honest, good at writing, and never fails to answer any requests immediately and with precision.",
@ -222,7 +223,6 @@
mirostat_eta: 0.1, // learning rate
grammar: '',
n_probs: 0, // no completion_probabilities,
slot_id: -1,
image_data: [],
cache_prompt: true
})
@ -389,7 +389,6 @@
throw new Error("already running");
}
controller.value = new AbortController();
let slot_id = -1;
for await (const chunk of llama(prompt, llamaParams, {controller: controller.value})) {
const data = chunk.data;
@ -401,7 +400,6 @@
currentMessages.pop();
}
transcriptUpdate([...history, [char, currentMessages]])
params.value = {...params.value, slot_id}
console.log("Completion finished: '", currentMessages.map(msg => msg.content).join(''), "', summary: ", data);
} else {
currentMessages.push(data);
@ -450,6 +448,7 @@
}
await runLlama(prompt, {
...params.value,
slot_id: slot_id,
stop: ["</s>", template("{{char}}:"), template("{{user}}:")],
}, "{{char}}");
}

View File

@ -125,6 +125,7 @@ enum slot_command {
struct slot_params {
bool stream = true;
uint32_t seed = -1; // RNG seed
int n_keep = 0; // RNG seed
int32_t n_predict = -1; // new tokens to predict
std::string grammar = ""; // optional BNF-like grammar to constrain sampling
bool cache_prompt = false; // remember a the prompt to avoid reprocessing all prompt
@ -452,6 +453,7 @@ struct llama_server_context
gpt_params params;
int n_ctx;
int n_vocab;
int max_ctx_per_slot = -1;
bool clean_kv_cache = true;
~llama_server_context()
@ -514,16 +516,23 @@ struct llama_server_context
void initialize() {
// create slots
LOG_TEE("Available slots:\n");
all_slots_are_idle = true;
if(max_ctx_per_slot == -1) {
max_ctx_per_slot = n_ctx / params.n_parallel; // split context
}
if(max_ctx_per_slot * params.n_parallel > n_ctx) {
printf("Error: The max context per slot is more greater than model context size");
return;
}
LOG_TEE("Available slots:\n");
for (int i = 0; i < params.n_parallel; i++)
{
llama_client_slot slot;
slot.id = i;
slot.last_n_tokens.resize(n_ctx); // a slot can fill context size
slot.last_n_tokens.resize(max_ctx_per_slot);
std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end(), 0);
slot.reset();
LOG_TEE(" -> Slot %i\n", slot.id);
LOG_TEE(" -> Slot %i - max context: %i\n", slot.id, max_ctx_per_slot);
slots.push_back(slot);
}
batch = llama_batch_init(n_ctx, 0);
@ -914,18 +923,17 @@ struct llama_server_context
}
// context shift takes effect only when there is a single slot
if(params.n_parallel == 1) {
llama_client_slot &slot = slots[0];
if (slot.isProcessing() && slot.cache_tokens.size() >= (size_t)n_ctx)
for(llama_client_slot &slot : slots) {
if (slot.isProcessing() && slot.cache_tokens.size() >= (size_t)max_ctx_per_slot)
{
// Shift context
const int n_left = slot.n_past - params.n_keep - 1;
const int n_left = slot.n_past - slot.params.n_keep - 1;
const int n_discard = n_left / 2;
llama_kv_cache_seq_rm (ctx, slot.id, params.n_keep + 1 , params.n_keep + n_discard + 1);
llama_kv_cache_seq_shift(ctx, slot.id, params.n_keep + 1 + n_discard, slot.n_past, -n_discard);
llama_kv_cache_seq_rm (ctx, slot.id, slot.params.n_keep + 1 , slot.params.n_keep + n_discard + 1);
llama_kv_cache_seq_shift(ctx, slot.id, slot.params.n_keep + 1 + n_discard, slot.n_past, -n_discard);
for (size_t i = params.n_keep + 1 + n_discard; i < slot.cache_tokens.size(); i++)
for (size_t i = slot.params.n_keep + 1 + n_discard; i < slot.cache_tokens.size(); i++)
{
slot.cache_tokens[i - n_discard] = slot.cache_tokens[i];
}
@ -1022,16 +1030,16 @@ struct llama_server_context
slot.n_past = 0;
slot.num_prompt_tokens_processed = slot.num_prompt_tokens;
} else {
if (params.n_keep < 0 && params.n_parallel == 1)
if (slot.params.n_keep < 0)
{
params.n_keep = (int)slot.num_prompt_tokens;
slot.params.n_keep = (int)slot.num_prompt_tokens;
}
params.n_keep = std::min(params.n_ctx - 4, params.n_keep);
slot.params.n_keep = std::min(max_ctx_per_slot - 4, slot.params.n_keep);
//if input prompt is too big, truncate like normal
if (slot.num_prompt_tokens >= (size_t)n_ctx)
if (slot.num_prompt_tokens >= (size_t)max_ctx_per_slot)
{
const int n_left = n_ctx - params.n_keep;
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
const int n_left = max_ctx_per_slot - slot.params.n_keep;
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + slot.params.n_keep);
// Use half the left-over space in the context for the prompt
new_tokens.insert(new_tokens.end(), prompt_tokens.end() - n_left / 2, prompt_tokens.end());
LOG_VERBOSE("input truncated", {
@ -1331,6 +1339,15 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
break;
}
params.n_ctx = std::stoi(argv[i]);
}
else if (arg == "-cps" || arg == "--ctx-per-slot" || arg == "--ctx_per_slot")
{
if (++i >= argc)
{
invalid_param = true;
break;
}
llama.max_ctx_per_slot = std::stoi(argv[i]);
}
else if (arg == "--rope-freq-base")
{
@ -1717,7 +1734,7 @@ static void parse_options_completion(const json &body, llama_client_slot* slot,
slot->sparams.mirostat_tau = json_value(body, "mirostat_tau", default_sparams.mirostat_tau);
slot->sparams.mirostat_eta = json_value(body, "mirostat_eta", default_sparams.mirostat_eta);
slot->sparams.penalize_nl = json_value(body, "penalize_nl", default_sparams.penalize_nl);
llama.params.n_keep = json_value(body, "n_keep", 0);
slot->params.n_keep = json_value(body, "n_keep", slot->params.n_keep);
slot->params.seed = json_value(body, "seed", default_params.seed);
slot->params.grammar = json_value(body, "grammar", default_params.grammar);
slot->sparams.n_probs = json_value(body, "n_probs", default_sparams.n_probs);