mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 03:44:35 +00:00
chat.mjs support cached prompt + some fixes
This commit is contained in:
parent
500ac7120e
commit
4ba5a5013d
@ -7,6 +7,11 @@ const args = process.argv.slice(2);
|
||||
const grammarJsonSchemaFile = args.find(
|
||||
(_, index) => args[index - 1] === "--grammar-json-schema"
|
||||
);
|
||||
|
||||
const no_cached_prompt = args.find(
|
||||
(_, index) => args[index - 1] === "--no-cache-prompt"
|
||||
) ?? "false";
|
||||
|
||||
const grammarFile = args.find((_, index) => args[index - 1] === "--grammar");
|
||||
|
||||
// Example usage: function,arguments
|
||||
@ -30,6 +35,9 @@ if (grammarFile) {
|
||||
grammar = readFileSync(grammarFile, 'utf-8')
|
||||
}
|
||||
|
||||
// for cached prompt
|
||||
let slot_id = -1;
|
||||
|
||||
const API_URL = 'http://127.0.0.1:8080'
|
||||
|
||||
const chat = [
|
||||
@ -76,7 +84,9 @@ async function chat_completion(question) {
|
||||
top_p: 0.9,
|
||||
n_keep: n_keep,
|
||||
n_predict: 256,
|
||||
stop: ["\n### Human:"], // stop completion after generating this
|
||||
cache_prompt: no_cached_prompt === "false",
|
||||
slot_id: slot_id,
|
||||
stop: ["### Human:"], // stop completion after generating this
|
||||
grammar,
|
||||
stream: true,
|
||||
})
|
||||
@ -92,6 +102,7 @@ async function chat_completion(question) {
|
||||
const t = Buffer.from(chunk).toString('utf8')
|
||||
if (t.startsWith('data: ')) {
|
||||
const message = JSON.parse(t.substring(6))
|
||||
slot_id = message.slot_id
|
||||
answer += message.content
|
||||
process.stdout.write(message.content)
|
||||
if (message.stop) {
|
||||
|
@ -407,14 +407,13 @@ struct llama_server_context
|
||||
{
|
||||
llama_client_slot slot;
|
||||
slot.id = i;
|
||||
slot.last_n_tokens.resize(params.n_predict); // max prediction per slot
|
||||
slot.reset();
|
||||
slot.last_n_tokens.resize(n_ctx); // a slot can fill context size
|
||||
std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end(), 0);
|
||||
slot.reset();
|
||||
LOG_TEE(" -> Slot %i\n", slot.id);
|
||||
slots.push_back(slot);
|
||||
}
|
||||
LOG_TEE("Context Size: %i\n", params.n_ctx);
|
||||
batch = llama_batch_init(params.n_ctx, 0);
|
||||
batch = llama_batch_init(n_ctx, 0);
|
||||
// empty system prompt
|
||||
system_prompt = "";
|
||||
num_tokens_system = 0;
|
||||
@ -465,38 +464,6 @@ struct llama_server_context
|
||||
return prompt_tokens;
|
||||
}
|
||||
|
||||
void processPrompt() {
|
||||
//params.n_keep = std::min(n_ctx - 4, params.n_keep);
|
||||
|
||||
// if input prompt is too big, truncate like normal
|
||||
// if (num_prompt_tokens >= (size_t)n_ctx)
|
||||
// {
|
||||
// const int n_left = (n_ctx - params.n_keep) / 2;
|
||||
// std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
|
||||
// const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left;
|
||||
// new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end());
|
||||
// std::copy(prompt_tokens.end() - n_ctx, prompt_tokens.end(), last_n_tokens.begin());
|
||||
|
||||
// LOG_VERBOSE("input truncated", {
|
||||
// {"n_ctx", n_ctx},
|
||||
// {"n_keep", params.n_keep},
|
||||
// {"n_left", n_left},
|
||||
// {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
|
||||
// });
|
||||
|
||||
// truncated = true;
|
||||
// prompt_tokens = new_tokens;
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// const size_t ps = num_prompt_tokens;
|
||||
// std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0);
|
||||
// std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps);
|
||||
// }
|
||||
|
||||
// compare the evaluated prompt with the new prompt
|
||||
}
|
||||
|
||||
llama_client_slot* getSlot(int id) {
|
||||
for (llama_client_slot & slot : slots)
|
||||
{
|
||||
@ -740,10 +707,11 @@ struct llama_server_context
|
||||
// release the slot
|
||||
if (slot.state == PROCESSING && slot.command == RELEASE && !slot.hasNewToken())
|
||||
{
|
||||
LOG_TEE("slot %i released\n", slot.id);
|
||||
slot.state = slot.params.cache_prompt ? SLEEPING : IDLE;
|
||||
if(slot.state == SLEEPING) {
|
||||
printf("%i has cached prompt.");
|
||||
LOG_TEE("slot %i has %i tokens in cache.\n", slot.id, slot.n_past);
|
||||
} else {
|
||||
LOG_TEE("slot %i released\n", slot.id);
|
||||
}
|
||||
slot.command = NONE;
|
||||
continue;
|
||||
@ -773,8 +741,9 @@ struct llama_server_context
|
||||
if (params.cont_batching || batch.n_tokens == 0) {
|
||||
for (auto & slot : slots) {
|
||||
// need process the prompt
|
||||
bool keep_gen = slot.state == SLEEPING; // remember generation
|
||||
if ((slot.state == IDLE || keep_gen) && slot.command == LOAD_PROMPT) {
|
||||
if ((slot.state == IDLE || slot.state == SLEEPING) && slot.command == LOAD_PROMPT) {
|
||||
slot.state = PROCESSING;
|
||||
slot.command = NONE;
|
||||
std::vector<llama_token> prompt_tokens;
|
||||
if(slot.infill) {
|
||||
bool suff_rm_leading_spc = true;
|
||||
@ -800,10 +769,7 @@ struct llama_server_context
|
||||
|
||||
slot.num_prompt_tokens = prompt_tokens.size();
|
||||
|
||||
slot.n_past = keep_gen ? common_part(slot.context_tokens, prompt_tokens) : 0;
|
||||
|
||||
printf("n_past: %i, context: %i, prompt: %i, cache: %s\n",
|
||||
slot.n_past ,slot.context_tokens.size(), prompt_tokens.size(), keep_gen ? "true" : "false");
|
||||
slot.n_past = slot.params.cache_prompt ? common_part(slot.context_tokens, prompt_tokens) : 0;
|
||||
|
||||
slot.context_tokens = prompt_tokens;
|
||||
|
||||
@ -813,6 +779,35 @@ struct llama_server_context
|
||||
slot.n_past--;
|
||||
}
|
||||
|
||||
if(!slot.params.cache_prompt) {
|
||||
std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end(), 0);
|
||||
} else {
|
||||
LOG_TEE("slot %i - cached: %i tokens | to eval: %i tokens\n", slot.id, slot.n_past, (slot.num_prompt_tokens - slot.n_past));
|
||||
//if input prompt is too big, truncate like normal
|
||||
if (slot.num_prompt_tokens >= (size_t)n_ctx)
|
||||
{
|
||||
const int n_left = (n_ctx - params.n_keep) / 2;
|
||||
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
|
||||
const int erased_blocks = (slot.num_prompt_tokens - params.n_keep - n_left - 1) / n_left;
|
||||
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end());
|
||||
std::copy(prompt_tokens.end() - n_ctx, prompt_tokens.end(), slot.last_n_tokens.begin());
|
||||
|
||||
LOG_VERBOSE("input truncated", {
|
||||
{"n_ctx", n_ctx},
|
||||
{"n_keep", params.n_keep},
|
||||
{"n_left", n_left},
|
||||
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
|
||||
});
|
||||
|
||||
slot.truncated = true;
|
||||
prompt_tokens = new_tokens;
|
||||
} else {
|
||||
const size_t ps = slot.num_prompt_tokens;
|
||||
std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end() - ps, 0);
|
||||
std::copy(prompt_tokens.begin(), prompt_tokens.end(), slot.last_n_tokens.end() - ps);
|
||||
}
|
||||
}
|
||||
|
||||
llama_kv_cache_seq_rm(ctx, slot.id, num_tokens_system + slot.n_past, -1);
|
||||
|
||||
LOG_VERBOSE("prompt ingested", {
|
||||
@ -820,10 +815,7 @@ struct llama_server_context
|
||||
{"cached", tokens_to_str(ctx, slot.context_tokens.cbegin(), slot.context_tokens.cbegin() + slot.n_past)},
|
||||
{"to_eval", tokens_to_str(ctx, slot.context_tokens.cbegin() + slot.n_past, slot.context_tokens.cend())},
|
||||
});
|
||||
|
||||
std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end(), 0);
|
||||
for (; slot.n_past < prompt_tokens.size(); ++slot.n_past) {
|
||||
printf(llama_token_to_piece(ctx, prompt_tokens[slot.n_past]).c_str());
|
||||
batch.token [batch.n_tokens] = prompt_tokens[slot.n_past];
|
||||
batch.pos [batch.n_tokens] = slot.n_past + num_tokens_system;
|
||||
batch.seq_id[batch.n_tokens] = slot.id;
|
||||
@ -838,8 +830,6 @@ struct llama_server_context
|
||||
|
||||
slot.n_decoded = 0;
|
||||
slot.i_batch = batch.n_tokens - 1;
|
||||
slot.state = PROCESSING;
|
||||
slot.command = NONE;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1743,7 +1733,7 @@ int main(int argc, char **argv)
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(5));
|
||||
std::this_thread::sleep_for(std::chrono::microseconds(5));
|
||||
}
|
||||
}
|
||||
const json data = format_final_response(
|
||||
|
Loading…
Reference in New Issue
Block a user