mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 20:04:35 +00:00
speculative : manage context in common_speculative
Some checks are pending
flake8 Lint / Lint (push) Waiting to run
Some checks are pending
flake8 Lint / Lint (push) Waiting to run
ggml-ci
This commit is contained in:
parent
fe043ff1ff
commit
0f878a657c
@ -536,12 +536,12 @@ std::string string_from(const struct llama_context * ctx, const struct llama_bat
|
||||
[](const unsigned char c) { return !std::isprint(c); }),
|
||||
detokenized.end());
|
||||
|
||||
buf << "\n" << std::to_string(i)
|
||||
<< ":token '" << detokenized << "'"
|
||||
<< ":pos " << std::to_string(batch.pos[i])
|
||||
<< ":n_seq_id " << std::to_string(batch.n_seq_id[i])
|
||||
<< ":seq_id " << std::to_string(batch.seq_id[i][0])
|
||||
<< ":logits " << std::to_string(batch.logits[i]);
|
||||
buf << "\n" << std::to_string(i)
|
||||
<< ", token '" << detokenized << "'"
|
||||
<< ", pos " << std::to_string(batch.pos[i])
|
||||
<< ", n_seq_id " << std::to_string(batch.n_seq_id[i])
|
||||
<< ", seq_id " << std::to_string(batch.seq_id[i][0])
|
||||
<< ", logits " << std::to_string(batch.logits[i]);
|
||||
}
|
||||
|
||||
buf << " ]";
|
||||
@ -1490,6 +1490,66 @@ void common_batch_add(
|
||||
batch.n_tokens++;
|
||||
}
|
||||
|
||||
//
|
||||
// Token utils
|
||||
//
|
||||
|
||||
size_t common_lcp(const llama_tokens & a, const llama_tokens & b) {
|
||||
size_t i;
|
||||
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
|
||||
|
||||
return i;
|
||||
}
|
||||
|
||||
size_t common_lcs(const llama_tokens & a, const llama_tokens & b) {
|
||||
// check for empty sequences
|
||||
if (a.empty() || b.empty()) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// get the lengths of the input sequences
|
||||
size_t a_len = a.size();
|
||||
size_t b_len = b.size();
|
||||
|
||||
// initialize the maximum length of the longest common subsequence (LCS)
|
||||
size_t max_length = 0;
|
||||
|
||||
// use two rows instead of a 2D matrix to optimize space
|
||||
std::vector<size_t> prev_row(b_len + 1, 0);
|
||||
std::vector<size_t> curr_row(b_len + 1, 0);
|
||||
|
||||
// iterate through the elements of a
|
||||
for (size_t i = 1; i <= a_len; i++) {
|
||||
// iterate through the elements of b
|
||||
for (size_t j = 1; j <= b_len; j++) {
|
||||
// if elements at the current positions match
|
||||
if (a[i - 1] == b[j - 1]) {
|
||||
// if it's the first element of either sequences, set LCS length to 1
|
||||
if (i == 1 || j == 1) {
|
||||
curr_row[j] = 1;
|
||||
} else {
|
||||
// increment LCS length by 1 compared to the previous element
|
||||
curr_row[j] = prev_row[j - 1] + 1;
|
||||
}
|
||||
|
||||
// update max_length if necessary
|
||||
if (curr_row[j] > max_length) {
|
||||
max_length = curr_row[j];
|
||||
}
|
||||
} else {
|
||||
// reset LCS length if elements don't match
|
||||
curr_row[j] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// update the previous row for the next iteration
|
||||
prev_row = curr_row;
|
||||
}
|
||||
|
||||
// return the maximum length of the LCS
|
||||
return max_length;
|
||||
}
|
||||
|
||||
//
|
||||
// Vocab utils
|
||||
//
|
||||
|
@ -33,6 +33,8 @@ struct common_lora_adapter_container : common_lora_adapter_info {
|
||||
struct llama_lora_adapter * adapter;
|
||||
};
|
||||
|
||||
using llama_tokens = std::vector<llama_token>;
|
||||
|
||||
// build info
|
||||
extern int LLAMA_BUILD_NUMBER;
|
||||
extern char const * LLAMA_COMMIT;
|
||||
@ -461,7 +463,9 @@ struct llama_model * common_load_model_from_hf(const char * repo, const char * f
|
||||
// clear LoRA adapters from context, then apply new list of adapters
|
||||
void common_lora_adapters_apply(struct llama_context * ctx, std::vector<common_lora_adapter_container> & lora_adapters);
|
||||
|
||||
//
|
||||
// Batch utils
|
||||
//
|
||||
|
||||
void common_batch_clear(struct llama_batch & batch);
|
||||
|
||||
@ -472,6 +476,16 @@ void common_batch_add(
|
||||
const std::vector<llama_seq_id> & seq_ids,
|
||||
bool logits);
|
||||
|
||||
//
|
||||
// Token utils
|
||||
//
|
||||
|
||||
// longest common prefix
|
||||
size_t common_lcp(const llama_tokens & a, const llama_tokens & b);
|
||||
|
||||
// longet common subsequence
|
||||
size_t common_lcs(const llama_tokens & a, const llama_tokens & b);
|
||||
|
||||
//
|
||||
// Vocab utils
|
||||
//
|
||||
|
@ -342,6 +342,28 @@ std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl,
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const struct llama_batch & batch, bool grammar_first) {
|
||||
std::vector<int> idxs;
|
||||
idxs.reserve(batch.n_tokens);
|
||||
|
||||
std::vector<llama_token> draft;
|
||||
draft.reserve(batch.n_tokens);
|
||||
|
||||
for (int i = 0; i < batch.n_tokens; i++) {
|
||||
if (batch.logits[i] == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (idxs.size() > 0) {
|
||||
GGML_ASSERT(batch.pos[idxs.back()] + 1 == batch.pos[i]);
|
||||
draft.push_back(batch.token[i]);
|
||||
}
|
||||
idxs.push_back(i);
|
||||
}
|
||||
|
||||
return common_sampler_sample_n(gsmpl, ctx, idxs, draft, grammar_first);
|
||||
}
|
||||
|
||||
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
|
||||
return llama_sampler_get_seed(gsmpl->chain);
|
||||
}
|
||||
|
@ -73,6 +73,8 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
|
||||
//
|
||||
std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const std::vector<llama_token> & draft, bool grammar_first = false);
|
||||
|
||||
std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const struct llama_batch & batch, bool grammar_first = false);
|
||||
|
||||
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
|
||||
|
||||
// helpers
|
||||
|
@ -11,9 +11,7 @@ struct common_speculative {
|
||||
|
||||
struct common_sampler * smpl;
|
||||
|
||||
std::vector<int> i_batch_tgt;
|
||||
|
||||
std::vector<llama_token> tokens;
|
||||
llama_tokens prompt_last;
|
||||
};
|
||||
|
||||
struct common_speculative * common_speculative_init(struct common_speculative_params params) {
|
||||
@ -21,12 +19,10 @@ struct common_speculative * common_speculative_init(struct common_speculative_pa
|
||||
/* .params = */ params,
|
||||
/* .batch_dft = */ llama_batch_init(llama_n_batch(params.ctx_dft), 0, 1),
|
||||
/* .smpl = */ nullptr,
|
||||
/* .i_batch_tgt = */ {},
|
||||
/* .tokens = */ {},
|
||||
};
|
||||
|
||||
// TODO: optimize or pass from outside?
|
||||
#if 0
|
||||
#if 1
|
||||
{
|
||||
common_sampler_params sparams;
|
||||
sparams.no_perf = false;
|
||||
@ -70,30 +66,79 @@ void common_speculative_free(struct common_speculative * spec) {
|
||||
delete spec;
|
||||
}
|
||||
|
||||
void common_speculative_set_prompt(struct common_speculative * spec, llama_token * tokens, int32_t n_tokens) {
|
||||
llama_kv_cache_clear(spec->params.ctx_dft);
|
||||
|
||||
// TODO: error handling
|
||||
llama_decode(spec->params.ctx_dft, llama_batch_get_one(tokens, n_tokens));
|
||||
}
|
||||
|
||||
void common_speculative_add_draft(
|
||||
struct common_speculative * spec,
|
||||
struct llama_batch & batch_tgt,
|
||||
const llama_tokens & prompt,
|
||||
llama_token id_last,
|
||||
int n_past) {
|
||||
spec->tokens.clear();
|
||||
llama_token n_past_tgt) {
|
||||
|
||||
spec->i_batch_tgt.clear();
|
||||
spec->i_batch_tgt.push_back(0);
|
||||
int reuse_i = 0;
|
||||
int reuse_n = 0;
|
||||
|
||||
common_sampler_reset(spec->smpl);
|
||||
const int n_ctx = llama_n_ctx(spec->params.ctx_dft) - spec->params.n_draft;
|
||||
|
||||
const int i_start = std::max<int>(0, (int) prompt.size() - n_ctx);
|
||||
|
||||
for (int i = 0; i < (int) spec->prompt_last.size(); ++i) {
|
||||
int cur = 0;
|
||||
while (i_start + cur < (int) prompt.size() &&
|
||||
i + cur < (int) spec->prompt_last.size() &&
|
||||
prompt[i_start + cur] == spec->prompt_last[i + cur]) {
|
||||
cur++;
|
||||
}
|
||||
|
||||
if ((cur >= spec->params.n_reuse || prompt.size() <= n_ctx) && cur > reuse_n) {
|
||||
reuse_i = i;
|
||||
reuse_n = cur;
|
||||
}
|
||||
}
|
||||
|
||||
LOG_DBG("%s: reuse_i = %d, reuse_n = %d\n", __func__, reuse_i, reuse_n);
|
||||
|
||||
if (reuse_n == 0) {
|
||||
llama_kv_cache_clear(spec->params.ctx_dft);
|
||||
|
||||
spec->prompt_last.clear();
|
||||
} else {
|
||||
llama_kv_cache_seq_rm (spec->params.ctx_dft, 0, 0, reuse_i);
|
||||
llama_kv_cache_seq_rm (spec->params.ctx_dft, 0, reuse_i + reuse_n, -1);
|
||||
llama_kv_cache_seq_add(spec->params.ctx_dft, 0, reuse_i, -1, -reuse_i);
|
||||
|
||||
spec->prompt_last.erase(spec->prompt_last.begin(), spec->prompt_last.begin() + reuse_i);
|
||||
spec->prompt_last.erase(spec->prompt_last.begin() + reuse_n, spec->prompt_last.end());
|
||||
}
|
||||
|
||||
common_batch_clear(spec->batch_dft);
|
||||
|
||||
for (int i = i_start + reuse_n; i < (int) prompt.size(); ++i) {
|
||||
//LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt[i]);
|
||||
common_batch_add(spec->batch_dft, prompt[i], i - i_start, { 0 }, false);
|
||||
|
||||
spec->prompt_last.push_back(prompt[i]);
|
||||
}
|
||||
|
||||
const llama_pos n_past = prompt.size() - i_start;
|
||||
|
||||
LOG_DBG("%s: n_past = %d\n", __func__, n_past);
|
||||
|
||||
if (spec->batch_dft.n_tokens > 0) {
|
||||
LOG_DBG("%s: draft batch: %s\n", __func__, string_from(spec->params.ctx_dft, spec->batch_dft).c_str());
|
||||
|
||||
llama_decode(spec->params.ctx_dft, spec->batch_dft);
|
||||
}
|
||||
|
||||
common_batch_clear(spec->batch_dft);
|
||||
common_batch_add (spec->batch_dft, id_last, n_past, { 0 }, true);
|
||||
|
||||
spec->prompt_last.push_back(id_last);
|
||||
|
||||
LOG_DBG("%s: prompt_last: %s\n", __func__, string_from(spec->params.ctx_dft, spec->prompt_last).c_str());
|
||||
|
||||
llama_decode(spec->params.ctx_dft, spec->batch_dft);
|
||||
|
||||
common_sampler_reset(spec->smpl);
|
||||
|
||||
// sample n_draft tokens from the draft model
|
||||
for (int i = 0; i < spec->params.n_draft; ++i) {
|
||||
common_batch_clear(spec->batch_dft);
|
||||
@ -111,18 +156,13 @@ void common_speculative_add_draft(
|
||||
const llama_token id = cur_p->data[0].id;
|
||||
|
||||
// only collect very high-confidence draft tokens
|
||||
if (cur_p->data[0].p < 0.75 && spec->tokens.size() >= 0) {
|
||||
if (cur_p->data[0].p < spec->params.p_min) {
|
||||
break;
|
||||
}
|
||||
|
||||
common_sampler_accept(spec->smpl, id, true);
|
||||
|
||||
spec->tokens.push_back(id);
|
||||
|
||||
// add unique drafted tokens to the target batch
|
||||
spec->i_batch_tgt.push_back(batch_tgt.n_tokens);
|
||||
|
||||
common_batch_add(batch_tgt, id, n_past + i + 1, { 0 }, true);
|
||||
common_batch_add(batch_tgt, id, n_past_tgt + i, { 0 }, true);
|
||||
|
||||
if (batch_tgt.n_tokens > spec->params.n_draft) {
|
||||
break;
|
||||
@ -132,23 +172,13 @@ void common_speculative_add_draft(
|
||||
|
||||
// evaluate the drafted tokens on the draft model
|
||||
llama_decode(spec->params.ctx_dft, spec->batch_dft);
|
||||
|
||||
spec->prompt_last.push_back(id);
|
||||
}
|
||||
|
||||
// don't waste time on small batches
|
||||
// TODO: do not evaluate the draft model for that many rounds
|
||||
if (batch_tgt.n_tokens < spec->params.n_min) {
|
||||
batch_tgt.n_tokens = 1;
|
||||
spec->tokens.resize(0);
|
||||
spec->i_batch_tgt.resize(1);
|
||||
}
|
||||
|
||||
// print current draft sequences
|
||||
LOG_DBG("draft %s\n", string_from(spec->params.ctx_dft, spec->tokens).c_str());
|
||||
}
|
||||
|
||||
std::vector<llama_token> common_speculative_sample(
|
||||
struct common_speculative * spec,
|
||||
struct common_sampler * smpl,
|
||||
struct llama_context * ctx_tgt) {
|
||||
return common_sampler_sample_n(smpl, ctx_tgt, spec->i_batch_tgt, spec->tokens);
|
||||
}
|
||||
|
@ -1,14 +1,16 @@
|
||||
#pragma once
|
||||
|
||||
#include "llama.h"
|
||||
|
||||
#include <vector>
|
||||
#include "common.h"
|
||||
|
||||
struct common_speculative;
|
||||
|
||||
struct common_speculative_params {
|
||||
int n_draft = 16;
|
||||
int n_min = 5; // do not add drafts smaller than this, TODO: leave this to user?
|
||||
int n_reuse = 256;
|
||||
|
||||
float p_min = 0.9f;
|
||||
|
||||
struct llama_model * model_dft = nullptr;
|
||||
|
||||
@ -19,28 +21,11 @@ struct common_speculative * common_speculative_init(struct common_speculative_pa
|
||||
|
||||
void common_speculative_free(struct common_speculative * spec);
|
||||
|
||||
// TODO: remove
|
||||
void common_speculative_set_prompt(struct common_speculative * spec, llama_token * tokens, int32_t n_tokens);
|
||||
|
||||
// sample up to n_draft tokens and add them to the batch using the draft model
|
||||
//
|
||||
// TODO: change to:
|
||||
//
|
||||
// void common_speculative_add_draft(
|
||||
// struct common_speculative * spec,
|
||||
// struct llama_batch & batch_tgt,
|
||||
// llama_token * tokens,
|
||||
// int32_t n_tokens);
|
||||
//
|
||||
// and update the internal logic to compute only the new tokens
|
||||
//
|
||||
void common_speculative_add_draft(
|
||||
struct common_speculative * spec,
|
||||
struct llama_batch & batch_tgt,
|
||||
const llama_tokens & prompt,
|
||||
llama_token id_last,
|
||||
int n_past);
|
||||
|
||||
std::vector<llama_token> common_speculative_sample(
|
||||
struct common_speculative * spec,
|
||||
struct common_sampler * smpl,
|
||||
struct llama_context * ctx_tgt);
|
||||
llama_token n_past_tgt);
|
||||
|
@ -743,7 +743,7 @@ struct server_context {
|
||||
}
|
||||
|
||||
// length of the Longest Common Subsequence between the current slot's prompt and the input prompt
|
||||
int cur_lcs_len = longest_common_subsequence(slot.cache_tokens, task.prompt_tokens);
|
||||
int cur_lcs_len = common_lcs(slot.cache_tokens, task.prompt_tokens);
|
||||
|
||||
// fraction of the common subsequence length compared to the current slot's prompt length
|
||||
float cur_similarity = static_cast<float>(cur_lcs_len) / static_cast<int>(slot.cache_tokens.size());
|
||||
@ -1960,7 +1960,7 @@ struct server_context {
|
||||
|
||||
if (slot.params.cache_prompt) {
|
||||
// reuse any previously computed tokens that are common with the new prompt
|
||||
slot.n_past = longest_common_prefix(slot.cache_tokens, prompt_tokens);
|
||||
slot.n_past = common_lcp(slot.cache_tokens, prompt_tokens);
|
||||
|
||||
// reuse chunks from the cached prompt by shifting their KV cache in the new position
|
||||
if (params.n_cache_reuse > 0) {
|
||||
|
@ -24,7 +24,6 @@
|
||||
#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
using llama_tokens = std::vector<llama_token>;
|
||||
|
||||
#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
||||
#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
||||
@ -439,62 +438,6 @@ static std::string gen_chatcmplid() {
|
||||
// other common utils
|
||||
//
|
||||
|
||||
static size_t longest_common_prefix(const llama_tokens & a, const llama_tokens & b) {
|
||||
size_t i;
|
||||
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
|
||||
|
||||
return i;
|
||||
}
|
||||
|
||||
static size_t longest_common_subsequence(const llama_tokens & a, const llama_tokens & b) {
|
||||
// check for empty sequences
|
||||
if (a.empty() || b.empty()) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// get the lengths of the input sequences
|
||||
size_t a_len = a.size();
|
||||
size_t b_len = b.size();
|
||||
|
||||
// initialize the maximum length of the longest common subsequence (LCS)
|
||||
size_t max_length = 0;
|
||||
|
||||
// use two rows instead of a 2D matrix to optimize space
|
||||
std::vector<size_t> prev_row(b_len + 1, 0);
|
||||
std::vector<size_t> curr_row(b_len + 1, 0);
|
||||
|
||||
// iterate through the elements of a
|
||||
for (size_t i = 1; i <= a_len; i++) {
|
||||
// iterate through the elements of b
|
||||
for (size_t j = 1; j <= b_len; j++) {
|
||||
// if elements at the current positions match
|
||||
if (a[i - 1] == b[j - 1]) {
|
||||
// if it's the first element of either sequences, set LCS length to 1
|
||||
if (i == 1 || j == 1) {
|
||||
curr_row[j] = 1;
|
||||
} else {
|
||||
// increment LCS length by 1 compared to the previous element
|
||||
curr_row[j] = prev_row[j - 1] + 1;
|
||||
}
|
||||
|
||||
// update max_length if necessary
|
||||
if (curr_row[j] > max_length) {
|
||||
max_length = curr_row[j];
|
||||
}
|
||||
} else {
|
||||
// reset LCS length if elements don't match
|
||||
curr_row[j] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// update the previous row for the next iteration
|
||||
prev_row = curr_row;
|
||||
}
|
||||
|
||||
// return the maximum length of the LCS
|
||||
return max_length;
|
||||
}
|
||||
|
||||
static bool ends_with(const std::string & str, const std::string & suffix) {
|
||||
return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
|
||||
}
|
||||
|
@ -14,14 +14,6 @@
|
||||
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
|
||||
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
|
||||
|
||||
struct seq_draft {
|
||||
std::vector<int> i_batch_tgt;
|
||||
|
||||
std::vector<llama_token> tokens;
|
||||
|
||||
struct common_sampler * smpl = nullptr;
|
||||
};
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
common_params params;
|
||||
|
||||
@ -165,27 +157,21 @@ int main(int argc, char ** argv) {
|
||||
// note: keep the last token separate!
|
||||
llama_token id_last = inp.back();
|
||||
|
||||
auto prompt_dft = std::vector<llama_token>(inp.begin(), inp.end() - 1);
|
||||
|
||||
int n_past = inp.size() - 1;
|
||||
|
||||
// init the speculator
|
||||
struct common_speculative_params params_spec;
|
||||
params_spec.n_draft = n_draft;
|
||||
params_spec.n_min = 5;
|
||||
params_spec.n_reuse = 256;
|
||||
params_spec.p_min = 0.9f;
|
||||
params_spec.model_dft = model_dft;
|
||||
params_spec.ctx_dft = ctx_dft;
|
||||
|
||||
struct common_speculative * spec = common_speculative_init(params_spec);
|
||||
|
||||
// feed the prompt to the speculator
|
||||
//
|
||||
// this has to be kept synchronized with the target context
|
||||
//
|
||||
// TODO: simplify this by moving the context management logic in the common_speculative instance
|
||||
// for example, the common_speculative_add_draft can pass the entire context (or part of it) and the
|
||||
// speculator will automatically compute any new tokens that are not present in its context
|
||||
//
|
||||
common_speculative_set_prompt(spec, inp.data(), n_input - 1);
|
||||
|
||||
llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1);
|
||||
|
||||
const auto t_enc_end = ggml_time_us();
|
||||
@ -204,7 +190,7 @@ int main(int argc, char ** argv) {
|
||||
// offloaded to a remote device. it doesn't even have to be based on an LLM. instead, it can provide tokens
|
||||
// from a cache or lookup tables.
|
||||
//
|
||||
common_speculative_add_draft(spec, batch_tgt, id_last, n_past);
|
||||
common_speculative_add_draft(spec, batch_tgt, prompt_dft, id_last, n_past + 1);
|
||||
|
||||
// evaluate the target model on [id_last, draft0, draft1, ..., draftN-1]
|
||||
{
|
||||
@ -220,7 +206,7 @@ int main(int argc, char ** argv) {
|
||||
// available logits from the batch and sample the next token until we run out of logits or the sampler
|
||||
// disagrees with the draft
|
||||
//
|
||||
const auto ids = common_speculative_sample(spec, smpl, ctx_tgt);
|
||||
const auto ids = common_sampler_sample_n(smpl, ctx_tgt, batch_tgt);
|
||||
|
||||
GGML_ASSERT(ids.size() > 0); // there will always be at least one accepted token
|
||||
|
||||
@ -266,9 +252,11 @@ int main(int argc, char ** argv) {
|
||||
LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past);
|
||||
|
||||
llama_kv_cache_seq_rm(ctx_tgt, 0, n_past, -1);
|
||||
llama_kv_cache_seq_rm(ctx_dft, 0, n_past, -1);
|
||||
}
|
||||
|
||||
prompt_dft.push_back(id_last);
|
||||
prompt_dft.insert(prompt_dft.end(), ids.begin(), ids.end() - 1);
|
||||
|
||||
// remember the last accepted token for the next iteration
|
||||
id_last = id;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user