speculative : manage context in common_speculative
Some checks are pending
flake8 Lint / Lint (push) Waiting to run

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-11-21 21:27:14 +02:00
parent fe043ff1ff
commit 0f878a657c
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
9 changed files with 188 additions and 144 deletions

View File

@ -536,12 +536,12 @@ std::string string_from(const struct llama_context * ctx, const struct llama_bat
[](const unsigned char c) { return !std::isprint(c); }),
detokenized.end());
buf << "\n" << std::to_string(i)
<< ":token '" << detokenized << "'"
<< ":pos " << std::to_string(batch.pos[i])
<< ":n_seq_id " << std::to_string(batch.n_seq_id[i])
<< ":seq_id " << std::to_string(batch.seq_id[i][0])
<< ":logits " << std::to_string(batch.logits[i]);
buf << "\n" << std::to_string(i)
<< ", token '" << detokenized << "'"
<< ", pos " << std::to_string(batch.pos[i])
<< ", n_seq_id " << std::to_string(batch.n_seq_id[i])
<< ", seq_id " << std::to_string(batch.seq_id[i][0])
<< ", logits " << std::to_string(batch.logits[i]);
}
buf << " ]";
@ -1490,6 +1490,66 @@ void common_batch_add(
batch.n_tokens++;
}
//
// Token utils
//
size_t common_lcp(const llama_tokens & a, const llama_tokens & b) {
size_t i;
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
return i;
}
size_t common_lcs(const llama_tokens & a, const llama_tokens & b) {
// check for empty sequences
if (a.empty() || b.empty()) {
return 0;
}
// get the lengths of the input sequences
size_t a_len = a.size();
size_t b_len = b.size();
// initialize the maximum length of the longest common subsequence (LCS)
size_t max_length = 0;
// use two rows instead of a 2D matrix to optimize space
std::vector<size_t> prev_row(b_len + 1, 0);
std::vector<size_t> curr_row(b_len + 1, 0);
// iterate through the elements of a
for (size_t i = 1; i <= a_len; i++) {
// iterate through the elements of b
for (size_t j = 1; j <= b_len; j++) {
// if elements at the current positions match
if (a[i - 1] == b[j - 1]) {
// if it's the first element of either sequences, set LCS length to 1
if (i == 1 || j == 1) {
curr_row[j] = 1;
} else {
// increment LCS length by 1 compared to the previous element
curr_row[j] = prev_row[j - 1] + 1;
}
// update max_length if necessary
if (curr_row[j] > max_length) {
max_length = curr_row[j];
}
} else {
// reset LCS length if elements don't match
curr_row[j] = 0;
}
}
// update the previous row for the next iteration
prev_row = curr_row;
}
// return the maximum length of the LCS
return max_length;
}
//
// Vocab utils
//

View File

@ -33,6 +33,8 @@ struct common_lora_adapter_container : common_lora_adapter_info {
struct llama_lora_adapter * adapter;
};
using llama_tokens = std::vector<llama_token>;
// build info
extern int LLAMA_BUILD_NUMBER;
extern char const * LLAMA_COMMIT;
@ -461,7 +463,9 @@ struct llama_model * common_load_model_from_hf(const char * repo, const char * f
// clear LoRA adapters from context, then apply new list of adapters
void common_lora_adapters_apply(struct llama_context * ctx, std::vector<common_lora_adapter_container> & lora_adapters);
//
// Batch utils
//
void common_batch_clear(struct llama_batch & batch);
@ -472,6 +476,16 @@ void common_batch_add(
const std::vector<llama_seq_id> & seq_ids,
bool logits);
//
// Token utils
//
// longest common prefix
size_t common_lcp(const llama_tokens & a, const llama_tokens & b);
// longet common subsequence
size_t common_lcs(const llama_tokens & a, const llama_tokens & b);
//
// Vocab utils
//

View File

@ -342,6 +342,28 @@ std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl,
return result;
}
std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const struct llama_batch & batch, bool grammar_first) {
std::vector<int> idxs;
idxs.reserve(batch.n_tokens);
std::vector<llama_token> draft;
draft.reserve(batch.n_tokens);
for (int i = 0; i < batch.n_tokens; i++) {
if (batch.logits[i] == 0) {
continue;
}
if (idxs.size() > 0) {
GGML_ASSERT(batch.pos[idxs.back()] + 1 == batch.pos[i]);
draft.push_back(batch.token[i]);
}
idxs.push_back(i);
}
return common_sampler_sample_n(gsmpl, ctx, idxs, draft, grammar_first);
}
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
return llama_sampler_get_seed(gsmpl->chain);
}

