passkey : simplify n_past logic

This commit is contained in:
Georgi Gerganov 2024-01-07 17:52:12 +02:00
parent bda3f2c892
commit f2c9800dfb
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -128,12 +128,14 @@ int main(int argc, char ** argv) {
llama_batch batch = llama_batch_init(512, 0, 1); llama_batch batch = llama_batch_init(512, 0, 1);
int n_past = 0;
// fill the KV cache // fill the KV cache
for (int i = 0; i < n_ctx; i += n_batch) { for (int i = 0; i < n_ctx; i += n_batch) {
llama_batch_clear(batch); llama_batch_clear(batch);
for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) { for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) {
llama_batch_add(batch, tokens_list[i + j], i + j, { 0 }, false); llama_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false);
} }
if (i + n_batch >= n_tokens_all) { if (i + n_batch >= n_tokens_all) {
@ -160,10 +162,12 @@ int main(int argc, char ** argv) {
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard); llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
llama_kv_cache_seq_shift(ctx, 0, n_keep + n_discard, n_ctx, -n_discard); llama_kv_cache_seq_shift(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
n_past -= n_discard;
llama_batch_clear(batch); llama_batch_clear(batch);
for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) { for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) {
llama_batch_add(batch, tokens_list[i + j], n_ctx - n_discard + j, { 0 }, false); llama_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false);
} }
if (i + n_batch >= n_tokens_all) { if (i + n_batch >= n_tokens_all) {
@ -178,8 +182,6 @@ int main(int argc, char ** argv) {
LOG_TEE("%s: processed: [%6d, %6d)\n", __func__, i, std::min(i + n_batch, n_tokens_all)); LOG_TEE("%s: processed: [%6d, %6d)\n", __func__, i, std::min(i + n_batch, n_tokens_all));
} }
int n_past = batch.pos[batch.n_tokens - 1];
{ {
const int n_discard = n_past - n_ctx + n_predict; const int n_discard = n_past - n_ctx + n_predict;
@ -236,13 +238,12 @@ int main(int argc, char ** argv) {
fflush(stdout); fflush(stdout);
n_decode += 1; n_decode += 1;
n_past += 1;
// prepare the next batch // prepare the next batch
llama_batch_clear(batch); llama_batch_clear(batch);
// push this new token for next evaluation // push this new token for next evaluation
llama_batch_add(batch, new_token_id, n_past, { 0 }, true); llama_batch_add(batch, new_token_id, n_past++, { 0 }, true);
} }
n_cur += 1; n_cur += 1;