passkey : simplify n_past logic

This commit is contained in:
Georgi Gerganov 2024-01-07 17:52:12 +02:00
parent bda3f2c892
commit f2c9800dfb
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -128,12 +128,14 @@ int main(int argc, char ** argv) {
llama_batch batch = llama_batch_init(512, 0, 1);
int n_past = 0;
// fill the KV cache
for (int i = 0; i < n_ctx; i += n_batch) {
llama_batch_clear(batch);
for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) {
llama_batch_add(batch, tokens_list[i + j], i + j, { 0 }, false);
llama_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false);
}
if (i + n_batch >= n_tokens_all) {
@ -160,10 +162,12 @@ int main(int argc, char ** argv) {
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
llama_kv_cache_seq_shift(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
n_past -= n_discard;
llama_batch_clear(batch);
for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) {
llama_batch_add(batch, tokens_list[i + j], n_ctx - n_discard + j, { 0 }, false);
llama_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false);
}
if (i + n_batch >= n_tokens_all) {
@ -178,8 +182,6 @@ int main(int argc, char ** argv) {
LOG_TEE("%s: processed: [%6d, %6d)\n", __func__, i, std::min(i + n_batch, n_tokens_all));
}
int n_past = batch.pos[batch.n_tokens - 1];
{
const int n_discard = n_past - n_ctx + n_predict;
@ -236,13 +238,12 @@ int main(int argc, char ** argv) {
fflush(stdout);
n_decode += 1;
n_past += 1;
// prepare the next batch
llama_batch_clear(batch);
// push this new token for next evaluation
llama_batch_add(batch, new_token_id, n_past, { 0 }, true);
llama_batch_add(batch, new_token_id, n_past++, { 0 }, true);
}
n_cur += 1;