View File

@ -73,6 +73,8 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
//
std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const std::vector<llama_token> & draft, bool grammar_first = false);
std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const struct llama_batch & batch, bool grammar_first = false);
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
// helpers

View File

@ -11,9 +11,7 @@ struct common_speculative {
struct common_sampler * smpl;
std::vector<int> i_batch_tgt;
std::vector<llama_token> tokens;
llama_tokens prompt_last;
};
struct common_speculative * common_speculative_init(struct common_speculative_params params) {
@ -21,12 +19,10 @@ struct common_speculative * common_speculative_init(struct common_speculative_pa
/* .params = */ params,
/* .batch_dft = */ llama_batch_init(llama_n_batch(params.ctx_dft), 0, 1),
/* .smpl = */ nullptr,
/* .i_batch_tgt = */ {},
/* .tokens = */ {},
};
// TODO: optimize or pass from outside?
#if 0
#if 1
{
common_sampler_params sparams;
sparams.no_perf = false;
@ -70,30 +66,79 @@ void common_speculative_free(struct common_speculative * spec) {
delete spec;
}
void common_speculative_set_prompt(struct common_speculative * spec, llama_token * tokens, int32_t n_tokens) {
llama_kv_cache_clear(spec->params.ctx_dft);
// TODO: error handling
llama_decode(spec->params.ctx_dft, llama_batch_get_one(tokens, n_tokens));
}
void common_speculative_add_draft(
struct common_speculative * spec,
struct llama_batch & batch_tgt,
const llama_tokens & prompt,
llama_token id_last,
int n_past) {
spec->tokens.clear();
llama_token n_past_tgt) {
spec->i_batch_tgt.clear();
spec->i_batch_tgt.push_back(0);
int reuse_i = 0;
int reuse_n = 0;
common_sampler_reset(spec->smpl);
const int n_ctx = llama_n_ctx(spec->params.ctx_dft) - spec->params.n_draft;
const int i_start = std::max<int>(0, (int) prompt.size() - n_ctx);
for (int i = 0; i < (int) spec->prompt_last.size(); ++i) {
int cur = 0;
while (i_start + cur < (int) prompt.size() &&
i + cur < (int) spec->prompt_last.size() &&
prompt[i_start + cur] == spec->prompt_last[i + cur]) {
cur++;
}
if ((cur >= spec->params.n_reuse || prompt.size() <= n_ctx) && cur > reuse_n) {
reuse_i = i;
reuse_n = cur;
}
}
LOG_DBG("%s: reuse_i = %d, reuse_n = %d\n", __func__, reuse_i, reuse_n);
if (reuse_n == 0) {
llama_kv_cache_clear(spec->params.ctx_dft);
spec->prompt_last.clear();
} else {
llama_kv_cache_seq_rm (spec->params.ctx_dft, 0, 0, reuse_i);
llama_kv_cache_seq_rm (spec->params.ctx_dft, 0, reuse_i + reuse_n, -1);
llama_kv_cache_seq_add(spec->params.ctx_dft, 0, reuse_i, -1, -reuse_i);
spec->prompt_last.erase(spec->prompt_last.begin(), spec->prompt_last.begin() + reuse_i);
spec->prompt_last.erase(spec->prompt_last.begin() + reuse_n, spec->prompt_last.end());
}
common_batch_clear(spec->batch_dft);
for (int i = i_start + reuse_n; i < (int) prompt.size(); ++i) {
//LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt[i]);
common_batch_add(spec->batch_dft, prompt[i], i - i_start, { 0 }, false);
spec->prompt_last.push_back(prompt[i]);
}
const llama_pos n_past = prompt.size() - i_start;
LOG_DBG("%s: n_past = %d\n", __func__, n_past);
if (spec->batch_dft.n_tokens > 0) {
LOG_DBG("%s: draft batch: %s\n", __func__, string_from(spec->params.ctx_dft, spec->batch_dft).c_str());
llama_decode(spec->params.ctx_dft, spec->batch_dft);
}
common_batch_clear(spec->batch_dft);
common_batch_add (spec->batch_dft, id_last, n_past, { 0 }, true);
spec->prompt_last.push_back(id_last);
LOG_DBG("%s: prompt_last: %s\n", __func__, string_from(spec->params.ctx_dft, spec->prompt_last).c_str());
llama_decode(spec->params.ctx_dft, spec->batch_dft);
common_sampler_reset(spec->smpl);
// sample n_draft tokens from the draft model
for (int i = 0; i < spec->params.n_draft; ++i) {
common_batch_clear(spec->batch_dft);
@ -111,18 +156,13 @@ void common_speculative_add_draft(
const llama_token id = cur_p->data[0].id;
// only collect very high-confidence draft tokens
if (cur_p->data[0].p < 0.75 && spec->tokens.size() >= 0) {
if (cur_p->data[0].p < spec->params.p_min) {
break;
}
common_sampler_accept(spec->smpl, id, true);
spec->tokens.push_back(id);
// add unique drafted tokens to the target batch
spec->i_batch_tgt.push_back(batch_tgt.n_tokens);
common_batch_add(batch_tgt, id, n_past + i + 1, { 0 }, true);
common_batch_add(batch_tgt, id, n_past_tgt + i, { 0 }, true);
if (batch_tgt.n_tokens > spec->params.n_draft) {
break;
@ -132,23 +172,13 @@ void common_speculative_add_draft(
// evaluate the drafted tokens on the draft model
llama_decode(spec->params.ctx_dft, spec->batch_dft);
spec->prompt_last.push_back(id);
}
// don't waste time on small batches
// TODO: do not evaluate the draft model for that many rounds
if (batch_tgt.n_tokens < spec->params.n_min) {
batch_tgt.n_tokens = 1;
spec->tokens.resize(0);
spec->i_batch_tgt.resize(1);
}
// print current draft sequences
LOG_DBG("draft %s\n", string_from(spec->params.ctx_dft, spec->tokens).c_str());
}
std::vector<llama_token> common_speculative_sample(
struct common_speculative * spec,
struct common_sampler * smpl,
struct llama_context * ctx_tgt) {
return common_sampler_sample_n(smpl, ctx_tgt, spec->i_batch_tgt, spec->tokens);
}

View File

@ -1,14 +1,16 @@
#pragma once
#include "llama.h"
#include <vector>
#include "common.h"
struct common_speculative;
struct common_speculative_params {
int n_draft = 16;
int n_min = 5; // do not add drafts smaller than this, TODO: leave this to user?
int n_reuse = 256;
float p_min = 0.9f;
struct llama_model * model_dft = nullptr;
@ -19,28 +21,11 @@ struct common_speculative * common_speculative_init(struct common_speculative_pa
void common_speculative_free(struct common_speculative * spec);
// TODO: remove
void common_speculative_set_prompt(struct common_speculative * spec, llama_token * tokens, int32_t n_tokens);
// sample up to n_draft tokens and add them to the batch using the draft model
//
// TODO: change to:
//
// void common_speculative_add_draft(
// struct common_speculative * spec,
// struct llama_batch & batch_tgt,
// llama_token * tokens,
// int32_t n_tokens);
//
// and update the internal logic to compute only the new tokens
//
void common_speculative_add_draft(
struct common_speculative * spec,
struct llama_batch & batch_tgt,
const llama_tokens & prompt,
llama_token id_last,
int n_past);
std::vector<llama_token> common_speculative_sample(
struct common_speculative * spec,
struct common_sampler * smpl,
struct llama_context * ctx_tgt);
llama_token n_past_tgt);

View File

@ -743,7 +743,7 @@ struct server_context {
}
// length of the Longest Common Subsequence between the current slot's prompt and the input prompt
int cur_lcs_len = longest_common_subsequence(slot.cache_tokens, task.prompt_tokens);
int cur_lcs_len = common_lcs(slot.cache_tokens, task.prompt_tokens);
// fraction of the common subsequence length compared to the current slot's prompt length
float cur_similarity = static_cast<float>(cur_lcs_len) / static_cast<int>(slot.cache_tokens.size());
@ -1960,7 +1960,7 @@ struct server_context {
if (slot.params.cache_prompt) {
// reuse any previously computed tokens that are common with the new prompt
slot.n_past = longest_common_prefix(slot.cache_tokens, prompt_tokens);
slot.n_past = common_lcp(slot.cache_tokens, prompt_tokens);
// reuse chunks from the cached prompt by shifting their KV cache in the new position
if (params.n_cache_reuse > 0) {

View File

@ -24,7 +24,6 @@
#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
using json = nlohmann::ordered_json;
using llama_tokens = std::vector<llama_token>;
#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
@ -439,62 +438,6 @@ static std::string gen_chatcmplid() {
// other common utils
//
static size_t longest_common_prefix(const llama_tokens & a, const llama_tokens & b) {
size_t i;
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
return i;
}
static size_t longest_common_subsequence(const llama_tokens & a, const llama_tokens & b) {
// check for empty sequences
if (a.empty() || b.empty()) {
return 0;
}
// get the lengths of the input sequences
size_t a_len = a.size();
size_t b_len = b.size();
// initialize the maximum length of the longest common subsequence (LCS)
size_t max_length = 0;
// use two rows instead of a 2D matrix to optimize space
std::vector<size_t> prev_row(b_len + 1, 0);
std::vector<size_t> curr_row(b_len + 1, 0);
// iterate through the elements of a
for (size_t i = 1; i <= a_len; i++) {
// iterate through the elements of b
for (size_t j = 1; j <= b_len; j++) {
// if elements at the current positions match
if (a[i - 1] == b[j - 1]) {
// if it's the first element of either sequences, set LCS length to 1
if (i == 1 || j == 1) {
curr_row[j] = 1;
} else {
// increment LCS length by 1 compared to the previous element
curr_row[j] = prev_row[j - 1] + 1;
}
// update max_length if necessary
if (curr_row[j] > max_length) {
max_length = curr_row[j];
}
} else {
// reset LCS length if elements don't match
curr_row[j] = 0;
}
}
// update the previous row for the next iteration
prev_row = curr_row;
}
// return the maximum length of the LCS
return max_length;
}
static bool ends_with(const std::string & str, const std::string & suffix) {
return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
}

View File

@ -14,14 +14,6 @@
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
struct seq_draft {
std::vector<int> i_batch_tgt;
std::vector<llama_token> tokens;
struct common_sampler * smpl = nullptr;
};
int main(int argc, char ** argv) {
common_params params;
@ -165,27 +157,21 @@ int main(int argc, char ** argv) {
// note: keep the last token separate!
llama_token id_last = inp.back();
auto prompt_dft = std::vector<llama_token>(inp.begin(), inp.end() - 1);
int n_past = inp.size() - 1;
// init the speculator
struct common_speculative_params params_spec;
params_spec.n_draft = n_draft;
params_spec.n_min = 5;
params_spec.n_reuse = 256;
params_spec.p_min = 0.9f;
params_spec.model_dft = model_dft;
params_spec.ctx_dft = ctx_dft;
struct common_speculative * spec = common_speculative_init(params_spec);
// feed the prompt to the speculator
//
// this has to be kept synchronized with the target context
//
// TODO: simplify this by moving the context management logic in the common_speculative instance
// for example, the common_speculative_add_draft can pass the entire context (or part of it) and the
// speculator will automatically compute any new tokens that are not present in its context
//
common_speculative_set_prompt(spec, inp.data(), n_input - 1);
llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1);
const auto t_enc_end = ggml_time_us();
@ -204,7 +190,7 @@ int main(int argc, char ** argv) {
// offloaded to a remote device. it doesn't even have to be based on an LLM. instead, it can provide tokens
// from a cache or lookup tables.
//
common_speculative_add_draft(spec, batch_tgt, id_last, n_past);
common_speculative_add_draft(spec, batch_tgt, prompt_dft, id_last, n_past + 1);
// evaluate the target model on [id_last, draft0, draft1, ..., draftN-1]
{
@ -220,7 +206,7 @@ int main(int argc, char ** argv) {
// available logits from the batch and sample the next token until we run out of logits or the sampler
// disagrees with the draft
//
const auto ids = common_speculative_sample(spec, smpl, ctx_tgt);
const auto ids = common_sampler_sample_n(smpl, ctx_tgt, batch_tgt);
GGML_ASSERT(ids.size() > 0); // there will always be at least one accepted token
@ -266,9 +252,11 @@ int main(int argc, char ** argv) {
LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past);
llama_kv_cache_seq_rm(ctx_tgt, 0, n_past, -1);
llama_kv_cache_seq_rm(ctx_dft, 0, n_past, -1);
}
prompt_dft.push_back(id_last);
prompt_dft.insert(prompt_dft.end(), ids.begin(), ids.end() - 1);
// remember the last accepted token for the next iteration
id_last = id;
}