diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index f3b83d509..682f90e83 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -10,7 +10,7 @@ int main(int argc, char ** argv) { gpt_params params; if (argc == 1 || argv[1][0] == '-') { - printf("usage: %s MODEL_PATH N_JUNK SEED\n" , argv[0]); + printf("usage: %s MODEL_PATH N_JUNK I_POS SEED\n" , argv[0]); return 1 ; } @@ -18,6 +18,7 @@ int main(int argc, char ** argv) { int n_junk = 250; // number of times to repeat the junk text int n_keep = 32; // number of tokens in the prompt prefix + int i_pos = -1; // position of the passkey in the junk text if (argc >= 2) { params.model = argv[1]; @@ -28,11 +29,12 @@ int main(int argc, char ** argv) { } if (argc >= 4) { - seed = std::stoi(argv[3]); + i_pos = std::stoi(argv[3]); } - const std::string prompt_prefix = "There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there."; - const std::string prompt_suffix = " What is the pass key? The pass key is"; + if (argc >= 5) { + seed = std::stoi(argv[4]); + } if (seed == -1) { seed = time(NULL); @@ -40,14 +42,20 @@ int main(int argc, char ** argv) { srand(seed); + if (i_pos == -1) { + i_pos = rand() % n_junk; + } + + const std::string prompt_prefix = "There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there."; + const std::string prompt_suffix = " What is the pass key? The pass key is"; + // generate junk text params.prompt = prompt_prefix; - const int n_insert = rand() % n_junk; - const int passkey = rand() % 50000 + 1; + const int passkey = rand() % 50000 + 1; for (int i = 0; i < n_junk; i++) { - if (i % n_junk == n_insert) { + if (i % n_junk == i_pos) { params.prompt += " The pass key is " + std::to_string(passkey) + ". Remember it. " + std::to_string(passkey) + " is the pass key."; } @@ -90,18 +98,20 @@ int main(int argc, char ** argv) { return 1; } - // tokenize the prefix and use it as a sink - const int n_tokens_prefix = ::llama_tokenize(ctx, prompt_prefix, true).size(); - // tokenize the prompt std::vector tokens_list; tokens_list = ::llama_tokenize(ctx, params.prompt, true); + // tokenize the prefix and use it as a sink + const int n_tokens_prefix = ::llama_tokenize(ctx, prompt_prefix, true).size(); + + const int n_tokens_all = tokens_list.size(); + // we leave a margin of 16 tokens for the generated text - it should contain just the passkey const int n_predict = 16; // total length of the sequences including the prompt - const int n_len = tokens_list.size() + n_predict; + const int n_len = n_tokens_all + n_predict; const int n_ctx = llama_n_ctx(ctx) - n_keep; const int n_kv_req = llama_n_ctx(ctx); @@ -113,7 +123,7 @@ int main(int argc, char ** argv) { LOG_TEE("\n"); LOG_TEE("prefix tokens: %d\n", n_tokens_prefix); - LOG_TEE("prompt tokens: %d\n", (int) tokens_list.size()); + LOG_TEE("prompt tokens: %d\n", n_tokens_all); //LOG_TEE("prompt: %s\n", params.prompt.c_str()); llama_batch batch = llama_batch_init(512, 0, 1); @@ -122,11 +132,11 @@ int main(int argc, char ** argv) { for (int i = 0; i < n_ctx; i += n_batch) { llama_batch_clear(batch); - for (int j = 0; j < n_batch && i + j < (int) tokens_list.size(); j++) { + for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) { llama_batch_add(batch, tokens_list[i + j], i + j, { 0 }, false); } - if (i + n_batch >= (int) tokens_list.size()) { + if (i + n_batch >= n_tokens_all) { batch.logits[batch.n_tokens - 1] = true; } @@ -135,14 +145,14 @@ int main(int argc, char ** argv) { return 1; } - LOG_TEE("%s: processed: [%6d, %6d)\n", __func__, i, std::min(i + n_batch, (int) tokens_list.size())); + LOG_TEE("%s: processed: [%6d, %6d)\n", __func__, i, std::min(i + n_batch, n_tokens_all)); - if (i + n_batch >= (int) tokens_list.size()) { + if (i + n_batch >= n_tokens_all) { break; } } - for (int i = n_ctx; i < (int) tokens_list.size(); i += n_batch) { + for (int i = n_ctx; i < n_tokens_all; i += n_batch) { const int n_discard = n_batch; LOG_TEE("%s: shifting KV cache with %d\n", __func__, n_discard); @@ -152,11 +162,11 @@ int main(int argc, char ** argv) { llama_batch_clear(batch); - for (int j = 0; j < n_batch && i + j < (int) tokens_list.size(); j++) { + for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) { llama_batch_add(batch, tokens_list[i + j], n_ctx - n_discard + j, { 0 }, false); } - if (i + n_batch >= (int) tokens_list.size()) { + if (i + n_batch >= n_tokens_all) { batch.logits[batch.n_tokens - 1] = true; } @@ -165,7 +175,7 @@ int main(int argc, char ** argv) { return 1; } - LOG_TEE("%s: processed: [%6d, %6d)\n", __func__, i, std::min(i + n_batch, (int) tokens_list.size())); + LOG_TEE("%s: processed: [%6d, %6d)\n", __func__, i, std::min(i + n_batch, n_tokens_all)); } int n_past = batch.pos[batch.n_tokens - 1]; @@ -184,12 +194,12 @@ int main(int argc, char ** argv) { } LOG_TEE("\n"); - LOG_TEE("%s: passkey = %d, inserted at position %d / %d\n", __func__, passkey, n_insert, n_junk); + LOG_TEE("%s: passkey = %d, inserted at position %d / %d (token pos: ~%d)\n", __func__, passkey, i_pos, n_junk, (i_pos * n_tokens_all) / n_junk); LOG_TEE("\n"); // main loop - int n_cur = tokens_list.size(); + int n_cur = n_tokens_all; int n_decode = 0; LOG_TEE("%s", prompt_suffix.c_str());