mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 03:44:35 +00:00
ggml-ci
This commit is contained in:
parent
4eb126fff0
commit
8f419181d1
@ -156,7 +156,7 @@ struct common_params_sampling {
|
||||
};
|
||||
|
||||
struct common_params_speculative {
|
||||
int32_t n_ctx = 4096; // draft context size
|
||||
int32_t n_ctx = 0; // draft context size
|
||||
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
|
||||
int32_t n_min = 5; // minimum number of draft tokens to use for speculative decoding
|
||||
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
|
||||
|
@ -142,6 +142,8 @@ llama_tokens common_speculative_gen_draft(
|
||||
|
||||
const int i_start = std::max<int>(0, (int) prompt_tgt.size() - n_ctx);
|
||||
|
||||
// reuse as much as possible from the old draft context
|
||||
// ideally, the draft context should be as big as the target context and we will always reuse the entire prompt
|
||||
for (int i = 0; i < (int) prompt.size(); ++i) {
|
||||
int cur = 0;
|
||||
while (i_start + cur < (int) prompt_tgt.size() &&
|
||||
@ -166,6 +168,8 @@ llama_tokens common_speculative_gen_draft(
|
||||
|
||||
prompt.clear();
|
||||
} else {
|
||||
// this happens when a previous draft has been discarded (for example, due to being too small), but the
|
||||
// target model agreed with it. in this case, we simply pass back the previous results to save compute
|
||||
if (reuse_i + reuse_n < (int) prompt.size() && prompt[reuse_i + reuse_n] == id_last) {
|
||||
for (int i = reuse_i + reuse_n + 1; i < (int) prompt.size(); ++i) {
|
||||
result.push_back(prompt[i]);
|
||||
@ -174,42 +178,51 @@ llama_tokens common_speculative_gen_draft(
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
llama_kv_cache_seq_rm (ctx, 0, 0, reuse_i);
|
||||
llama_kv_cache_seq_rm (ctx, 0, reuse_i + reuse_n, -1);
|
||||
llama_kv_cache_seq_add(ctx, 0, reuse_i, -1, -reuse_i);
|
||||
if (reuse_i > 0) {
|
||||
llama_kv_cache_seq_rm (ctx, 0, 0, reuse_i);
|
||||
llama_kv_cache_seq_add(ctx, 0, reuse_i, -1, -reuse_i);
|
||||
|
||||
prompt.erase(prompt.begin(), prompt.begin() + reuse_i);
|
||||
prompt.erase(prompt.begin() + reuse_n, prompt.end());
|
||||
prompt.erase(prompt.begin(), prompt.begin() + reuse_i);
|
||||
}
|
||||
|
||||
if (reuse_n < (int) prompt.size()) {
|
||||
llama_kv_cache_seq_rm (ctx, 0, reuse_n, -1);
|
||||
|
||||
prompt.erase(prompt.begin() + reuse_n, prompt.end());
|
||||
}
|
||||
}
|
||||
|
||||
// prepare a batch to evaluate any new tokens in the prompt
|
||||
common_batch_clear(batch);
|
||||
|
||||
for (int i = i_start + reuse_n; i < (int) prompt_tgt.size(); ++i) {
|
||||
for (size_t i = i_start + reuse_n; i < prompt_tgt.size(); ++i) {
|
||||
//LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]);
|
||||
common_batch_add(batch, prompt_tgt[i], i - i_start, { 0 }, false);
|
||||
|
||||
prompt.push_back(prompt_tgt[i]);
|
||||
}
|
||||
|
||||
const llama_pos n_past = prompt_tgt.size() - i_start;
|
||||
|
||||
LOG_DBG("%s: n_past = %d\n", __func__, n_past);
|
||||
|
||||
// we should rarely end-up here during normal decoding
|
||||
if (batch.n_tokens > 0) {
|
||||
LOG_DBG("%s: draft batch: %s\n", __func__, string_from(ctx, batch).c_str());
|
||||
//LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
|
||||
|
||||
llama_decode(ctx, batch);
|
||||
}
|
||||
|
||||
const llama_pos n_past = prompt.size();
|
||||
|
||||
LOG_DBG("%s: n_past = %d\n", __func__, n_past);
|
||||
|
||||
common_batch_clear(batch);
|
||||
common_batch_add (batch, id_last, n_past, { 0 }, true);
|
||||
|
||||
prompt.push_back(id_last);
|
||||
|
||||
LOG_DBG("%s: prompt_last: %s\n", __func__, string_from(ctx, prompt).c_str());
|
||||
//LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx, prompt).c_str());
|
||||
|
||||
llama_decode(ctx, batch);
|
||||
|
||||
|
@ -6,10 +6,10 @@
|
||||
struct common_speculative;
|
||||
|
||||
struct common_speculative_params {
|
||||
int n_draft = 16;
|
||||
int n_draft = 16; // max drafted tokens
|
||||
int n_reuse = 256;
|
||||
|
||||
float p_min = 0.9f;
|
||||
float p_min = 0.9f; // min probabiliy required to accept a token in the draft
|
||||
};
|
||||
|
||||
struct common_speculative * common_speculative_init(struct llama_context * ctx_dft);
|
||||
@ -21,9 +21,8 @@ bool common_speculative_are_compatible(
|
||||
const struct llama_context * ctx_dft);
|
||||
|
||||
// sample up to n_draft tokens and add them to the batch using the draft model
|
||||
//
|
||||
llama_tokens common_speculative_gen_draft(
|
||||
struct common_speculative * spec,
|
||||
struct common_speculative_params params,
|
||||
const llama_tokens & prompt,
|
||||
llama_token id_last);
|
||||
struct common_speculative * spec,
|
||||
struct common_speculative_params params,
|
||||
const llama_tokens & prompt,
|
||||
llama_token id_last);
|
||||
|
@ -46,8 +46,11 @@ int main(int argc, char ** argv) {
|
||||
ctx_tgt = llama_init_tgt.context;
|
||||
|
||||
// load the draft model
|
||||
params.model = params.speculative.model;
|
||||
params.model = params.speculative.model;
|
||||
params.n_ctx = params.speculative.n_ctx;
|
||||
params.n_batch = params.speculative.n_ctx > 0 ? params.speculative.n_ctx : params.n_batch;
|
||||
params.n_gpu_layers = params.speculative.n_gpu_layers;
|
||||
|
||||
if (params.speculative.cpuparams.n_threads > 0) {
|
||||
params.cpuparams.n_threads = params.speculative.cpuparams.n_threads;
|
||||
}
|
||||
@ -66,8 +69,14 @@ int main(int argc, char ** argv) {
|
||||
std::vector<llama_token> inp;
|
||||
inp = common_tokenize(ctx_tgt, params.prompt, true, true);
|
||||
|
||||
if ((int) inp.size() > llama_n_ctx(ctx_tgt)) {
|
||||
LOG_ERR("%s: prompt too long (%d tokens, max %d)\n", __func__, (int) inp.size(), llama_n_ctx(ctx_tgt));
|
||||
if (llama_n_ctx(ctx_tgt) < (int) inp.size()) {
|
||||
LOG_ERR("%s: the prompt exceeds the context size (%d tokens, ctx %d)\n", __func__, (int) inp.size(), llama_n_ctx(ctx_tgt));
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (llama_n_batch(ctx_tgt) < (int) inp.size()) {
|
||||
LOG_ERR("%s: the prompt exceeds the batch size (%d tokens, batch %d)\n", __func__, (int) inp.size(), llama_n_batch(ctx_tgt));
|
||||
|
||||
return 1;
|
||||
}
|
||||
@ -114,7 +123,7 @@ int main(int argc, char ** argv) {
|
||||
// init the speculator
|
||||
struct common_speculative_params params_spec;
|
||||
params_spec.n_draft = n_draft;
|
||||
params_spec.n_reuse = 256;
|
||||
params_spec.n_reuse = llama_n_ctx(ctx_dft) - n_draft;
|
||||
params_spec.p_min = p_min;
|
||||
|
||||
struct common_speculative * spec = common_speculative_init(ctx_dft);
|
||||
|
Loading…
Reference in New Issue
Block a user