mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-24 10:24:35 +00:00
Some improvements to loading the session with --prompt-cache (#1550)
Improvements to loading the session with `--prompt-cache` in the `main` example. 1. Fix an issue where the `--seed` parameter was ignored when loading a cached prompt. 2. When loading a cached prompt, you previously had to specify the saved prompt (or a prefix of it) again. This pull changes that behavior to default to the prompt that was cached if a prompt wasn't specified by the user.
This commit is contained in:
parent
1fcdcc28b1
commit
66874d4fbc
@ -272,7 +272,7 @@ These options help improve the performance and memory usage of the LLaMA models.
|
|||||||
|
|
||||||
### Prompt Caching
|
### Prompt Caching
|
||||||
|
|
||||||
- `--prompt-cache FNAME`: Specify a file to cache the model state after the initial prompt. This can significantly speed up the startup time when you're using longer prompts. The file is created during the first run and is reused and updated in subsequent runs.
|
- `--prompt-cache FNAME`: Specify a file to cache the model state after the initial prompt. This can significantly speed up the startup time when you're using longer prompts. The file is created during the first run and is reused and updated in subsequent runs. **Note**: Restoring a cached prompt does not imply restoring the exact state of the session at the point it was saved. So even when specifying a specific seed, you are not guaranteed to get the same sequence of tokens as the original generation.
|
||||||
|
|
||||||
### Quantization
|
### Quantization
|
||||||
|
|
||||||
|
@ -134,8 +134,6 @@ int main(int argc, char ** argv) {
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add a space in front of the first character to match OG llama tokenizer behavior
|
|
||||||
params.prompt.insert(0, 1, ' ');
|
|
||||||
|
|
||||||
std::string path_session = params.path_prompt_cache;
|
std::string path_session = params.path_prompt_cache;
|
||||||
std::vector<llama_token> session_tokens;
|
std::vector<llama_token> session_tokens;
|
||||||
@ -155,6 +153,7 @@ int main(int argc, char ** argv) {
|
|||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
session_tokens.resize(n_token_count_out);
|
session_tokens.resize(n_token_count_out);
|
||||||
|
llama_set_rng_seed(ctx, params.seed);
|
||||||
|
|
||||||
fprintf(stderr, "%s: loaded a session with prompt size of %d tokens\n", __func__, (int) session_tokens.size());
|
fprintf(stderr, "%s: loaded a session with prompt size of %d tokens\n", __func__, (int) session_tokens.size());
|
||||||
} else {
|
} else {
|
||||||
@ -163,7 +162,16 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// tokenize the prompt
|
// tokenize the prompt
|
||||||
auto embd_inp = ::llama_tokenize(ctx, params.prompt, true);
|
std::vector<llama_token> embd_inp;
|
||||||
|
|
||||||
|
if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) {
|
||||||
|
// Add a space in front of the first character to match OG llama tokenizer behavior
|
||||||
|
params.prompt.insert(0, 1, ' ');
|
||||||
|
|
||||||
|
embd_inp = ::llama_tokenize(ctx, params.prompt, true);
|
||||||
|
} else {
|
||||||
|
embd_inp = session_tokens;
|
||||||
|
}
|
||||||
|
|
||||||
const int n_ctx = llama_n_ctx(ctx);
|
const int n_ctx = llama_n_ctx(ctx);
|
||||||
|
|
||||||
@ -181,7 +189,9 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
n_matching_session_tokens++;
|
n_matching_session_tokens++;
|
||||||
}
|
}
|
||||||
if (n_matching_session_tokens >= embd_inp.size()) {
|
if (params.prompt.empty() && n_matching_session_tokens == embd_inp.size()) {
|
||||||
|
fprintf(stderr, "%s: using full prompt from session file\n", __func__);
|
||||||
|
} else if (n_matching_session_tokens >= embd_inp.size()) {
|
||||||
fprintf(stderr, "%s: session file has exact match for prompt!\n", __func__);
|
fprintf(stderr, "%s: session file has exact match for prompt!\n", __func__);
|
||||||
} else if (n_matching_session_tokens < (embd_inp.size() / 2)) {
|
} else if (n_matching_session_tokens < (embd_inp.size() / 2)) {
|
||||||
fprintf(stderr, "%s: warning: session file has low similarity to prompt (%zu / %zu tokens); will mostly be reevaluated\n",
|
fprintf(stderr, "%s: warning: session file has low similarity to prompt (%zu / %zu tokens); will mostly be reevaluated\n",
|
||||||
|
Loading…
Reference in New Issue
Block a user