common : final touches
Some checks failed
flake8 Lint / Lint (push) Has been cancelled

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-11-24 19:19:12 +02:00
parent 4eb126fff0
commit 8f419181d1
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
4 changed files with 45 additions and 24 deletions

View File

@ -156,7 +156,7 @@ struct common_params_sampling {
};
struct common_params_speculative {
int32_t n_ctx = 4096; // draft context size
int32_t n_ctx = 0; // draft context size
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
int32_t n_min = 5; // minimum number of draft tokens to use for speculative decoding
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)

View File

@ -142,6 +142,8 @@ llama_tokens common_speculative_gen_draft(
const int i_start = std::max<int>(0, (int) prompt_tgt.size() - n_ctx);
// reuse as much as possible from the old draft context
// ideally, the draft context should be as big as the target context and we will always reuse the entire prompt
for (int i = 0; i < (int) prompt.size(); ++i) {
int cur = 0;
while (i_start + cur < (int) prompt_tgt.size() &&
@ -166,6 +168,8 @@ llama_tokens common_speculative_gen_draft(
prompt.clear();
} else {
// this happens when a previous draft has been discarded (for example, due to being too small), but the
// target model agreed with it. in this case, we simply pass back the previous results to save compute
if (reuse_i + reuse_n < (int) prompt.size() && prompt[reuse_i + reuse_n] == id_last) {
for (int i = reuse_i + reuse_n + 1; i < (int) prompt.size(); ++i) {
result.push_back(prompt[i]);
@ -174,42 +178,51 @@ llama_tokens common_speculative_gen_draft(
break;
}
}
return result;
}
llama_kv_cache_seq_rm (ctx, 0, 0, reuse_i);
llama_kv_cache_seq_rm (ctx, 0, reuse_i + reuse_n, -1);
llama_kv_cache_seq_add(ctx, 0, reuse_i, -1, -reuse_i);
if (reuse_i > 0) {
llama_kv_cache_seq_rm (ctx, 0, 0, reuse_i);
llama_kv_cache_seq_add(ctx, 0, reuse_i, -1, -reuse_i);
prompt.erase(prompt.begin(), prompt.begin() + reuse_i);
prompt.erase(prompt.begin() + reuse_n, prompt.end());
prompt.erase(prompt.begin(), prompt.begin() + reuse_i);
}
if (reuse_n < (int) prompt.size()) {
llama_kv_cache_seq_rm (ctx, 0, reuse_n, -1);
prompt.erase(prompt.begin() + reuse_n, prompt.end());
}
}
// prepare a batch to evaluate any new tokens in the prompt
common_batch_clear(batch);
for (int i = i_start + reuse_n; i < (int) prompt_tgt.size(); ++i) {
for (size_t i = i_start + reuse_n; i < prompt_tgt.size(); ++i) {
//LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]);
common_batch_add(batch, prompt_tgt[i], i - i_start, { 0 }, false);
prompt.push_back(prompt_tgt[i]);
}
const llama_pos n_past = prompt_tgt.size() - i_start;
LOG_DBG("%s: n_past = %d\n", __func__, n_past);
// we should rarely end-up here during normal decoding
if (batch.n_tokens > 0) {
LOG_DBG("%s: draft batch: %s\n", __func__, string_from(ctx, batch).c_str());
//LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
llama_decode(ctx, batch);
}
const llama_pos n_past = prompt.size();
LOG_DBG("%s: n_past = %d\n", __func__, n_past);
common_batch_clear(batch);
common_batch_add (batch, id_last, n_past, { 0 }, true);
prompt.push_back(id_last);
LOG_DBG("%s: prompt_last: %s\n", __func__, string_from(ctx, prompt).c_str());
//LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx, prompt).c_str());
llama_decode(ctx, batch);

View File

@ -6,10 +6,10 @@
struct common_speculative;
struct common_speculative_params {
int n_draft = 16;
int n_draft = 16; // max drafted tokens
int n_reuse = 256;
float p_min = 0.9f;
float p_min = 0.9f; // min probabiliy required to accept a token in the draft
};
struct common_speculative * common_speculative_init(struct llama_context * ctx_dft);
@ -21,9 +21,8 @@ bool common_speculative_are_compatible(
const struct llama_context * ctx_dft);
// sample up to n_draft tokens and add them to the batch using the draft model
//
llama_tokens common_speculative_gen_draft(
struct common_speculative * spec,
struct common_speculative_params params,
const llama_tokens & prompt,
llama_token id_last);
struct common_speculative * spec,
struct common_speculative_params params,
const llama_tokens & prompt,
llama_token id_last);

View File

@ -46,8 +46,11 @@ int main(int argc, char ** argv) {
ctx_tgt = llama_init_tgt.context;
// load the draft model
params.model = params.speculative.model;
params.model = params.speculative.model;
params.n_ctx = params.speculative.n_ctx;
params.n_batch = params.speculative.n_ctx > 0 ? params.speculative.n_ctx : params.n_batch;
params.n_gpu_layers = params.speculative.n_gpu_layers;
if (params.speculative.cpuparams.n_threads > 0) {
params.cpuparams.n_threads = params.speculative.cpuparams.n_threads;
}
@ -66,8 +69,14 @@ int main(int argc, char ** argv) {
std::vector<llama_token> inp;
inp = common_tokenize(ctx_tgt, params.prompt, true, true);
if ((int) inp.size() > llama_n_ctx(ctx_tgt)) {
LOG_ERR("%s: prompt too long (%d tokens, max %d)\n", __func__, (int) inp.size(), llama_n_ctx(ctx_tgt));
if (llama_n_ctx(ctx_tgt) < (int) inp.size()) {
LOG_ERR("%s: the prompt exceeds the context size (%d tokens, ctx %d)\n", __func__, (int) inp.size(), llama_n_ctx(ctx_tgt));
return 1;
}
if (llama_n_batch(ctx_tgt) < (int) inp.size()) {
LOG_ERR("%s: the prompt exceeds the batch size (%d tokens, batch %d)\n", __func__, (int) inp.size(), llama_n_batch(ctx_tgt));
return 1;
}
@ -114,7 +123,7 @@ int main(int argc, char ** argv) {
// init the speculator
struct common_speculative_params params_spec;
params_spec.n_draft = n_draft;
params_spec.n_reuse = 256;
params_spec.n_reuse = llama_n_ctx(ctx_dft) - n_draft;
params_spec.p_min = p_min;
struct common_speculative * spec = common_speculative_init(ctx_dft);