mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-13 04:00:16 +00:00
speculative : fix the n_drafted fix + p constants
This commit is contained in:
parent
f07cd35da4
commit
e6dd81f0bc
@ -36,6 +36,10 @@ int main(int argc, char ** argv) {
|
|||||||
// max number of parallel drafting sequences (i.e. tree branches)
|
// max number of parallel drafting sequences (i.e. tree branches)
|
||||||
const int n_seq_dft = params.n_parallel;
|
const int n_seq_dft = params.n_parallel;
|
||||||
|
|
||||||
|
// TODO: make this configurable
|
||||||
|
const float p_accept = 0.4f;
|
||||||
|
const float p_split = 0.3f;
|
||||||
|
|
||||||
#ifndef LOG_DISABLE_LOGS
|
#ifndef LOG_DISABLE_LOGS
|
||||||
log_set_target(log_filename_generator("speculative", "log"));
|
log_set_target(log_filename_generator("speculative", "log"));
|
||||||
LOG_TEE("Log start\n");
|
LOG_TEE("Log start\n");
|
||||||
@ -272,8 +276,7 @@ int main(int argc, char ** argv) {
|
|||||||
k, s, i, cur_p[k].id, cur_p[k].p, llama_token_to_piece(ctx_dft, cur_p[k].id).c_str());
|
k, s, i, cur_p[k].id, cur_p[k].p, llama_token_to_piece(ctx_dft, cur_p[k].id).c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: make this configurable
|
if (cur_p[0].p < p_accept) {
|
||||||
if (cur_p[0].p < 0.4) {
|
|
||||||
LOG("stopping drafting for seq %3d, probability too low: %.3f < 2*%.3f\n", s, cur_p[0].p, cur_p[1].p);
|
LOG("stopping drafting for seq %3d, probability too low: %.3f < 2*%.3f\n", s, cur_p[0].p, cur_p[1].p);
|
||||||
drafts[s].drafting = false;
|
drafts[s].drafting = false;
|
||||||
continue;
|
continue;
|
||||||
@ -283,8 +286,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
// attempt to split the branch if the probability is high enough
|
// attempt to split the branch if the probability is high enough
|
||||||
for (int f = 1; f < 8; ++f) {
|
for (int f = 1; f < 8; ++f) {
|
||||||
// TODO: make this configurable
|
if (n_seq_cur < n_seq_dft && cur_p[f].p > p_split) {
|
||||||
if (n_seq_cur < n_seq_dft && cur_p[f].p > 0.3) {
|
|
||||||
LOG("splitting seq %3d into %3d\n", s, n_seq_cur);
|
LOG("splitting seq %3d into %3d\n", s, n_seq_cur);
|
||||||
|
|
||||||
llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1);
|
llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1);
|
||||||
@ -364,7 +366,9 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// account for the last drafted token that we didn't evaluate
|
// account for the last drafted token that we didn't evaluate
|
||||||
++n_drafted;
|
if (batch_tgt.n_tokens > n_draft) {
|
||||||
|
++n_drafted;
|
||||||
|
}
|
||||||
|
|
||||||
// evaluate the target model on the drafted tokens
|
// evaluate the target model on the drafted tokens
|
||||||
{
|
{
|
||||||
|
Loading…
Reference in New Issue
Block a user