2024-09-09 21:36:09 +00:00
|
|
|
#include "arg.h"
|
2023-09-03 12:12:08 +00:00
|
|
|
#include "common.h"
|
2024-09-09 21:36:09 +00:00
|
|
|
#include "sampling.h"
|
2024-09-15 17:46:12 +00:00
|
|
|
#include "log.h"
|
2023-09-03 12:12:08 +00:00
|
|
|
#include "llama.h"
|
|
|
|
|
2024-09-15 17:46:12 +00:00
|
|
|
#include <algorithm>
|
2023-09-03 12:12:08 +00:00
|
|
|
#include <cstdio>
|
2024-09-15 17:46:12 +00:00
|
|
|
#include <cstring>
|
|
|
|
#include <random>
|
|
|
|
#include <set>
|
2023-09-03 12:12:08 +00:00
|
|
|
#include <string>
|
|
|
|
#include <vector>
|
|
|
|
|
2023-10-27 21:40:07 +00:00
|
|
|
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 100
|
|
|
|
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
|
|
|
|
|
2023-10-18 13:21:57 +00:00
|
|
|
struct seq_draft {
|
|
|
|
bool active = false;
|
|
|
|
bool drafting = false;
|
|
|
|
bool skip = false;
|
|
|
|
|
|
|
|
int i_batch_dft = 0;
|
|
|
|
std::vector<int> i_batch_tgt;
|
|
|
|
|
|
|
|
std::vector<llama_token> tokens;
|
2024-03-04 18:24:00 +00:00
|
|
|
std::vector<std::vector<llama_token_data>> dists;
|
2023-10-18 13:21:57 +00:00
|
|
|
|
2024-10-10 20:57:42 +00:00
|
|
|
struct common_sampler * smpl = nullptr;
|
2023-10-18 13:21:57 +00:00
|
|
|
};
|
|
|
|
|
2023-09-03 12:12:08 +00:00
|
|
|
int main(int argc, char ** argv) {
|
2024-10-10 20:57:42 +00:00
|
|
|
common_params params;
|
2023-09-03 12:12:08 +00:00
|
|
|
|
2024-09-24 06:03:17 +00:00
|
|
|
// needed to get candidate probs even for temp <= 0.0
|
|
|
|
params.sparams.n_probs = 128;
|
|
|
|
|
2024-10-10 20:57:42 +00:00
|
|
|
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SPECULATIVE)) {
|
2023-09-03 12:12:08 +00:00
|
|
|
return 1;
|
|
|
|
}
|
|
|
|
|
2024-10-21 06:37:12 +00:00
|
|
|
if (params.n_predict < -1) {
|
|
|
|
LOG_ERR("%s: --n-predict must be >= -1\n", __func__);
|
|
|
|
return 1;
|
|
|
|
}
|
|
|
|
|
2024-10-10 20:57:42 +00:00
|
|
|
common_init();
|
2024-09-15 17:46:12 +00:00
|
|
|
|
2023-09-03 12:12:08 +00:00
|
|
|
if (params.model_draft.empty()) {
|
2024-09-15 17:46:12 +00:00
|
|
|
LOG_ERR("%s: --model-draft is required\n", __func__);
|
2023-09-03 12:12:08 +00:00
|
|
|
return 1;
|
|
|
|
}
|
|
|
|
|
2023-10-18 13:21:57 +00:00
|
|
|
// max number of parallel drafting sequences (i.e. tree branches)
|
|
|
|
const int n_seq_dft = params.n_parallel;
|
|
|
|
|
2023-11-03 07:41:17 +00:00
|
|
|
// probability threshold for splitting a draft branch (only for n_seq_dft > 1)
|
|
|
|
const float p_split = params.p_split;
|
2023-10-18 13:21:57 +00:00
|
|
|
|
2024-09-24 06:03:17 +00:00
|
|
|
std::default_random_engine rng(params.sparams.seed == LLAMA_DEFAULT_SEED ? std::random_device()() : params.sparams.seed);
|
2024-03-04 18:24:00 +00:00
|
|
|
std::uniform_real_distribution<> u_dist;
|
|
|
|
|
2023-09-03 12:12:08 +00:00
|
|
|
// init llama.cpp
|
2024-02-16 09:31:07 +00:00
|
|
|
llama_backend_init();
|
|
|
|
llama_numa_init(params.numa);
|
2023-09-03 12:12:08 +00:00
|
|
|
|
|
|
|
llama_model * model_tgt = NULL;
|
|
|
|
llama_model * model_dft = NULL;
|
|
|
|
|
|
|
|
llama_context * ctx_tgt = NULL;
|
|
|
|
llama_context * ctx_dft = NULL;
|
|
|
|
|
|
|
|
// load the target model
|
2024-10-10 20:57:42 +00:00
|
|
|
common_init_result llama_init_tgt = common_init_from_params(params);
|
2024-08-05 16:14:10 +00:00
|
|
|
model_tgt = llama_init_tgt.model;
|
|
|
|
ctx_tgt = llama_init_tgt.context;
|
2023-09-03 12:12:08 +00:00
|
|
|
|
|
|
|
// load the draft model
|
|
|
|
params.model = params.model_draft;
|
2023-09-13 06:50:46 +00:00
|
|
|
params.n_gpu_layers = params.n_gpu_layers_draft;
|
Threadpool: take 2 (#8672)
* Introduce ggml_compute_threadpool
- OpenMP functional: check
- Vanilla ggml functional: Check
- ggml w/threadpool functional: Check
- OpenMP no regression: No glaring problems
- Vanilla ggml no regression: No glaring problems
- ggml w/threadpool no regression: No glaring problems
* Minor fixes
* fixed use after release bug
* fixed a harmless race condition
* Fix Android bulid issue
* fix more race conditions
* fix deadlock for cases where cgraph.n_nodes == 1
and fix --poll case
* threadpool: use cpu_get_num_math to set the default number of threadpool threads
This way we avoid using E-Cores and Hyperthreaded siblings.
* bench: create fresh threadpool for each test
For benchmarking it's better to start a fresh pool for each test with the exact number of threads
needed for that test. Having larger pools is suboptimal (causes more load, etc).
* atomics: always use stdatomics with clang and use relaxed memory order when polling in ggml_barrier
This also removes sched_yield() calls from ggml_barrier() to match OpenMP behavior.
* threadpool: make polling the default to match openmp behavior
All command line args now allow for setting poll to 0 (false).
* threadpool: do not wakeup threads in already paused threadpool
* fix potential race condition in check_for_work
* threadpool: do not create two threadpools if their params are identical
* threadpool: reduce pause/resume/wakeup overhead in common cases
We now start threadpool in paused state only if we have two.
The resume is now implicit (ie new work) which allows for reduced locking and context-switch overhead.
* threadpool: add support for hybrid polling
poll params (--poll, ...) now specify "polling level", i.e. how aggresively we poll before waiting on cond.var.
poll=0 means no polling, 1 means poll for 128K rounds then wait, 2 for 256K rounds, ...
The default value of 50 (ie 50x128K rounds) seems like a decent default across modern platforms.
We can tune this further as things evolve.
* threadpool: reduce the number of barrier required
New work is now indicated with an atomic counter that is incremented for
each new graph that needs to be computed.
This removes the need for extra barrier for clearing the "new_work" and
removes the special case for trivial graphs.
* threadpool: remove special-casing for disposable threadpools
With the efficient hybrid polling there is no need to make disposable pools any different.
This simplifies the overall logic and reduces branching.
Include n_threads in debug print for disposable threadpool.
Declare pause and stop flags as atomic_bool
This doesn't actually generate any memory barriers and simply informs
the thread sanitizer that these flags can be written & read by different
threads without locking.
* threadpool: do not clear barrier counters between graphs computes (fixes race with small graphs)
This fixes the race condition with very small graphs where the main thread happens to
start a new graph while the workers are just about to exit from barriers.
* threadpool: use relaxed order for chunk sync
Full memory barrier is an overkill for this since each thread works on different chunk
* threadpool: remove abort_callback from threadpool state
* threadpool: better naming for thread/cpumask releated functions
* threadpool: consistent use of int type for n_threads params
* threadpool: add support for ggml_threadpool_params_default/init
Also removes the need for explicit mask_specified param.
all-zero cpumask means use default (usually inherited) cpu affinity mask.
* threadpool: move typedef into ggml.h
* threadpool: fix apply_priority() function name
* threadpool: fix swift wrapper errors due to n_threads int type cleanup
* threadpool: enable --cpu-mask and other threadpool related options only if threadpool is enabled
* threadpool: replace checks for compute_thread ret code with proper status check
* threadpool: simplify threadpool init logic and fix main thread affinity application
Most of the init code is now exactly the same between threadpool and openmp.
* threadpool: update threadpool resume/pause function names
* threadpool: enable openmp by default for now
* threadpool: don't forget to free workers state when omp is enabled
* threadpool: avoid updating process priority on the platforms that do not require it
On Windows we need to change overall process priority class in order to set thread priorities,
but on Linux, Mac, etc we do not need to touch the overall process settings.
* threadpool: update calling thread prio and affinity only at start/resume
This avoids extra syscalls for each graph_compute()
* llama-bench: turn threadpool params into vectors, add output headers, etc
* llama-bench: add support for cool off between tests --delay
This helps for long running tests on platforms that are thermally limited (phones, laptops, etc).
--delay (disabled by default) introduces the sleep for N seconds before starting each test.
* threadpool: move process priority setting into the apps (bench and cli)
This avoids changing the overall process priority on Windows for the apps
that use ggml/llama.cpp directy.
* threadpool: move all pause/resume logic into ggml
* threadpool: futher api cleanup and prep for future refactoring
All threadpool related functions and structs use ggml_threadpool prefix.
* threadpool: minor indent fixes
* threadpool: improve setprioty error message
* Update examples/llama-bench/llama-bench.cpp
Co-authored-by: slaren <slarengh@gmail.com>
* threadpool: fix indent in set_threadpool call
* use int32_t for n_thread type in public llama.cpp API
* threadpool: use _new and _free instead of _create and _release
* fix two more public APIs to use int32_t for n_threads
* build: set _GNU_SOURCE for Adroid
---------
Co-authored-by: Max Krasnyansky <quic_maxk@quicinc.com>
Co-authored-by: fmz <quic_fzaghlou@quic.com>
Co-authored-by: Max Krasnyansky <max.krasnyansky@gmail.com>
Co-authored-by: slaren <slarengh@gmail.com>
2024-08-29 23:20:53 +00:00
|
|
|
if (params.draft_cpuparams.n_threads > 0) {
|
|
|
|
params.cpuparams.n_threads = params.draft_cpuparams.n_threads;
|
2024-01-16 11:04:32 +00:00
|
|
|
}
|
Threadpool: take 2 (#8672)
* Introduce ggml_compute_threadpool
- OpenMP functional: check
- Vanilla ggml functional: Check
- ggml w/threadpool functional: Check
- OpenMP no regression: No glaring problems
- Vanilla ggml no regression: No glaring problems
- ggml w/threadpool no regression: No glaring problems
* Minor fixes
* fixed use after release bug
* fixed a harmless race condition
* Fix Android bulid issue
* fix more race conditions
* fix deadlock for cases where cgraph.n_nodes == 1
and fix --poll case
* threadpool: use cpu_get_num_math to set the default number of threadpool threads
This way we avoid using E-Cores and Hyperthreaded siblings.
* bench: create fresh threadpool for each test
For benchmarking it's better to start a fresh pool for each test with the exact number of threads
needed for that test. Having larger pools is suboptimal (causes more load, etc).
* atomics: always use stdatomics with clang and use relaxed memory order when polling in ggml_barrier
This also removes sched_yield() calls from ggml_barrier() to match OpenMP behavior.
* threadpool: make polling the default to match openmp behavior
All command line args now allow for setting poll to 0 (false).
* threadpool: do not wakeup threads in already paused threadpool
* fix potential race condition in check_for_work
* threadpool: do not create two threadpools if their params are identical
* threadpool: reduce pause/resume/wakeup overhead in common cases
We now start threadpool in paused state only if we have two.
The resume is now implicit (ie new work) which allows for reduced locking and context-switch overhead.
* threadpool: add support for hybrid polling
poll params (--poll, ...) now specify "polling level", i.e. how aggresively we poll before waiting on cond.var.
poll=0 means no polling, 1 means poll for 128K rounds then wait, 2 for 256K rounds, ...
The default value of 50 (ie 50x128K rounds) seems like a decent default across modern platforms.
We can tune this further as things evolve.
* threadpool: reduce the number of barrier required
New work is now indicated with an atomic counter that is incremented for
each new graph that needs to be computed.
This removes the need for extra barrier for clearing the "new_work" and
removes the special case for trivial graphs.
* threadpool: remove special-casing for disposable threadpools
With the efficient hybrid polling there is no need to make disposable pools any different.
This simplifies the overall logic and reduces branching.
Include n_threads in debug print for disposable threadpool.
Declare pause and stop flags as atomic_bool
This doesn't actually generate any memory barriers and simply informs
the thread sanitizer that these flags can be written & read by different
threads without locking.
* threadpool: do not clear barrier counters between graphs computes (fixes race with small graphs)
This fixes the race condition with very small graphs where the main thread happens to
start a new graph while the workers are just about to exit from barriers.
* threadpool: use relaxed order for chunk sync
Full memory barrier is an overkill for this since each thread works on different chunk
* threadpool: remove abort_callback from threadpool state
* threadpool: better naming for thread/cpumask releated functions
* threadpool: consistent use of int type for n_threads params
* threadpool: add support for ggml_threadpool_params_default/init
Also removes the need for explicit mask_specified param.
all-zero cpumask means use default (usually inherited) cpu affinity mask.
* threadpool: move typedef into ggml.h
* threadpool: fix apply_priority() function name
* threadpool: fix swift wrapper errors due to n_threads int type cleanup
* threadpool: enable --cpu-mask and other threadpool related options only if threadpool is enabled
* threadpool: replace checks for compute_thread ret code with proper status check
* threadpool: simplify threadpool init logic and fix main thread affinity application
Most of the init code is now exactly the same between threadpool and openmp.
* threadpool: update threadpool resume/pause function names
* threadpool: enable openmp by default for now
* threadpool: don't forget to free workers state when omp is enabled
* threadpool: avoid updating process priority on the platforms that do not require it
On Windows we need to change overall process priority class in order to set thread priorities,
but on Linux, Mac, etc we do not need to touch the overall process settings.
* threadpool: update calling thread prio and affinity only at start/resume
This avoids extra syscalls for each graph_compute()
* llama-bench: turn threadpool params into vectors, add output headers, etc
* llama-bench: add support for cool off between tests --delay
This helps for long running tests on platforms that are thermally limited (phones, laptops, etc).
--delay (disabled by default) introduces the sleep for N seconds before starting each test.
* threadpool: move process priority setting into the apps (bench and cli)
This avoids changing the overall process priority on Windows for the apps
that use ggml/llama.cpp directy.
* threadpool: move all pause/resume logic into ggml
* threadpool: futher api cleanup and prep for future refactoring
All threadpool related functions and structs use ggml_threadpool prefix.
* threadpool: minor indent fixes
* threadpool: improve setprioty error message
* Update examples/llama-bench/llama-bench.cpp
Co-authored-by: slaren <slarengh@gmail.com>
* threadpool: fix indent in set_threadpool call
* use int32_t for n_thread type in public llama.cpp API
* threadpool: use _new and _free instead of _create and _release
* fix two more public APIs to use int32_t for n_threads
* build: set _GNU_SOURCE for Adroid
---------
Co-authored-by: Max Krasnyansky <quic_maxk@quicinc.com>
Co-authored-by: fmz <quic_fzaghlou@quic.com>
Co-authored-by: Max Krasnyansky <max.krasnyansky@gmail.com>
Co-authored-by: slaren <slarengh@gmail.com>
2024-08-29 23:20:53 +00:00
|
|
|
|
|
|
|
params.cpuparams_batch.n_threads = params.draft_cpuparams_batch.n_threads;
|
2024-10-10 20:57:42 +00:00
|
|
|
common_init_result llama_init_dft = common_init_from_params(params);
|
2024-08-05 16:14:10 +00:00
|
|
|
model_dft = llama_init_dft.model;
|
|
|
|
ctx_dft = llama_init_dft.context;
|
2023-09-03 12:12:08 +00:00
|
|
|
|
2024-04-09 17:44:08 +00:00
|
|
|
const bool vocab_type_tgt = llama_vocab_type(model_tgt);
|
2024-09-15 17:46:12 +00:00
|
|
|
LOG_DBG("vocab_type tgt: %d\n", vocab_type_tgt);
|
2024-04-09 17:44:08 +00:00
|
|
|
|
|
|
|
const bool vocab_type_dft = llama_vocab_type(model_dft);
|
2024-09-15 17:46:12 +00:00
|
|
|
LOG_DBG("vocab_type dft: %d\n", vocab_type_dft);
|
2024-04-09 17:44:08 +00:00
|
|
|
|
|
|
|
if (vocab_type_tgt != vocab_type_dft) {
|
2024-09-15 17:46:12 +00:00
|
|
|
LOG_ERR("%s: draft model vocab type must match target model to use speculation but ", __func__);
|
|
|
|
LOG_ERR("vocab_type_dft = %d while vocab_type_tgt = %d\n", vocab_type_dft, vocab_type_tgt);
|
2024-04-09 17:44:08 +00:00
|
|
|
return 1;
|
|
|
|
}
|
|
|
|
|
|
|
|
if (
|
|
|
|
llama_add_bos_token(model_tgt) != llama_add_bos_token(model_dft) ||
|
|
|
|
llama_add_eos_token(model_tgt) != llama_add_eos_token(model_dft) ||
|
|
|
|
llama_token_bos(model_tgt) != llama_token_bos(model_dft) ||
|
|
|
|
llama_token_eos(model_tgt) != llama_token_eos(model_dft)
|
|
|
|
) {
|
2024-09-15 17:46:12 +00:00
|
|
|
LOG_ERR("%s: draft model special tokens must match target model to use speculation\n", __func__);
|
2024-04-09 17:44:08 +00:00
|
|
|
return 1;
|
|
|
|
}
|
|
|
|
|
2023-10-27 21:40:07 +00:00
|
|
|
{
|
|
|
|
const int n_vocab_tgt = llama_n_vocab(model_tgt);
|
|
|
|
const int n_vocab_dft = llama_n_vocab(model_dft);
|
|
|
|
const int vocab_diff = n_vocab_tgt > n_vocab_dft
|
|
|
|
? n_vocab_tgt - n_vocab_dft
|
|
|
|
: n_vocab_dft - n_vocab_tgt;
|
|
|
|
|
|
|
|
if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
|
2024-09-15 17:46:12 +00:00
|
|
|
LOG_ERR("%s: draft model vocab must closely match target model to use speculation but ", __func__);
|
|
|
|
LOG_ERR("target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
|
2023-10-27 21:40:07 +00:00
|
|
|
n_vocab_tgt, llama_n_vocab(model_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
|
|
|
|
return 1;
|
|
|
|
}
|
|
|
|
|
|
|
|
for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) {
|
|
|
|
const char * token_text_tgt = llama_token_get_text(model_tgt, i);
|
|
|
|
const char * token_text_dft = llama_token_get_text(model_dft, i);
|
|
|
|
if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
|
2024-09-15 17:46:12 +00:00
|
|
|
LOG_ERR("%s: draft model vocab must match target model to use speculation but ", __func__);
|
|
|
|
LOG_ERR("token %d content differs - target '%s', draft '%s'\n", i,
|
2024-10-10 20:57:42 +00:00
|
|
|
common_token_to_piece(ctx_tgt, i).c_str(),
|
|
|
|
common_token_to_piece(ctx_dft, i).c_str());
|
2023-10-27 21:40:07 +00:00
|
|
|
return 1;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-11-20 09:50:04 +00:00
|
|
|
|
|
|
|
// Tokenize the prompt
|
2023-09-03 12:12:08 +00:00
|
|
|
std::vector<llama_token> inp;
|
2024-10-10 20:57:42 +00:00
|
|
|
inp = common_tokenize(ctx_tgt, params.prompt, true, true);
|
2023-09-03 12:12:08 +00:00
|
|
|
|
|
|
|
const int max_context_size = llama_n_ctx(ctx_tgt);
|
|
|
|
const int max_tokens_list_size = max_context_size - 4;
|
|
|
|
|
|
|
|
if ((int) inp.size() > max_tokens_list_size) {
|
2024-09-15 17:46:12 +00:00
|
|
|
LOG_ERR("%s: prompt too long (%d tokens, max %d)\n", __func__, (int) inp.size(), max_tokens_list_size);
|
2023-09-03 12:12:08 +00:00
|
|
|
return 1;
|
|
|
|
}
|
|
|
|
|
2024-09-15 17:46:12 +00:00
|
|
|
LOG("\n\n");
|
2023-09-03 12:12:08 +00:00
|
|
|
|
|
|
|
for (auto id : inp) {
|
2024-10-10 20:57:42 +00:00
|
|
|
LOG("%s", common_token_to_piece(ctx_tgt, id).c_str());
|
2023-09-03 12:12:08 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
const int n_input = inp.size();
|
|
|
|
|
|
|
|
const auto t_enc_start = ggml_time_us();
|
|
|
|
|
|
|
|
// eval the prompt with both models
|
2024-10-18 21:18:01 +00:00
|
|
|
llama_decode(ctx_tgt, llama_batch_get_one( inp.data(), n_input - 1));
|
|
|
|
llama_decode(ctx_tgt, llama_batch_get_one(&inp.back(), 1));
|
|
|
|
llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input));
|
2023-09-03 12:12:08 +00:00
|
|
|
|
|
|
|
const auto t_enc_end = ggml_time_us();
|
|
|
|
|
|
|
|
// the 2 models should have the same vocab
|
2023-09-28 19:42:38 +00:00
|
|
|
//GGML_ASSERT(n_vocab == llama_n_vocab(model_dft));
|
2023-09-03 12:12:08 +00:00
|
|
|
|
|
|
|
// how many tokens to draft each time
|
2023-09-14 16:14:44 +00:00
|
|
|
int n_draft = params.n_draft;
|
2023-09-03 12:12:08 +00:00
|
|
|
|
|
|
|
int n_predict = 0;
|
|
|
|
int n_drafted = 0;
|
|
|
|
int n_accept = 0;
|
|
|
|
|
|
|
|
int n_past_tgt = inp.size();
|
|
|
|
int n_past_dft = inp.size();
|
|
|
|
|
|
|
|
// used to determine end of generation
|
|
|
|
bool has_eos = false;
|
|
|
|
|
2024-09-07 12:16:19 +00:00
|
|
|
// target model sampling context (reuse the llama_context's sampling instance)
|
2024-10-10 20:57:42 +00:00
|
|
|
struct common_sampler * smpl = common_sampler_init(model_tgt, params.sparams);
|
2024-09-07 12:16:19 +00:00
|
|
|
|
2023-10-18 13:21:57 +00:00
|
|
|
// draft sequence data
|
|
|
|
std::vector<seq_draft> drafts(n_seq_dft);
|
2023-09-05 05:46:17 +00:00
|
|
|
|
2023-10-18 13:21:57 +00:00
|
|
|
for (int s = 0; s < n_seq_dft; ++s) {
|
2024-10-10 20:57:42 +00:00
|
|
|
// allocate llama_sampler for each draft sequence
|
|
|
|
drafts[s].smpl = common_sampler_init(model_dft, params.sparams);
|
2023-09-05 05:46:17 +00:00
|
|
|
}
|
|
|
|
|
2024-10-21 06:37:12 +00:00
|
|
|
llama_batch batch_dft = llama_batch_init(llama_n_batch(ctx_dft), 0, 1);
|
|
|
|
llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, n_seq_dft);
|
2023-10-11 19:35:46 +00:00
|
|
|
|
2023-09-03 12:12:08 +00:00
|
|
|
const auto t_dec_start = ggml_time_us();
|
|
|
|
|
2023-10-18 13:21:57 +00:00
|
|
|
// sample from the last token of the prompt
|
|
|
|
drafts[0].i_batch_tgt.resize(1);
|
|
|
|
drafts[0].i_batch_tgt[0] = 0;
|
|
|
|
|
2023-09-03 12:12:08 +00:00
|
|
|
while (true) {
|
2024-03-04 18:24:00 +00:00
|
|
|
std::set<int> active_seqs = {};
|
|
|
|
|
2023-10-18 13:21:57 +00:00
|
|
|
// print current draft sequences
|
|
|
|
for (int s = 0; s < n_seq_dft; ++s) {
|
|
|
|
if (!drafts[s].active) {
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
|
2024-03-04 18:24:00 +00:00
|
|
|
active_seqs.insert(s);
|
2023-10-18 13:21:57 +00:00
|
|
|
const auto & tokens = drafts[s].tokens;
|
2023-09-03 12:12:08 +00:00
|
|
|
|
2024-09-15 17:46:12 +00:00
|
|
|
LOG_DBG("draft %d: %s\n", s, string_from(ctx_dft, tokens).c_str());
|
2023-10-18 13:21:57 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
int i_dft = 0;
|
|
|
|
int s_keep = 0;
|
2023-09-14 16:14:44 +00:00
|
|
|
|
2024-03-04 18:24:00 +00:00
|
|
|
llama_token token_id;
|
|
|
|
std::string token_str;
|
|
|
|
|
|
|
|
// loop until we fail to accept a drafted token or we run out of drafted tokens
|
2023-09-03 12:12:08 +00:00
|
|
|
while (true) {
|
2023-10-18 13:21:57 +00:00
|
|
|
|
2024-03-04 18:24:00 +00:00
|
|
|
// check if the target token matches any of the drafts
|
|
|
|
// for stochastic sampling, attempt to match the token with the drafted tokens
|
|
|
|
{
|
|
|
|
bool accept = false;
|
|
|
|
if (params.sparams.temp > 0) {
|
|
|
|
// stochastic verification
|
2024-10-10 20:57:42 +00:00
|
|
|
common_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft], true);
|
2024-03-04 18:24:00 +00:00
|
|
|
|
2024-10-10 20:57:42 +00:00
|
|
|
auto & dist_tgt = *common_sampler_get_candidates(smpl);
|
2024-03-04 18:24:00 +00:00
|
|
|
|
2024-09-07 12:16:19 +00:00
|
|
|
float p_tgt = 0.0f;
|
|
|
|
float p_dft = 0.0f;
|
2024-03-04 18:24:00 +00:00
|
|
|
|
|
|
|
while (active_seqs.size() > 0) {
|
|
|
|
// randomly select a sequence to verify from active sequences
|
2024-03-05 03:23:06 +00:00
|
|
|
std::uniform_int_distribution<unsigned int> u_int_dist(0, active_seqs.size() - 1);
|
2024-03-04 18:24:00 +00:00
|
|
|
int s = *std::next(active_seqs.begin(), u_int_dist(rng));
|
|
|
|
if (i_dft >= (int) drafts[s].tokens.size()) {
|
|
|
|
drafts[s].active = false;
|
|
|
|
active_seqs.erase(s);
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
if (accept) {
|
|
|
|
// if we already accepted a token, we can skip the rest
|
|
|
|
if (drafts[s].tokens[i_dft] != drafts[s_keep].tokens[i_dft]) {
|
|
|
|
drafts[s].active = false;
|
|
|
|
active_seqs.erase(s);
|
|
|
|
}
|
|
|
|
continue;
|
|
|
|
}
|
2024-09-07 12:16:19 +00:00
|
|
|
|
2024-09-15 17:46:12 +00:00
|
|
|
LOG_DBG("verifying sequence #%d at pos #%d from %d active sequence(s)\n", s, i_dft, (int) active_seqs.size());
|
2024-03-04 18:24:00 +00:00
|
|
|
float r = u_dist(rng);
|
2024-09-07 12:16:19 +00:00
|
|
|
llama_token_data_array dist_dft = { drafts[s].dists[i_dft].data() , drafts[s].dists[i_dft].size(), LLAMA_TOKEN_NULL, true };
|
|
|
|
|
|
|
|
//GGML_ASSERT(dist_tgt.size <= dist_dft.size);
|
|
|
|
|
2024-03-04 18:24:00 +00:00
|
|
|
// acquire the token probabilities assigned by the draft and target models
|
|
|
|
for (size_t i = 0; i < dist_tgt.size; i++) {
|
|
|
|
if (dist_tgt.data[i].id == drafts[s].tokens[i_dft]) {
|
|
|
|
p_tgt = dist_tgt.data[i].p;
|
2024-11-14 09:23:45 +00:00
|
|
|
break;
|
2024-03-04 18:24:00 +00:00
|
|
|
}
|
2024-11-14 09:23:45 +00:00
|
|
|
}
|
|
|
|
for (size_t i = 0; i < dist_dft.size; i++) {
|
2024-03-04 18:24:00 +00:00
|
|
|
if (dist_dft.data[i].id == drafts[s].tokens[i_dft]) {
|
|
|
|
p_dft = dist_dft.data[i].p;
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
2024-09-15 17:46:12 +00:00
|
|
|
LOG_DBG("r = %f, p_dft = %f, p_tgt = %f\n", r, p_dft, p_tgt);
|
2024-03-04 18:24:00 +00:00
|
|
|
if (r <= p_tgt / p_dft) {
|
|
|
|
s_keep = s;
|
|
|
|
accept = true;
|
|
|
|
token_id = drafts[s].tokens[i_dft];
|
2024-10-10 20:57:42 +00:00
|
|
|
token_str = common_token_to_piece(ctx_tgt, token_id);
|
|
|
|
common_sampler_accept(smpl, token_id, true);
|
2024-03-04 18:24:00 +00:00
|
|
|
|
2024-09-15 17:46:12 +00:00
|
|
|
LOG_DBG("draft token %d of sequence %d (%d, '%s') accepted\n", i_dft, s, token_id, token_str.c_str());
|
2024-03-04 18:24:00 +00:00
|
|
|
break;
|
|
|
|
} else {
|
2024-10-10 20:57:42 +00:00
|
|
|
LOG_DBG("draft token %d of sequence %d (%d, '%s') rejected\n", i_dft, s, drafts[s].tokens[i_dft], common_token_to_piece(ctx_tgt, drafts[s].tokens[i_dft]).c_str());
|
2024-03-04 18:24:00 +00:00
|
|
|
drafts[s].active = false;
|
|
|
|
|
|
|
|
// calculate residual probability
|
|
|
|
GGML_ASSERT(dist_tgt.sorted);
|
|
|
|
GGML_ASSERT(dist_dft.sorted);
|
|
|
|
|
|
|
|
// sort dist by id
|
|
|
|
std::sort(dist_tgt.data, dist_tgt.data + dist_tgt.size, [](const llama_token_data &a, const llama_token_data &b) {
|
|
|
|
return a.id < b.id;
|
|
|
|
});
|
|
|
|
std::sort(dist_dft.data, dist_dft.data + dist_dft.size, [](const llama_token_data &a, const llama_token_data &b) {
|
|
|
|
return a.id < b.id;
|
|
|
|
});
|
|
|
|
|
2024-09-07 12:16:19 +00:00
|
|
|
float sum_probs = 0.0f;
|
|
|
|
|
2024-03-04 18:24:00 +00:00
|
|
|
for (size_t i = 0; i < dist_tgt.size; i++) {
|
2024-09-07 12:16:19 +00:00
|
|
|
if (i < dist_dft.size) {
|
|
|
|
dist_tgt.data[i].p = std::max(0.0f, dist_tgt.data[i].p - dist_dft.data[i].p);
|
|
|
|
} else {
|
|
|
|
dist_tgt.data[i].p = std::max(0.0f, dist_tgt.data[i].p);
|
|
|
|
}
|
|
|
|
|
2024-03-04 18:24:00 +00:00
|
|
|
sum_probs += dist_tgt.data[i].p;
|
|
|
|
}
|
2024-09-07 12:16:19 +00:00
|
|
|
|
2024-03-04 18:24:00 +00:00
|
|
|
for (size_t i = 0; i < dist_tgt.size; i++) {
|
|
|
|
dist_tgt.data[i].p /= sum_probs;
|
|
|
|
}
|
2023-09-03 12:12:08 +00:00
|
|
|
|
2024-03-04 18:24:00 +00:00
|
|
|
// sort dist_tgt by p desc
|
|
|
|
std::sort(dist_tgt.data, dist_tgt.data + dist_tgt.size, [](const llama_token_data &a, const llama_token_data &b) {
|
|
|
|
return a.p > b.p;
|
|
|
|
});
|
|
|
|
}
|
2023-09-03 12:12:08 +00:00
|
|
|
|
2024-03-04 18:24:00 +00:00
|
|
|
active_seqs.erase(s);
|
|
|
|
for(int i = 0; i < n_seq_dft; i++) {
|
|
|
|
if (i == s) {
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
if (drafts[i].tokens[i_dft] == drafts[s].tokens[i_dft]) {
|
|
|
|
// synchronize active status for sequences with the same drafted token
|
|
|
|
drafts[i].active = drafts[i].active && accept;
|
|
|
|
if (!drafts[i].active) {
|
|
|
|
active_seqs.erase(s);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
2023-09-03 12:12:08 +00:00
|
|
|
|
2024-03-04 18:24:00 +00:00
|
|
|
if (!accept) {
|
|
|
|
// all drafted tokens were rejected
|
|
|
|
// sample from the target model
|
2024-09-15 17:46:12 +00:00
|
|
|
LOG_DBG("all drafted tokens were rejected, sampling from residual distribution\n");
|
2024-09-07 12:16:19 +00:00
|
|
|
std::vector<float> probs(dist_tgt.size);
|
|
|
|
for (size_t i = 0; i < dist_tgt.size; ++i) {
|
|
|
|
probs[i] = dist_tgt.data[i].p;
|
|
|
|
}
|
|
|
|
|
|
|
|
std::discrete_distribution<> dist(probs.begin(), probs.end());
|
|
|
|
|
|
|
|
const int idx = dist(rng);
|
|
|
|
|
|
|
|
token_id = dist_tgt.data[idx].id;
|
2024-10-10 20:57:42 +00:00
|
|
|
common_sampler_accept(smpl, token_id, true);
|
|
|
|
token_str = common_token_to_piece(ctx_tgt, token_id);
|
2024-03-04 18:24:00 +00:00
|
|
|
}
|
|
|
|
} else {
|
|
|
|
// greedy verification
|
2023-09-03 12:12:08 +00:00
|
|
|
|
2024-03-04 18:24:00 +00:00
|
|
|
// sample from the target model
|
2024-09-15 17:46:12 +00:00
|
|
|
LOG_DBG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]);
|
2024-10-10 20:57:42 +00:00
|
|
|
token_id = common_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft]);
|
2023-09-03 12:12:08 +00:00
|
|
|
|
2024-10-10 20:57:42 +00:00
|
|
|
common_sampler_accept(smpl, token_id, true);
|
2023-09-03 12:12:08 +00:00
|
|
|
|
2024-10-10 20:57:42 +00:00
|
|
|
token_str = common_token_to_piece(ctx_tgt, token_id);
|
2023-10-18 13:21:57 +00:00
|
|
|
|
2024-03-04 18:24:00 +00:00
|
|
|
for (int s = 0; s < n_seq_dft; ++s) {
|
|
|
|
if (!drafts[s].active) {
|
|
|
|
continue;
|
|
|
|
}
|
2023-10-18 13:21:57 +00:00
|
|
|
|
2024-03-04 18:24:00 +00:00
|
|
|
if (i_dft < (int) drafts[s].tokens.size() && token_id == drafts[s].tokens[i_dft]) {
|
2024-09-15 17:46:12 +00:00
|
|
|
LOG_DBG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, s, token_id, token_str.c_str());
|
2024-03-04 18:24:00 +00:00
|
|
|
|
|
|
|
s_keep = s;
|
|
|
|
accept = true;
|
|
|
|
} else {
|
|
|
|
drafts[s].active = false;
|
|
|
|
}
|
2023-10-18 13:21:57 +00:00
|
|
|
}
|
|
|
|
}
|
2023-09-03 12:12:08 +00:00
|
|
|
|
2024-04-21 11:50:41 +00:00
|
|
|
if (llama_token_is_eog(model_tgt, token_id)) {
|
2024-03-04 18:24:00 +00:00
|
|
|
has_eos = true;
|
|
|
|
}
|
|
|
|
++n_predict;
|
|
|
|
|
|
|
|
if (accept) {
|
2023-10-18 13:21:57 +00:00
|
|
|
++n_accept;
|
|
|
|
++n_past_tgt;
|
|
|
|
++n_past_dft;
|
|
|
|
++i_dft;
|
2023-12-06 08:08:17 +00:00
|
|
|
if (params.use_color) {
|
|
|
|
// Color token according to its origin sequence
|
2024-09-15 17:46:12 +00:00
|
|
|
LOG("\u001b[%dm%s\u001b[37m", (36 - s_keep % 6), token_str.c_str());
|
2024-03-04 18:24:00 +00:00
|
|
|
} else {
|
2024-09-15 17:46:12 +00:00
|
|
|
LOG("%s", token_str.c_str());
|
2023-12-06 08:08:17 +00:00
|
|
|
}
|
2023-10-18 13:21:57 +00:00
|
|
|
continue;
|
2024-03-04 18:24:00 +00:00
|
|
|
} else {
|
2024-09-15 17:46:12 +00:00
|
|
|
LOG("%s", token_str.c_str());
|
2024-03-04 18:24:00 +00:00
|
|
|
break;
|
2023-10-18 13:21:57 +00:00
|
|
|
}
|
2023-09-05 05:46:17 +00:00
|
|
|
}
|
2024-03-04 18:24:00 +00:00
|
|
|
}
|
2023-09-05 05:46:17 +00:00
|
|
|
|
2024-03-04 18:24:00 +00:00
|
|
|
{
|
2024-09-15 17:46:12 +00:00
|
|
|
LOG_DBG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", token_id, token_str.c_str());
|
2023-09-03 12:12:08 +00:00
|
|
|
|
2023-10-18 13:21:57 +00:00
|
|
|
// TODO: simplify
|
2023-09-14 16:14:44 +00:00
|
|
|
{
|
2024-09-15 17:46:12 +00:00
|
|
|
LOG_DBG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft);
|
2023-10-18 13:21:57 +00:00
|
|
|
|
|
|
|
llama_kv_cache_seq_keep(ctx_dft, s_keep);
|
|
|
|
llama_kv_cache_seq_cp (ctx_dft, s_keep, 0, -1, -1);
|
|
|
|
llama_kv_cache_seq_keep(ctx_dft, 0);
|
|
|
|
|
|
|
|
llama_kv_cache_seq_rm (ctx_tgt, s_keep, n_past_tgt, -1);
|
|
|
|
llama_kv_cache_seq_keep(ctx_tgt, s_keep);
|
|
|
|
llama_kv_cache_seq_cp (ctx_tgt, s_keep, 0, -1, -1);
|
|
|
|
llama_kv_cache_seq_keep(ctx_tgt, 0);
|
2023-09-14 16:14:44 +00:00
|
|
|
}
|
|
|
|
|
2023-10-18 13:21:57 +00:00
|
|
|
for (int s = 0; s < n_seq_dft; ++s) {
|
|
|
|
drafts[s].active = false;
|
|
|
|
drafts[s].tokens.clear();
|
|
|
|
drafts[s].i_batch_tgt.clear();
|
2024-03-04 18:24:00 +00:00
|
|
|
drafts[s].dists.clear();
|
2023-10-18 13:21:57 +00:00
|
|
|
}
|
|
|
|
// note: will be erased after the speculation phase
|
2024-03-04 18:24:00 +00:00
|
|
|
drafts[0].tokens.push_back(token_id);
|
|
|
|
drafts[0].dists.push_back(std::vector<llama_token_data>());
|
2023-10-18 13:21:57 +00:00
|
|
|
drafts[0].i_batch_tgt.push_back(0);
|
|
|
|
|
2024-10-10 20:57:42 +00:00
|
|
|
common_batch_clear(batch_dft);
|
|
|
|
common_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true);
|
2023-10-18 13:21:57 +00:00
|
|
|
|
|
|
|
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
|
2024-09-15 17:46:12 +00:00
|
|
|
// LOG_DBG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str());
|
2024-03-04 18:24:00 +00:00
|
|
|
llama_decode(ctx_dft, batch_dft);
|
2023-10-18 13:21:57 +00:00
|
|
|
|
|
|
|
++n_past_dft;
|
2023-09-03 12:12:08 +00:00
|
|
|
}
|
|
|
|
|
2024-10-21 06:37:12 +00:00
|
|
|
if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
|
2023-09-03 12:12:08 +00:00
|
|
|
break;
|
|
|
|
}
|
|
|
|
|
2024-09-07 12:16:19 +00:00
|
|
|
if (drafts[0].smpl) {
|
2024-10-10 20:57:42 +00:00
|
|
|
common_sampler_free(drafts[0].smpl);
|
2024-09-07 12:16:19 +00:00
|
|
|
}
|
2024-10-10 20:57:42 +00:00
|
|
|
drafts[0].smpl = common_sampler_clone(smpl);
|
2023-09-05 05:46:17 +00:00
|
|
|
|
2023-10-18 13:21:57 +00:00
|
|
|
int n_seq_cur = 1;
|
2023-09-03 12:12:08 +00:00
|
|
|
int n_past_cur = n_past_dft;
|
2023-10-18 13:21:57 +00:00
|
|
|
|
|
|
|
for (int s = 0; s < n_seq_dft; ++s) {
|
|
|
|
drafts[s].active = false;
|
|
|
|
drafts[s].drafting = false;
|
|
|
|
}
|
|
|
|
drafts[0].active = true;
|
|
|
|
drafts[0].drafting = true;
|
|
|
|
drafts[0].i_batch_dft = 0;
|
|
|
|
|
2024-10-10 20:57:42 +00:00
|
|
|
common_batch_clear(batch_tgt);
|
|
|
|
common_batch_add (batch_tgt, drafts[0].tokens[0], n_past_tgt, { 0 }, true);
|
2023-10-18 13:21:57 +00:00
|
|
|
|
|
|
|
// sample n_draft tokens from the draft model using tree-based sampling
|
2023-09-03 12:12:08 +00:00
|
|
|
for (int i = 0; i < n_draft; ++i) {
|
2023-10-18 13:21:57 +00:00
|
|
|
batch_dft.n_tokens = 0;
|
2023-09-03 12:12:08 +00:00
|
|
|
|
2023-10-18 13:21:57 +00:00
|
|
|
for (int s = 0; s < n_seq_dft; ++s) {
|
|
|
|
drafts[s].skip = false;
|
2023-09-03 12:12:08 +00:00
|
|
|
}
|
|
|
|
|
2023-10-18 13:21:57 +00:00
|
|
|
for (int s = 0; s < n_seq_dft; ++s) {
|
|
|
|
if (!drafts[s].drafting || drafts[s].skip) {
|
|
|
|
continue;
|
|
|
|
}
|
2023-09-03 12:12:08 +00:00
|
|
|
|
2024-10-10 20:57:42 +00:00
|
|
|
common_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft, true);
|
2023-10-18 13:21:57 +00:00
|
|
|
|
2024-10-10 20:57:42 +00:00
|
|
|
const auto * cur_p = common_sampler_get_candidates(drafts[s].smpl);
|
2023-10-18 13:21:57 +00:00
|
|
|
|
2024-09-07 12:16:19 +00:00
|
|
|
for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p->size); ++k) {
|
2024-09-15 17:46:12 +00:00
|
|
|
LOG_DBG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n",
|
2024-10-10 20:57:42 +00:00
|
|
|
k, s, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str());
|
2023-10-18 13:21:57 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
std::vector<int> sa(1, s);
|
|
|
|
|
|
|
|
// attempt to split the branch if the probability is high enough
|
|
|
|
for (int f = 1; f < 8; ++f) {
|
2024-09-07 12:16:19 +00:00
|
|
|
if (n_seq_cur < n_seq_dft && cur_p->data[f].p > p_split) {
|
2024-09-15 17:46:12 +00:00
|
|
|
LOG_DBG("splitting seq %3d into %3d\n", s, n_seq_cur);
|
2023-10-18 13:21:57 +00:00
|
|
|
|
|
|
|
llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1);
|
|
|
|
llama_kv_cache_seq_cp(ctx_dft, s, n_seq_cur, -1, -1);
|
|
|
|
|
|
|
|
// all previous tokens from this branch are now also part of the new branch
|
|
|
|
for (int t = 0; t < batch_tgt.n_tokens; ++t) {
|
|
|
|
for (int p = 0; p < batch_tgt.n_seq_id[t]; ++p) {
|
|
|
|
if (batch_tgt.seq_id[t][p] == s) {
|
|
|
|
batch_tgt.seq_id[t][batch_tgt.n_seq_id[t]] = n_seq_cur;
|
|
|
|
batch_tgt.n_seq_id[t]++;
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// copy the draft state
|
|
|
|
drafts[n_seq_cur].active = true;
|
|
|
|
drafts[n_seq_cur].drafting = true;
|
|
|
|
drafts[n_seq_cur].skip = true;
|
|
|
|
|
|
|
|
drafts[n_seq_cur].tokens = drafts[s].tokens;
|
2024-03-04 18:24:00 +00:00
|
|
|
drafts[n_seq_cur].dists = drafts[s].dists;
|
2023-10-18 13:21:57 +00:00
|
|
|
drafts[n_seq_cur].i_batch_dft = drafts[s].i_batch_dft;
|
|
|
|
drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt;
|
|
|
|
|
2024-09-07 12:16:19 +00:00
|
|
|
if (drafts[n_seq_cur].smpl) {
|
2024-10-10 20:57:42 +00:00
|
|
|
common_sampler_free(drafts[n_seq_cur].smpl);
|
2024-09-07 12:16:19 +00:00
|
|
|
}
|
2024-10-10 20:57:42 +00:00
|
|
|
drafts[n_seq_cur].smpl = common_sampler_clone(drafts[s].smpl);
|
2023-10-18 13:21:57 +00:00
|
|
|
|
|
|
|
sa.push_back(n_seq_cur);
|
|
|
|
|
|
|
|
n_seq_cur++;
|
|
|
|
} else {
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// add drafted token for each sequence
|
|
|
|
for (int is = 0; is < (int) sa.size(); ++is) {
|
2024-09-07 12:16:19 +00:00
|
|
|
const llama_token id = cur_p->data[is].id;
|
2023-10-18 13:21:57 +00:00
|
|
|
|
|
|
|
const int s = sa[is];
|
|
|
|
|
2024-10-10 20:57:42 +00:00
|
|
|
common_sampler_accept(drafts[s].smpl, id, true);
|
2023-09-05 05:46:17 +00:00
|
|
|
|
2023-10-18 13:21:57 +00:00
|
|
|
drafts[s].tokens.push_back(id);
|
2024-03-04 18:24:00 +00:00
|
|
|
// save cur_p.data into drafts[s].dists
|
2024-09-07 12:16:19 +00:00
|
|
|
drafts[s].dists.push_back({cur_p->data, cur_p->data + cur_p->size});
|
2023-09-03 12:12:08 +00:00
|
|
|
|
2023-10-18 13:21:57 +00:00
|
|
|
// add unique drafted tokens to the target batch
|
|
|
|
drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens);
|
|
|
|
|
2024-10-10 20:57:42 +00:00
|
|
|
common_batch_add(batch_tgt, id, n_past_tgt + i + 1, { s }, true);
|
2023-10-18 13:21:57 +00:00
|
|
|
|
|
|
|
// add the token to the batch for batched decoding with the draft model
|
|
|
|
drafts[s].i_batch_dft = batch_dft.n_tokens;
|
|
|
|
|
2024-10-10 20:57:42 +00:00
|
|
|
common_batch_add(batch_dft, id, n_past_cur, { s }, true);
|
2023-10-18 15:49:40 +00:00
|
|
|
|
|
|
|
if (batch_tgt.n_tokens > n_draft) {
|
|
|
|
drafts[s].drafting = false;
|
|
|
|
}
|
2023-10-18 13:21:57 +00:00
|
|
|
}
|
2023-09-03 12:12:08 +00:00
|
|
|
}
|
|
|
|
|
2023-10-18 13:21:57 +00:00
|
|
|
// no sequence is drafting anymore
|
|
|
|
if (batch_dft.n_tokens == 0) {
|
2023-09-03 12:12:08 +00:00
|
|
|
break;
|
|
|
|
}
|
|
|
|
|
2023-10-18 13:21:57 +00:00
|
|
|
// evaluate the drafted tokens on the draft model
|
|
|
|
llama_decode(ctx_dft, batch_dft);
|
|
|
|
++n_past_cur;
|
2023-09-03 12:12:08 +00:00
|
|
|
++n_drafted;
|
|
|
|
|
2023-10-18 13:21:57 +00:00
|
|
|
if (batch_tgt.n_tokens > n_draft) {
|
2023-09-05 05:46:17 +00:00
|
|
|
break;
|
|
|
|
}
|
2023-10-18 13:21:57 +00:00
|
|
|
}
|
2023-09-05 05:46:17 +00:00
|
|
|
|
2023-10-18 13:21:57 +00:00
|
|
|
// evaluate the target model on the drafted tokens
|
|
|
|
{
|
|
|
|
llama_kv_cache_seq_keep(ctx_tgt, 0);
|
|
|
|
for (int s = 1; s < n_seq_dft; ++s) {
|
|
|
|
llama_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1);
|
2023-09-03 12:12:08 +00:00
|
|
|
}
|
2023-10-18 13:21:57 +00:00
|
|
|
|
2024-09-15 17:46:12 +00:00
|
|
|
// LOG_DBG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str());
|
2023-10-18 13:21:57 +00:00
|
|
|
llama_decode(ctx_tgt, batch_tgt);
|
|
|
|
++n_past_tgt;
|
2023-09-03 12:12:08 +00:00
|
|
|
}
|
|
|
|
|
2023-12-12 09:53:36 +00:00
|
|
|
// the first token is always proposed by the target model before the speculation loop so we erase it here
|
2023-10-18 13:21:57 +00:00
|
|
|
for (int s = 0; s < n_seq_dft; ++s) {
|
|
|
|
if (!drafts[s].active) {
|
|
|
|
continue;
|
|
|
|
}
|
2023-09-03 12:12:08 +00:00
|
|
|
|
2023-10-18 13:21:57 +00:00
|
|
|
drafts[s].tokens.erase(drafts[s].tokens.begin());
|
2024-03-04 18:24:00 +00:00
|
|
|
drafts[s].dists.erase(drafts[s].dists.begin());
|
2023-10-18 13:21:57 +00:00
|
|
|
}
|
2023-09-03 12:12:08 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
auto t_dec_end = ggml_time_us();
|
|
|
|
|
2024-09-15 17:46:12 +00:00
|
|
|
LOG("\n\n");
|
2023-09-03 12:12:08 +00:00
|
|
|
|
2024-09-15 17:46:12 +00:00
|
|
|
LOG_INF("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f));
|
|
|
|
LOG_INF("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f));
|
2023-09-03 12:12:08 +00:00
|
|
|
|
2024-09-15 17:46:12 +00:00
|
|
|
LOG_INF("\n");
|
|
|
|
LOG_INF("n_draft = %d\n", n_draft);
|
|
|
|
LOG_INF("n_predict = %d\n", n_predict);
|
|
|
|
LOG_INF("n_drafted = %d\n", n_drafted);
|
|
|
|
LOG_INF("n_accept = %d\n", n_accept);
|
|
|
|
LOG_INF("accept = %.3f%%\n", 100.0f * n_accept / n_drafted);
|
2023-09-03 12:12:08 +00:00
|
|
|
|
2024-09-15 17:46:12 +00:00
|
|
|
LOG_INF("\n");
|
|
|
|
LOG_INF("draft:\n\n");
|
2024-09-07 12:16:19 +00:00
|
|
|
// TODO: print sampling/grammar timings for all drafts
|
2024-09-13 06:53:38 +00:00
|
|
|
llama_perf_context_print(ctx_dft);
|
2023-09-03 12:12:08 +00:00
|
|
|
|
2024-09-15 17:46:12 +00:00
|
|
|
LOG_INF("\n");
|
|
|
|
LOG_INF("target:\n\n");
|
2024-10-10 20:57:42 +00:00
|
|
|
common_perf_print(ctx_tgt, smpl);
|
2023-09-03 12:12:08 +00:00
|
|
|
|
2024-10-10 20:57:42 +00:00
|
|
|
common_sampler_free(smpl);
|
2023-10-18 13:21:57 +00:00
|
|
|
for (int s = 0; s < n_seq_dft; ++s) {
|
2024-10-10 20:57:42 +00:00
|
|
|
common_sampler_free(drafts[s].smpl);
|
2023-10-18 13:21:57 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
llama_batch_free(batch_dft);
|
|
|
|
|
2023-09-03 12:12:08 +00:00
|
|
|
llama_free(ctx_tgt);
|
|
|
|
llama_free_model(model_tgt);
|
|
|
|
|
|
|
|
llama_free(ctx_dft);
|
|
|
|
llama_free_model(model_dft);
|
|
|
|
|
|
|
|
llama_backend_free();
|
|
|
|
|
2024-09-15 17:46:12 +00:00
|
|
|
LOG("\n\n");
|
2023-09-03 12:12:08 +00:00
|
|
|
|
|
|
|
return 0;
|
|
|
|
}
|