This commit is contained in:
Georgi Gerganov 2024-12-24 09:54:51 +03:00 committed by GitHub
commit 9a8b96b2c0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 26 additions and 22 deletions

View File

@ -11,7 +11,9 @@
struct common_speculative { struct common_speculative {
struct llama_context * ctx; struct llama_context * ctx;
struct common_sampler * smpl; struct common_sampler * smpl;
struct common_sampler * smpl_infill;
llama_batch batch; llama_batch batch;
llama_tokens prompt; llama_tokens prompt;
@ -20,14 +22,26 @@ struct common_speculative {
struct common_speculative * common_speculative_init( struct common_speculative * common_speculative_init(
struct llama_context * ctx_dft) { struct llama_context * ctx_dft) {
auto * result = new common_speculative { auto * result = new common_speculative {
/* .ctx = */ ctx_dft, /* .ctx = */ ctx_dft,
/* .smpl = */ nullptr, /* .smpl = */ nullptr,
/* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1), /* .smpl_infill = */ nullptr,
/* .prompt = */ {}, /* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1),
/* .prompt = */ {},
}; };
// TODO: optimize or pass from outside? {
#if 0 common_params_sampling params;
params.no_perf = false;
params.top_k = 10;
params.samplers = {
COMMON_SAMPLER_TYPE_TOP_K,
};
result->smpl = common_sampler_init(llama_get_model(ctx_dft), params);
}
{ {
common_params_sampling params; common_params_sampling params;
params.no_perf = false; params.no_perf = false;
@ -41,22 +55,8 @@ struct common_speculative * common_speculative_init(
COMMON_SAMPLER_TYPE_INFILL, COMMON_SAMPLER_TYPE_INFILL,
}; };
result->smpl = common_sampler_init(llama_get_model(ctx_dft), params); result->smpl_infill = common_sampler_init(llama_get_model(ctx_dft), params);
} }
#else
{
common_params_sampling params;
params.no_perf = false;
params.top_k = 10;
params.samplers = {
COMMON_SAMPLER_TYPE_TOP_K,
};
result->smpl = common_sampler_init(llama_get_model(ctx_dft), params);
}
#endif
return result; return result;
} }
@ -67,6 +67,7 @@ void common_speculative_free(struct common_speculative * spec) {
} }
common_sampler_free(spec->smpl); common_sampler_free(spec->smpl);
common_sampler_free(spec->smpl_infill);
llama_batch_free(spec->batch); llama_batch_free(spec->batch);
@ -137,7 +138,7 @@ llama_tokens common_speculative_gen_draft(
llama_token id_last) { llama_token id_last) {
auto & batch = spec->batch; auto & batch = spec->batch;
auto & ctx = spec->ctx; auto & ctx = spec->ctx;
auto & smpl = spec->smpl; auto & smpl = params.infill ? spec->smpl_infill : spec->smpl;
auto & prompt = spec->prompt; auto & prompt = spec->prompt;
int reuse_i = 0; int reuse_i = 0;

View File

@ -10,6 +10,8 @@ struct common_speculative_params {
int n_reuse = 256; int n_reuse = 256;
float p_min = 0.9f; // min probabiliy required to accept a token in the draft float p_min = 0.9f; // min probabiliy required to accept a token in the draft
bool infill = false; // use infill sampling (useful for FIM)
}; };
struct common_speculative * common_speculative_init(struct llama_context * ctx_dft); struct common_speculative * common_speculative_init(struct llama_context * ctx_dft);

View File

@ -2945,6 +2945,7 @@ struct server_context {
params_spec.n_draft = n_draft_max; params_spec.n_draft = n_draft_max;
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max; params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;
params_spec.p_min = slot.params.speculative.p_min; params_spec.p_min = slot.params.speculative.p_min;
params_spec.infill = slot.inf_type == SERVER_TASK_INF_TYPE_INFILL;
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, slot.cache_tokens, id); llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, slot.cache_tokens, id);