mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-28 12:24:35 +00:00
passkey : simplify n_past logic
This commit is contained in:
parent
bda3f2c892
commit
f2c9800dfb
@ -128,12 +128,14 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
llama_batch batch = llama_batch_init(512, 0, 1);
|
llama_batch batch = llama_batch_init(512, 0, 1);
|
||||||
|
|
||||||
|
int n_past = 0;
|
||||||
|
|
||||||
// fill the KV cache
|
// fill the KV cache
|
||||||
for (int i = 0; i < n_ctx; i += n_batch) {
|
for (int i = 0; i < n_ctx; i += n_batch) {
|
||||||
llama_batch_clear(batch);
|
llama_batch_clear(batch);
|
||||||
|
|
||||||
for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) {
|
for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) {
|
||||||
llama_batch_add(batch, tokens_list[i + j], i + j, { 0 }, false);
|
llama_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (i + n_batch >= n_tokens_all) {
|
if (i + n_batch >= n_tokens_all) {
|
||||||
@ -160,10 +162,12 @@ int main(int argc, char ** argv) {
|
|||||||
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
|
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
|
||||||
llama_kv_cache_seq_shift(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
|
llama_kv_cache_seq_shift(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
|
||||||
|
|
||||||
|
n_past -= n_discard;
|
||||||
|
|
||||||
llama_batch_clear(batch);
|
llama_batch_clear(batch);
|
||||||
|
|
||||||
for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) {
|
for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) {
|
||||||
llama_batch_add(batch, tokens_list[i + j], n_ctx - n_discard + j, { 0 }, false);
|
llama_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (i + n_batch >= n_tokens_all) {
|
if (i + n_batch >= n_tokens_all) {
|
||||||
@ -178,8 +182,6 @@ int main(int argc, char ** argv) {
|
|||||||
LOG_TEE("%s: processed: [%6d, %6d)\n", __func__, i, std::min(i + n_batch, n_tokens_all));
|
LOG_TEE("%s: processed: [%6d, %6d)\n", __func__, i, std::min(i + n_batch, n_tokens_all));
|
||||||
}
|
}
|
||||||
|
|
||||||
int n_past = batch.pos[batch.n_tokens - 1];
|
|
||||||
|
|
||||||
{
|
{
|
||||||
const int n_discard = n_past - n_ctx + n_predict;
|
const int n_discard = n_past - n_ctx + n_predict;
|
||||||
|
|
||||||
@ -236,13 +238,12 @@ int main(int argc, char ** argv) {
|
|||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
|
|
||||||
n_decode += 1;
|
n_decode += 1;
|
||||||
n_past += 1;
|
|
||||||
|
|
||||||
// prepare the next batch
|
// prepare the next batch
|
||||||
llama_batch_clear(batch);
|
llama_batch_clear(batch);
|
||||||
|
|
||||||
// push this new token for next evaluation
|
// push this new token for next evaluation
|
||||||
llama_batch_add(batch, new_token_id, n_past, { 0 }, true);
|
llama_batch_add(batch, new_token_id, n_past++, { 0 }, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
n_cur += 1;
|
n_cur += 1;
|
||||||
|
Loading…
Reference in New Issue
Block a user