parallel : fix bug (extra BOS) + smaller token_prev array

This commit is contained in:
Georgi Gerganov 2023-09-20 17:32:21 +03:00
parent 1be2b8c19b
commit ee1d670cc6

View File

@ -114,7 +114,7 @@ int main(int argc, char ** argv) {
for (size_t i = 0; i < clients.size(); ++i) { for (size_t i = 0; i < clients.size(); ++i) {
auto & client = clients[i]; auto & client = clients[i];
client.id = i; client.id = i;
client.tokens_prev.resize(n_ctx); client.tokens_prev.resize(params.n_predict);
std::fill(client.tokens_prev.begin(), client.tokens_prev.end(), 0); std::fill(client.tokens_prev.begin(), client.tokens_prev.end(), 0);
} }
@ -191,6 +191,8 @@ int main(int argc, char ** argv) {
for (int i = 0; i < n_clients; ++i) { for (int i = 0; i < n_clients; ++i) {
llama_kv_cache_seq_rm(ctx, i, n_tokens_system, -1); llama_kv_cache_seq_rm(ctx, i, n_tokens_system, -1);
} }
LOG_TEE("%s: clearing the KV cache\n", __func__);
} }
// insert new sequences for decoding // insert new sequences for decoding
@ -208,8 +210,9 @@ int main(int argc, char ** argv) {
std::fill(client.tokens_prev.begin(), client.tokens_prev.end(), 0); std::fill(client.tokens_prev.begin(), client.tokens_prev.end(), 0);
// do not prepend BOS because we have a system prompt!
std::vector<llama_token> tokens_prompt; std::vector<llama_token> tokens_prompt;
tokens_prompt = ::llama_tokenize(ctx, client.prompt, true); tokens_prompt = ::llama_tokenize(ctx, client.prompt, false);
for (size_t i = 0; i < tokens_prompt.size(); ++i) { for (size_t i = 0; i < tokens_prompt.size(); ++i) {
batch.token [batch.n_tokens] = tokens_prompt[i]; batch.token [batch.n_tokens] = tokens_prompt[i];