mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 19:04:35 +00:00
Merge b83cae088c
into 60cfa728e2
This commit is contained in:
commit
9a8b96b2c0
@ -11,7 +11,9 @@
|
||||
|
||||
struct common_speculative {
|
||||
struct llama_context * ctx;
|
||||
|
||||
struct common_sampler * smpl;
|
||||
struct common_sampler * smpl_infill;
|
||||
|
||||
llama_batch batch;
|
||||
llama_tokens prompt;
|
||||
@ -20,14 +22,26 @@ struct common_speculative {
|
||||
struct common_speculative * common_speculative_init(
|
||||
struct llama_context * ctx_dft) {
|
||||
auto * result = new common_speculative {
|
||||
/* .ctx = */ ctx_dft,
|
||||
/* .smpl = */ nullptr,
|
||||
/* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1),
|
||||
/* .prompt = */ {},
|
||||
/* .ctx = */ ctx_dft,
|
||||
/* .smpl = */ nullptr,
|
||||
/* .smpl_infill = */ nullptr,
|
||||
/* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1),
|
||||
/* .prompt = */ {},
|
||||
};
|
||||
|
||||
// TODO: optimize or pass from outside?
|
||||
#if 0
|
||||
{
|
||||
common_params_sampling params;
|
||||
params.no_perf = false;
|
||||
|
||||
params.top_k = 10;
|
||||
|
||||
params.samplers = {
|
||||
COMMON_SAMPLER_TYPE_TOP_K,
|
||||
};
|
||||
|
||||
result->smpl = common_sampler_init(llama_get_model(ctx_dft), params);
|
||||
}
|
||||
|
||||
{
|
||||
common_params_sampling params;
|
||||
params.no_perf = false;
|
||||
@ -41,22 +55,8 @@ struct common_speculative * common_speculative_init(
|
||||
COMMON_SAMPLER_TYPE_INFILL,
|
||||
};
|
||||
|
||||
result->smpl = common_sampler_init(llama_get_model(ctx_dft), params);
|
||||
result->smpl_infill = common_sampler_init(llama_get_model(ctx_dft), params);
|
||||
}
|
||||
#else
|
||||
{
|
||||
common_params_sampling params;
|
||||
params.no_perf = false;
|
||||
|
||||
params.top_k = 10;
|
||||
|
||||
params.samplers = {
|
||||
COMMON_SAMPLER_TYPE_TOP_K,
|
||||
};
|
||||
|
||||
result->smpl = common_sampler_init(llama_get_model(ctx_dft), params);
|
||||
}
|
||||
#endif
|
||||
|
||||
return result;
|
||||
}
|
||||
@ -67,6 +67,7 @@ void common_speculative_free(struct common_speculative * spec) {
|
||||
}
|
||||
|
||||
common_sampler_free(spec->smpl);
|
||||
common_sampler_free(spec->smpl_infill);
|
||||
|
||||
llama_batch_free(spec->batch);
|
||||
|
||||
@ -137,7 +138,7 @@ llama_tokens common_speculative_gen_draft(
|
||||
llama_token id_last) {
|
||||
auto & batch = spec->batch;
|
||||
auto & ctx = spec->ctx;
|
||||
auto & smpl = spec->smpl;
|
||||
auto & smpl = params.infill ? spec->smpl_infill : spec->smpl;
|
||||
auto & prompt = spec->prompt;
|
||||
|
||||
int reuse_i = 0;
|
||||
|
@ -10,6 +10,8 @@ struct common_speculative_params {
|
||||
int n_reuse = 256;
|
||||
|
||||
float p_min = 0.9f; // min probabiliy required to accept a token in the draft
|
||||
|
||||
bool infill = false; // use infill sampling (useful for FIM)
|
||||
};
|
||||
|
||||
struct common_speculative * common_speculative_init(struct llama_context * ctx_dft);
|
||||
|
@ -2945,6 +2945,7 @@ struct server_context {
|
||||
params_spec.n_draft = n_draft_max;
|
||||
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;
|
||||
params_spec.p_min = slot.params.speculative.p_min;
|
||||
params_spec.infill = slot.inf_type == SERVER_TASK_INF_TYPE_INFILL;
|
||||
|
||||
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, slot.cache_tokens, id);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user