diff --git a/examples/server/chat.mjs b/examples/server/chat.mjs index 87f4d2926..103116bf2 100644 --- a/examples/server/chat.mjs +++ b/examples/server/chat.mjs @@ -7,6 +7,11 @@ const args = process.argv.slice(2); const grammarJsonSchemaFile = args.find( (_, index) => args[index - 1] === "--grammar-json-schema" ); + +const no_cached_prompt = args.find( + (_, index) => args[index - 1] === "--no-cache-prompt" +) ?? "false"; + const grammarFile = args.find((_, index) => args[index - 1] === "--grammar"); // Example usage: function,arguments @@ -30,6 +35,9 @@ if (grammarFile) { grammar = readFileSync(grammarFile, 'utf-8') } +// for cached prompt +let slot_id = -1; + const API_URL = 'http://127.0.0.1:8080' const chat = [ @@ -76,7 +84,9 @@ async function chat_completion(question) { top_p: 0.9, n_keep: n_keep, n_predict: 256, - stop: ["\n### Human:"], // stop completion after generating this + cache_prompt: no_cached_prompt === "false", + slot_id: slot_id, + stop: ["### Human:"], // stop completion after generating this grammar, stream: true, }) @@ -92,6 +102,7 @@ async function chat_completion(question) { const t = Buffer.from(chunk).toString('utf8') if (t.startsWith('data: ')) { const message = JSON.parse(t.substring(6)) + slot_id = message.slot_id answer += message.content process.stdout.write(message.content) if (message.stop) { diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 18ee4b5be..dbeb5fe5a 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -407,14 +407,13 @@ struct llama_server_context { llama_client_slot slot; slot.id = i; - slot.last_n_tokens.resize(params.n_predict); // max prediction per slot - slot.reset(); + slot.last_n_tokens.resize(n_ctx); // a slot can fill context size std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end(), 0); + slot.reset(); LOG_TEE(" -> Slot %i\n", slot.id); slots.push_back(slot); } - LOG_TEE("Context Size: %i\n", params.n_ctx); - batch = llama_batch_init(params.n_ctx, 0); + batch = llama_batch_init(n_ctx, 0); // empty system prompt system_prompt = ""; num_tokens_system = 0; @@ -465,38 +464,6 @@ struct llama_server_context return prompt_tokens; } - void processPrompt() { - //params.n_keep = std::min(n_ctx - 4, params.n_keep); - - // if input prompt is too big, truncate like normal - // if (num_prompt_tokens >= (size_t)n_ctx) - // { - // const int n_left = (n_ctx - params.n_keep) / 2; - // std::vector new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep); - // const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left; - // new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end()); - // std::copy(prompt_tokens.end() - n_ctx, prompt_tokens.end(), last_n_tokens.begin()); - - // LOG_VERBOSE("input truncated", { - // {"n_ctx", n_ctx}, - // {"n_keep", params.n_keep}, - // {"n_left", n_left}, - // {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())}, - // }); - - // truncated = true; - // prompt_tokens = new_tokens; - // } - // else - // { - // const size_t ps = num_prompt_tokens; - // std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0); - // std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps); - // } - - // compare the evaluated prompt with the new prompt - } - llama_client_slot* getSlot(int id) { for (llama_client_slot & slot : slots) { @@ -740,10 +707,11 @@ struct llama_server_context // release the slot if (slot.state == PROCESSING && slot.command == RELEASE && !slot.hasNewToken()) { - LOG_TEE("slot %i released\n", slot.id); slot.state = slot.params.cache_prompt ? SLEEPING : IDLE; if(slot.state == SLEEPING) { - printf("%i has cached prompt."); + LOG_TEE("slot %i has %i tokens in cache.\n", slot.id, slot.n_past); + } else { + LOG_TEE("slot %i released\n", slot.id); } slot.command = NONE; continue; @@ -773,8 +741,9 @@ struct llama_server_context if (params.cont_batching || batch.n_tokens == 0) { for (auto & slot : slots) { // need process the prompt - bool keep_gen = slot.state == SLEEPING; // remember generation - if ((slot.state == IDLE || keep_gen) && slot.command == LOAD_PROMPT) { + if ((slot.state == IDLE || slot.state == SLEEPING) && slot.command == LOAD_PROMPT) { + slot.state = PROCESSING; + slot.command = NONE; std::vector prompt_tokens; if(slot.infill) { bool suff_rm_leading_spc = true; @@ -800,10 +769,7 @@ struct llama_server_context slot.num_prompt_tokens = prompt_tokens.size(); - slot.n_past = keep_gen ? common_part(slot.context_tokens, prompt_tokens) : 0; - - printf("n_past: %i, context: %i, prompt: %i, cache: %s\n", - slot.n_past ,slot.context_tokens.size(), prompt_tokens.size(), keep_gen ? "true" : "false"); + slot.n_past = slot.params.cache_prompt ? common_part(slot.context_tokens, prompt_tokens) : 0; slot.context_tokens = prompt_tokens; @@ -813,6 +779,35 @@ struct llama_server_context slot.n_past--; } + if(!slot.params.cache_prompt) { + std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end(), 0); + } else { + LOG_TEE("slot %i - cached: %i tokens | to eval: %i tokens\n", slot.id, slot.n_past, (slot.num_prompt_tokens - slot.n_past)); + //if input prompt is too big, truncate like normal + if (slot.num_prompt_tokens >= (size_t)n_ctx) + { + const int n_left = (n_ctx - params.n_keep) / 2; + std::vector new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep); + const int erased_blocks = (slot.num_prompt_tokens - params.n_keep - n_left - 1) / n_left; + new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end()); + std::copy(prompt_tokens.end() - n_ctx, prompt_tokens.end(), slot.last_n_tokens.begin()); + + LOG_VERBOSE("input truncated", { + {"n_ctx", n_ctx}, + {"n_keep", params.n_keep}, + {"n_left", n_left}, + {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())}, + }); + + slot.truncated = true; + prompt_tokens = new_tokens; + } else { + const size_t ps = slot.num_prompt_tokens; + std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end() - ps, 0); + std::copy(prompt_tokens.begin(), prompt_tokens.end(), slot.last_n_tokens.end() - ps); + } + } + llama_kv_cache_seq_rm(ctx, slot.id, num_tokens_system + slot.n_past, -1); LOG_VERBOSE("prompt ingested", { @@ -820,10 +815,7 @@ struct llama_server_context {"cached", tokens_to_str(ctx, slot.context_tokens.cbegin(), slot.context_tokens.cbegin() + slot.n_past)}, {"to_eval", tokens_to_str(ctx, slot.context_tokens.cbegin() + slot.n_past, slot.context_tokens.cend())}, }); - - std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end(), 0); for (; slot.n_past < prompt_tokens.size(); ++slot.n_past) { - printf(llama_token_to_piece(ctx, prompt_tokens[slot.n_past]).c_str()); batch.token [batch.n_tokens] = prompt_tokens[slot.n_past]; batch.pos [batch.n_tokens] = slot.n_past + num_tokens_system; batch.seq_id[batch.n_tokens] = slot.id; @@ -838,8 +830,6 @@ struct llama_server_context slot.n_decoded = 0; slot.i_batch = batch.n_tokens - 1; - slot.state = PROCESSING; - slot.command = NONE; } } } @@ -1743,7 +1733,7 @@ int main(int argc, char **argv) return false; } } else { - std::this_thread::sleep_for(std::chrono::milliseconds(5)); + std::this_thread::sleep_for(std::chrono::microseconds(5)); } } const json data = format_final_response(