passkey : select pass key pos from CLI

This commit is contained in:
Georgi Gerganov 2024-01-07 14:48:09 +02:00
parent fbb999f592
commit bda3f2c892
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -10,7 +10,7 @@ int main(int argc, char ** argv) {
gpt_params params; gpt_params params;
if (argc == 1 || argv[1][0] == '-') { if (argc == 1 || argv[1][0] == '-') {
printf("usage: %s MODEL_PATH N_JUNK SEED\n" , argv[0]); printf("usage: %s MODEL_PATH N_JUNK I_POS SEED\n" , argv[0]);
return 1 ; return 1 ;
} }
@ -18,6 +18,7 @@ int main(int argc, char ** argv) {
int n_junk = 250; // number of times to repeat the junk text int n_junk = 250; // number of times to repeat the junk text
int n_keep = 32; // number of tokens in the prompt prefix int n_keep = 32; // number of tokens in the prompt prefix
int i_pos = -1; // position of the passkey in the junk text
if (argc >= 2) { if (argc >= 2) {
params.model = argv[1]; params.model = argv[1];
@ -28,11 +29,12 @@ int main(int argc, char ** argv) {
} }
if (argc >= 4) { if (argc >= 4) {
seed = std::stoi(argv[3]); i_pos = std::stoi(argv[3]);
} }
const std::string prompt_prefix = "There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there."; if (argc >= 5) {
const std::string prompt_suffix = " What is the pass key? The pass key is"; seed = std::stoi(argv[4]);
}
if (seed == -1) { if (seed == -1) {
seed = time(NULL); seed = time(NULL);
@ -40,14 +42,20 @@ int main(int argc, char ** argv) {
srand(seed); srand(seed);
if (i_pos == -1) {
i_pos = rand() % n_junk;
}
const std::string prompt_prefix = "There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.";
const std::string prompt_suffix = " What is the pass key? The pass key is";
// generate junk text // generate junk text
params.prompt = prompt_prefix; params.prompt = prompt_prefix;
const int n_insert = rand() % n_junk; const int passkey = rand() % 50000 + 1;
const int passkey = rand() % 50000 + 1;
for (int i = 0; i < n_junk; i++) { for (int i = 0; i < n_junk; i++) {
if (i % n_junk == n_insert) { if (i % n_junk == i_pos) {
params.prompt += " The pass key is " + std::to_string(passkey) + ". Remember it. " + std::to_string(passkey) + " is the pass key."; params.prompt += " The pass key is " + std::to_string(passkey) + ". Remember it. " + std::to_string(passkey) + " is the pass key.";
} }
@ -90,18 +98,20 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
// tokenize the prefix and use it as a sink
const int n_tokens_prefix = ::llama_tokenize(ctx, prompt_prefix, true).size();
// tokenize the prompt // tokenize the prompt
std::vector<llama_token> tokens_list; std::vector<llama_token> tokens_list;
tokens_list = ::llama_tokenize(ctx, params.prompt, true); tokens_list = ::llama_tokenize(ctx, params.prompt, true);
// tokenize the prefix and use it as a sink
const int n_tokens_prefix = ::llama_tokenize(ctx, prompt_prefix, true).size();
const int n_tokens_all = tokens_list.size();
// we leave a margin of 16 tokens for the generated text - it should contain just the passkey // we leave a margin of 16 tokens for the generated text - it should contain just the passkey
const int n_predict = 16; const int n_predict = 16;
// total length of the sequences including the prompt // total length of the sequences including the prompt
const int n_len = tokens_list.size() + n_predict; const int n_len = n_tokens_all + n_predict;
const int n_ctx = llama_n_ctx(ctx) - n_keep; const int n_ctx = llama_n_ctx(ctx) - n_keep;
const int n_kv_req = llama_n_ctx(ctx); const int n_kv_req = llama_n_ctx(ctx);
@ -113,7 +123,7 @@ int main(int argc, char ** argv) {
LOG_TEE("\n"); LOG_TEE("\n");
LOG_TEE("prefix tokens: %d\n", n_tokens_prefix); LOG_TEE("prefix tokens: %d\n", n_tokens_prefix);
LOG_TEE("prompt tokens: %d\n", (int) tokens_list.size()); LOG_TEE("prompt tokens: %d\n", n_tokens_all);
//LOG_TEE("prompt: %s\n", params.prompt.c_str()); //LOG_TEE("prompt: %s\n", params.prompt.c_str());
llama_batch batch = llama_batch_init(512, 0, 1); llama_batch batch = llama_batch_init(512, 0, 1);
@ -122,11 +132,11 @@ int main(int argc, char ** argv) {
for (int i = 0; i < n_ctx; i += n_batch) { for (int i = 0; i < n_ctx; i += n_batch) {
llama_batch_clear(batch); llama_batch_clear(batch);
for (int j = 0; j < n_batch && i + j < (int) tokens_list.size(); j++) { for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) {
llama_batch_add(batch, tokens_list[i + j], i + j, { 0 }, false); llama_batch_add(batch, tokens_list[i + j], i + j, { 0 }, false);
} }
if (i + n_batch >= (int) tokens_list.size()) { if (i + n_batch >= n_tokens_all) {
batch.logits[batch.n_tokens - 1] = true; batch.logits[batch.n_tokens - 1] = true;
} }
@ -135,14 +145,14 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
LOG_TEE("%s: processed: [%6d, %6d)\n", __func__, i, std::min(i + n_batch, (int) tokens_list.size())); LOG_TEE("%s: processed: [%6d, %6d)\n", __func__, i, std::min(i + n_batch, n_tokens_all));
if (i + n_batch >= (int) tokens_list.size()) { if (i + n_batch >= n_tokens_all) {
break; break;
} }
} }
for (int i = n_ctx; i < (int) tokens_list.size(); i += n_batch) { for (int i = n_ctx; i < n_tokens_all; i += n_batch) {
const int n_discard = n_batch; const int n_discard = n_batch;
LOG_TEE("%s: shifting KV cache with %d\n", __func__, n_discard); LOG_TEE("%s: shifting KV cache with %d\n", __func__, n_discard);
@ -152,11 +162,11 @@ int main(int argc, char ** argv) {
llama_batch_clear(batch); llama_batch_clear(batch);
for (int j = 0; j < n_batch && i + j < (int) tokens_list.size(); j++) { for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) {
llama_batch_add(batch, tokens_list[i + j], n_ctx - n_discard + j, { 0 }, false); llama_batch_add(batch, tokens_list[i + j], n_ctx - n_discard + j, { 0 }, false);
} }
if (i + n_batch >= (int) tokens_list.size()) { if (i + n_batch >= n_tokens_all) {
batch.logits[batch.n_tokens - 1] = true; batch.logits[batch.n_tokens - 1] = true;
} }
@ -165,7 +175,7 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
LOG_TEE("%s: processed: [%6d, %6d)\n", __func__, i, std::min(i + n_batch, (int) tokens_list.size())); LOG_TEE("%s: processed: [%6d, %6d)\n", __func__, i, std::min(i + n_batch, n_tokens_all));
} }
int n_past = batch.pos[batch.n_tokens - 1]; int n_past = batch.pos[batch.n_tokens - 1];
@ -184,12 +194,12 @@ int main(int argc, char ** argv) {
} }
LOG_TEE("\n"); LOG_TEE("\n");
LOG_TEE("%s: passkey = %d, inserted at position %d / %d\n", __func__, passkey, n_insert, n_junk); LOG_TEE("%s: passkey = %d, inserted at position %d / %d (token pos: ~%d)\n", __func__, passkey, i_pos, n_junk, (i_pos * n_tokens_all) / n_junk);
LOG_TEE("\n"); LOG_TEE("\n");
// main loop // main loop
int n_cur = tokens_list.size(); int n_cur = n_tokens_all;
int n_decode = 0; int n_decode = 0;
LOG_TEE("%s", prompt_suffix.c_str()); LOG_TEE("%s", prompt_suffix.c_str());