last n tokens done

This commit is contained in:
Henri Vasserman 2023-07-20 00:36:36 +03:00
parent 42591a0acd
commit dd3cf5760a
No known key found for this signature in database
GPG Key ID: 2995FC0F58B1A986

View File

@ -165,15 +165,14 @@ static bool server_verbose = false;
#define LOG_WARNING(MSG, ...) server_log("WARNING", __func__, __LINE__, MSG, __VA_ARGS__)
#define LOG_INFO(MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__)
// helper class to manage prompt loading and truncation
struct prompt_evaluator {
llama_context * ctx;
size_t n_ctx = 0;
//std::string prompt;
std::vector<llama_token> embd;
std::vector<llama_token> last_n_tokens;
size_t num_prompt_tokens = 0;
//size_t num_tokens_predicted = 0;
//size_t n_remain = 0;
size_t repeat_last_n = 0;
size_t n_past = 0;
size_t n_keep = 0;
bool truncated = false;
@ -226,11 +225,12 @@ struct prompt_evaluator {
prompt_tokens = new_tokens;
}
last_n_tokens.resize(n_last);
// fill the last n tokens from the input even if context is truncated
repeat_last_n = n_last;
last_n_tokens.clear();
if (n_last > 0) {
const size_t s = std::min(n_last, num_prompt_tokens);
std::fill(last_n_tokens.begin(), last_n_tokens.end() - s, 0);
std::copy(prompt_tokens.end() - s, prompt_tokens.end(), last_n_tokens.begin());
last_n_tokens.insert(last_n_tokens.begin(),
std::max(prompt_tokens.begin(), prompt_tokens.end() - n_last), prompt_tokens.end());
}
// compare the evaluated prompt with the new prompt
@ -295,8 +295,10 @@ struct prompt_evaluator {
}
void append_token(llama_token id) {
if (last_n_tokens.size() > 0) {
last_n_tokens.erase(last_n_tokens.begin());
if (repeat_last_n > 0) {
if (last_n_tokens.size() >= repeat_last_n) {
last_n_tokens.erase(last_n_tokens.begin());
}
last_n_tokens.push_back(id);
}
embd.push_back(id);
@ -410,7 +412,6 @@ struct llama_server_context {
const float top_p = params.top_p;
const float tfs_z = params.tfs_z;
const float typical_p = params.typical_p;
//const int32_t repeat_last_n = params.repeat_last_n < 0 ? params.n_ctx : params.repeat_last_n;
const float repeat_penalty = params.repeat_penalty;
const float alpha_presence = params.presence_penalty;
const float alpha_frequency = params.frequency_penalty;
@ -444,7 +445,6 @@ struct llama_server_context {
// Apply penalties
float nl_logit = logits[llama_token_nl()];
//auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), params.n_ctx);
llama_sample_repetition_penalty(ctx, &candidates_p,
evaluator.last_n_tokens.data(), evaluator.last_n_tokens.size(),
repeat_penalty);