chat.mjs support cached prompt + some fixes

This commit is contained in:
FSSRepo 2023-10-13 11:06:41 -04:00
parent 500ac7120e
commit 4ba5a5013d
2 changed files with 52 additions and 51 deletions

View File

@ -7,6 +7,11 @@ const args = process.argv.slice(2);
const grammarJsonSchemaFile = args.find(
(_, index) => args[index - 1] === "--grammar-json-schema"
);
const no_cached_prompt = args.find(
(_, index) => args[index - 1] === "--no-cache-prompt"
) ?? "false";
const grammarFile = args.find((_, index) => args[index - 1] === "--grammar");
// Example usage: function,arguments
@ -30,6 +35,9 @@ if (grammarFile) {
grammar = readFileSync(grammarFile, 'utf-8')
}
// for cached prompt
let slot_id = -1;
const API_URL = 'http://127.0.0.1:8080'
const chat = [
@ -76,7 +84,9 @@ async function chat_completion(question) {
top_p: 0.9,
n_keep: n_keep,
n_predict: 256,
stop: ["\n### Human:"], // stop completion after generating this
cache_prompt: no_cached_prompt === "false",
slot_id: slot_id,
stop: ["### Human:"], // stop completion after generating this
grammar,
stream: true,
})
@ -92,6 +102,7 @@ async function chat_completion(question) {
const t = Buffer.from(chunk).toString('utf8')
if (t.startsWith('data: ')) {
const message = JSON.parse(t.substring(6))
slot_id = message.slot_id
answer += message.content
process.stdout.write(message.content)
if (message.stop) {

View File

@ -407,14 +407,13 @@ struct llama_server_context
{
llama_client_slot slot;
slot.id = i;
slot.last_n_tokens.resize(params.n_predict); // max prediction per slot
slot.reset();
slot.last_n_tokens.resize(n_ctx); // a slot can fill context size
std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end(), 0);
slot.reset();
LOG_TEE(" -> Slot %i\n", slot.id);
slots.push_back(slot);
}
LOG_TEE("Context Size: %i\n", params.n_ctx);
batch = llama_batch_init(params.n_ctx, 0);
batch = llama_batch_init(n_ctx, 0);
// empty system prompt
system_prompt = "";
num_tokens_system = 0;
@ -465,38 +464,6 @@ struct llama_server_context
return prompt_tokens;
}
void processPrompt() {
//params.n_keep = std::min(n_ctx - 4, params.n_keep);
// if input prompt is too big, truncate like normal
// if (num_prompt_tokens >= (size_t)n_ctx)
// {
// const int n_left = (n_ctx - params.n_keep) / 2;
// std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
// const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left;
// new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end());
// std::copy(prompt_tokens.end() - n_ctx, prompt_tokens.end(), last_n_tokens.begin());
// LOG_VERBOSE("input truncated", {
// {"n_ctx", n_ctx},
// {"n_keep", params.n_keep},
// {"n_left", n_left},
// {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
// });
// truncated = true;
// prompt_tokens = new_tokens;
// }
// else
// {
// const size_t ps = num_prompt_tokens;
// std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0);
// std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps);
// }
// compare the evaluated prompt with the new prompt
}
llama_client_slot* getSlot(int id) {
for (llama_client_slot & slot : slots)
{
@ -740,10 +707,11 @@ struct llama_server_context
// release the slot
if (slot.state == PROCESSING && slot.command == RELEASE && !slot.hasNewToken())
{
LOG_TEE("slot %i released\n", slot.id);
slot.state = slot.params.cache_prompt ? SLEEPING : IDLE;
if(slot.state == SLEEPING) {
printf("%i has cached prompt.");
LOG_TEE("slot %i has %i tokens in cache.\n", slot.id, slot.n_past);
} else {
LOG_TEE("slot %i released\n", slot.id);
}
slot.command = NONE;
continue;
@ -773,8 +741,9 @@ struct llama_server_context
if (params.cont_batching || batch.n_tokens == 0) {
for (auto & slot : slots) {
// need process the prompt
bool keep_gen = slot.state == SLEEPING; // remember generation
if ((slot.state == IDLE || keep_gen) && slot.command == LOAD_PROMPT) {
if ((slot.state == IDLE || slot.state == SLEEPING) && slot.command == LOAD_PROMPT) {
slot.state = PROCESSING;
slot.command = NONE;
std::vector<llama_token> prompt_tokens;
if(slot.infill) {
bool suff_rm_leading_spc = true;
@ -800,10 +769,7 @@ struct llama_server_context
slot.num_prompt_tokens = prompt_tokens.size();
slot.n_past = keep_gen ? common_part(slot.context_tokens, prompt_tokens) : 0;
printf("n_past: %i, context: %i, prompt: %i, cache: %s\n",
slot.n_past ,slot.context_tokens.size(), prompt_tokens.size(), keep_gen ? "true" : "false");
slot.n_past = slot.params.cache_prompt ? common_part(slot.context_tokens, prompt_tokens) : 0;
slot.context_tokens = prompt_tokens;
@ -813,6 +779,35 @@ struct llama_server_context
slot.n_past--;
}
if(!slot.params.cache_prompt) {
std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end(), 0);
} else {
LOG_TEE("slot %i - cached: %i tokens | to eval: %i tokens\n", slot.id, slot.n_past, (slot.num_prompt_tokens - slot.n_past));
//if input prompt is too big, truncate like normal
if (slot.num_prompt_tokens >= (size_t)n_ctx)
{
const int n_left = (n_ctx - params.n_keep) / 2;
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
const int erased_blocks = (slot.num_prompt_tokens - params.n_keep - n_left - 1) / n_left;
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end());
std::copy(prompt_tokens.end() - n_ctx, prompt_tokens.end(), slot.last_n_tokens.begin());
LOG_VERBOSE("input truncated", {
{"n_ctx", n_ctx},
{"n_keep", params.n_keep},
{"n_left", n_left},
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
});
slot.truncated = true;
prompt_tokens = new_tokens;
} else {
const size_t ps = slot.num_prompt_tokens;
std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end() - ps, 0);
std::copy(prompt_tokens.begin(), prompt_tokens.end(), slot.last_n_tokens.end() - ps);
}
}
llama_kv_cache_seq_rm(ctx, slot.id, num_tokens_system + slot.n_past, -1);
LOG_VERBOSE("prompt ingested", {
@ -820,10 +815,7 @@ struct llama_server_context
{"cached", tokens_to_str(ctx, slot.context_tokens.cbegin(), slot.context_tokens.cbegin() + slot.n_past)},
{"to_eval", tokens_to_str(ctx, slot.context_tokens.cbegin() + slot.n_past, slot.context_tokens.cend())},
});
std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end(), 0);
for (; slot.n_past < prompt_tokens.size(); ++slot.n_past) {
printf(llama_token_to_piece(ctx, prompt_tokens[slot.n_past]).c_str());
batch.token [batch.n_tokens] = prompt_tokens[slot.n_past];
batch.pos [batch.n_tokens] = slot.n_past + num_tokens_system;
batch.seq_id[batch.n_tokens] = slot.id;
@ -838,8 +830,6 @@ struct llama_server_context
slot.n_decoded = 0;
slot.i_batch = batch.n_tokens - 1;
slot.state = PROCESSING;
slot.command = NONE;
}
}
}
@ -1743,7 +1733,7 @@ int main(int argc, char **argv)
return false;
}
} else {
std::this_thread::sleep_for(std::chrono::milliseconds(5));
std::this_thread::sleep_for(std::chrono::microseconds(5));
}
}
const json data = format_final_response(