From ee1d670cc6eef301d913b698864e1f4cbbe4d912 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 20 Sep 2023 17:32:21 +0300 Subject: [PATCH] parallel : fix bug (extra BOS) + smaller token_prev array --- examples/parallel/parallel.cpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index b8bd6d936..9c7cfd0dc 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -114,7 +114,7 @@ int main(int argc, char ** argv) { for (size_t i = 0; i < clients.size(); ++i) { auto & client = clients[i]; client.id = i; - client.tokens_prev.resize(n_ctx); + client.tokens_prev.resize(params.n_predict); std::fill(client.tokens_prev.begin(), client.tokens_prev.end(), 0); } @@ -191,6 +191,8 @@ int main(int argc, char ** argv) { for (int i = 0; i < n_clients; ++i) { llama_kv_cache_seq_rm(ctx, i, n_tokens_system, -1); } + + LOG_TEE("%s: clearing the KV cache\n", __func__); } // insert new sequences for decoding @@ -208,8 +210,9 @@ int main(int argc, char ** argv) { std::fill(client.tokens_prev.begin(), client.tokens_prev.end(), 0); + // do not prepend BOS because we have a system prompt! std::vector tokens_prompt; - tokens_prompt = ::llama_tokenize(ctx, client.prompt, true); + tokens_prompt = ::llama_tokenize(ctx, client.prompt, false); for (size_t i = 0; i < tokens_prompt.size(); ++i) { batch.token [batch.n_tokens] = tokens_prompt[i];