mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-13 20:14:29 +00:00
fix multiple clients
This commit is contained in:
parent
d2b1fac6c7
commit
c02c52efb5
File diff suppressed because it is too large
Load Diff
@ -195,6 +195,7 @@
|
||||
import { llama } from '/completion.js';
|
||||
import { SchemaConverter } from '/json-schema-to-grammar.mjs';
|
||||
let selected_image = false;
|
||||
var slot_id = -1;
|
||||
|
||||
const session = signal({
|
||||
prompt: "This is a conversation between User and Llama, a friendly chatbot. Llama is helpful, kind, honest, good at writing, and never fails to answer any requests immediately and with precision.",
|
||||
@ -222,7 +223,6 @@
|
||||
mirostat_eta: 0.1, // learning rate
|
||||
grammar: '',
|
||||
n_probs: 0, // no completion_probabilities,
|
||||
slot_id: -1,
|
||||
image_data: [],
|
||||
cache_prompt: true
|
||||
})
|
||||
@ -389,7 +389,6 @@
|
||||
throw new Error("already running");
|
||||
}
|
||||
controller.value = new AbortController();
|
||||
let slot_id = -1;
|
||||
for await (const chunk of llama(prompt, llamaParams, {controller: controller.value})) {
|
||||
const data = chunk.data;
|
||||
|
||||
@ -401,7 +400,6 @@
|
||||
currentMessages.pop();
|
||||
}
|
||||
transcriptUpdate([...history, [char, currentMessages]])
|
||||
params.value = {...params.value, slot_id}
|
||||
console.log("Completion finished: '", currentMessages.map(msg => msg.content).join(''), "', summary: ", data);
|
||||
} else {
|
||||
currentMessages.push(data);
|
||||
@ -450,6 +448,7 @@
|
||||
}
|
||||
await runLlama(prompt, {
|
||||
...params.value,
|
||||
slot_id: slot_id,
|
||||
stop: ["</s>", template("{{char}}:"), template("{{user}}:")],
|
||||
}, "{{char}}");
|
||||
}
|
||||
|
@ -125,6 +125,7 @@ enum slot_command {
|
||||
struct slot_params {
|
||||
bool stream = true;
|
||||
uint32_t seed = -1; // RNG seed
|
||||
int n_keep = 0; // RNG seed
|
||||
int32_t n_predict = -1; // new tokens to predict
|
||||
std::string grammar = ""; // optional BNF-like grammar to constrain sampling
|
||||
bool cache_prompt = false; // remember a the prompt to avoid reprocessing all prompt
|
||||
@ -452,6 +453,7 @@ struct llama_server_context
|
||||
gpt_params params;
|
||||
int n_ctx;
|
||||
int n_vocab;
|
||||
int max_ctx_per_slot = -1;
|
||||
bool clean_kv_cache = true;
|
||||
|
||||
~llama_server_context()
|
||||
@ -514,16 +516,23 @@ struct llama_server_context
|
||||
|
||||
void initialize() {
|
||||
// create slots
|
||||
LOG_TEE("Available slots:\n");
|
||||
all_slots_are_idle = true;
|
||||
if(max_ctx_per_slot == -1) {
|
||||
max_ctx_per_slot = n_ctx / params.n_parallel; // split context
|
||||
}
|
||||
if(max_ctx_per_slot * params.n_parallel > n_ctx) {
|
||||
printf("Error: The max context per slot is more greater than model context size");
|
||||
return;
|
||||
}
|
||||
LOG_TEE("Available slots:\n");
|
||||
for (int i = 0; i < params.n_parallel; i++)
|
||||
{
|
||||
llama_client_slot slot;
|
||||
slot.id = i;
|
||||
slot.last_n_tokens.resize(n_ctx); // a slot can fill context size
|
||||
slot.last_n_tokens.resize(max_ctx_per_slot);
|
||||
std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end(), 0);
|
||||
slot.reset();
|
||||
LOG_TEE(" -> Slot %i\n", slot.id);
|
||||
LOG_TEE(" -> Slot %i - max context: %i\n", slot.id, max_ctx_per_slot);
|
||||
slots.push_back(slot);
|
||||
}
|
||||
batch = llama_batch_init(n_ctx, 0);
|
||||
@ -914,18 +923,17 @@ struct llama_server_context
|
||||
}
|
||||
|
||||
// context shift takes effect only when there is a single slot
|
||||
if(params.n_parallel == 1) {
|
||||
llama_client_slot &slot = slots[0];
|
||||
if (slot.isProcessing() && slot.cache_tokens.size() >= (size_t)n_ctx)
|
||||
for(llama_client_slot &slot : slots) {
|
||||
if (slot.isProcessing() && slot.cache_tokens.size() >= (size_t)max_ctx_per_slot)
|
||||
{
|
||||
// Shift context
|
||||
const int n_left = slot.n_past - params.n_keep - 1;
|
||||
const int n_left = slot.n_past - slot.params.n_keep - 1;
|
||||
const int n_discard = n_left / 2;
|
||||
|
||||
llama_kv_cache_seq_rm (ctx, slot.id, params.n_keep + 1 , params.n_keep + n_discard + 1);
|
||||
llama_kv_cache_seq_shift(ctx, slot.id, params.n_keep + 1 + n_discard, slot.n_past, -n_discard);
|
||||
llama_kv_cache_seq_rm (ctx, slot.id, slot.params.n_keep + 1 , slot.params.n_keep + n_discard + 1);
|
||||
llama_kv_cache_seq_shift(ctx, slot.id, slot.params.n_keep + 1 + n_discard, slot.n_past, -n_discard);
|
||||
|
||||
for (size_t i = params.n_keep + 1 + n_discard; i < slot.cache_tokens.size(); i++)
|
||||
for (size_t i = slot.params.n_keep + 1 + n_discard; i < slot.cache_tokens.size(); i++)
|
||||
{
|
||||
slot.cache_tokens[i - n_discard] = slot.cache_tokens[i];
|
||||
}
|
||||
@ -1022,16 +1030,16 @@ struct llama_server_context
|
||||
slot.n_past = 0;
|
||||
slot.num_prompt_tokens_processed = slot.num_prompt_tokens;
|
||||
} else {
|
||||
if (params.n_keep < 0 && params.n_parallel == 1)
|
||||
if (slot.params.n_keep < 0)
|
||||
{
|
||||
params.n_keep = (int)slot.num_prompt_tokens;
|
||||
slot.params.n_keep = (int)slot.num_prompt_tokens;
|
||||
}
|
||||
params.n_keep = std::min(params.n_ctx - 4, params.n_keep);
|
||||
slot.params.n_keep = std::min(max_ctx_per_slot - 4, slot.params.n_keep);
|
||||
//if input prompt is too big, truncate like normal
|
||||
if (slot.num_prompt_tokens >= (size_t)n_ctx)
|
||||
if (slot.num_prompt_tokens >= (size_t)max_ctx_per_slot)
|
||||
{
|
||||
const int n_left = n_ctx - params.n_keep;
|
||||
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
|
||||
const int n_left = max_ctx_per_slot - slot.params.n_keep;
|
||||
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + slot.params.n_keep);
|
||||
// Use half the left-over space in the context for the prompt
|
||||
new_tokens.insert(new_tokens.end(), prompt_tokens.end() - n_left / 2, prompt_tokens.end());
|
||||
LOG_VERBOSE("input truncated", {
|
||||
@ -1331,6 +1339,15 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
|
||||
break;
|
||||
}
|
||||
params.n_ctx = std::stoi(argv[i]);
|
||||
}
|
||||
else if (arg == "-cps" || arg == "--ctx-per-slot" || arg == "--ctx_per_slot")
|
||||
{
|
||||
if (++i >= argc)
|
||||
{
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
llama.max_ctx_per_slot = std::stoi(argv[i]);
|
||||
}
|
||||
else if (arg == "--rope-freq-base")
|
||||
{
|
||||
@ -1717,7 +1734,7 @@ static void parse_options_completion(const json &body, llama_client_slot* slot,
|
||||
slot->sparams.mirostat_tau = json_value(body, "mirostat_tau", default_sparams.mirostat_tau);
|
||||
slot->sparams.mirostat_eta = json_value(body, "mirostat_eta", default_sparams.mirostat_eta);
|
||||
slot->sparams.penalize_nl = json_value(body, "penalize_nl", default_sparams.penalize_nl);
|
||||
llama.params.n_keep = json_value(body, "n_keep", 0);
|
||||
slot->params.n_keep = json_value(body, "n_keep", slot->params.n_keep);
|
||||
slot->params.seed = json_value(body, "seed", default_params.seed);
|
||||
slot->params.grammar = json_value(body, "grammar", default_params.grammar);
|
||||
slot->sparams.n_probs = json_value(body, "n_probs", default_sparams.n_probs);
|
||||
|
Loading…
Reference in New Issue
Block a user