diff --git a/common/speculative.cpp b/common/speculative.cpp index 3fcbb0020..d5a1bd803 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -11,7 +11,9 @@ struct common_speculative { struct llama_context * ctx; + struct common_sampler * smpl; + struct common_sampler * smpl_infill; llama_batch batch; llama_tokens prompt; @@ -20,14 +22,26 @@ struct common_speculative { struct common_speculative * common_speculative_init( struct llama_context * ctx_dft) { auto * result = new common_speculative { - /* .ctx = */ ctx_dft, - /* .smpl = */ nullptr, - /* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1), - /* .prompt = */ {}, + /* .ctx = */ ctx_dft, + /* .smpl = */ nullptr, + /* .smpl_infill = */ nullptr, + /* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1), + /* .prompt = */ {}, }; - // TODO: optimize or pass from outside? -#if 0 + { + common_params_sampling params; + params.no_perf = false; + + params.top_k = 10; + + params.samplers = { + COMMON_SAMPLER_TYPE_TOP_K, + }; + + result->smpl = common_sampler_init(llama_get_model(ctx_dft), params); + } + { common_params_sampling params; params.no_perf = false; @@ -41,22 +55,8 @@ struct common_speculative * common_speculative_init( COMMON_SAMPLER_TYPE_INFILL, }; - result->smpl = common_sampler_init(llama_get_model(ctx_dft), params); + result->smpl_infill = common_sampler_init(llama_get_model(ctx_dft), params); } -#else - { - common_params_sampling params; - params.no_perf = false; - - params.top_k = 10; - - params.samplers = { - COMMON_SAMPLER_TYPE_TOP_K, - }; - - result->smpl = common_sampler_init(llama_get_model(ctx_dft), params); - } -#endif return result; } @@ -67,6 +67,7 @@ void common_speculative_free(struct common_speculative * spec) { } common_sampler_free(spec->smpl); + common_sampler_free(spec->smpl_infill); llama_batch_free(spec->batch); @@ -137,7 +138,7 @@ llama_tokens common_speculative_gen_draft( llama_token id_last) { auto & batch = spec->batch; auto & ctx = spec->ctx; - auto & smpl = spec->smpl; + auto & smpl = params.infill ? spec->smpl_infill : spec->smpl; auto & prompt = spec->prompt; int reuse_i = 0; diff --git a/common/speculative.h b/common/speculative.h index 50ec03446..3ff126ab5 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -10,6 +10,8 @@ struct common_speculative_params { int n_reuse = 256; float p_min = 0.9f; // min probabiliy required to accept a token in the draft + + bool infill = false; // use infill sampling (useful for FIM) }; struct common_speculative * common_speculative_init(struct llama_context * ctx_dft); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index fa3682a92..beaace4c7 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2942,6 +2942,7 @@ struct server_context { params_spec.n_draft = n_draft_max; params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max; params_spec.p_min = slot.params.speculative.p_min; + params_spec.infill = slot.inf_type == SERVER_TASK_INF_TYPE_INFILL; llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, slot.cache_tokens, id);