llama : refactor samplers internal implementation (#9370)

This commit is contained in:
slaren 2024-09-08 15:52:07 +02:00 committed by GitHub
parent 2a358fb0c4
commit 19f4a7b296
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 881 additions and 725 deletions

View File

@ -101,6 +101,10 @@ struct ring_buffer {
}
void push_back(const T & value) {
if (capacity == 0) {
throw std::runtime_error("ring buffer: capacity is zero");
}
if (sz == capacity) {
// advance the start when buffer is full
first = (first + 1) % capacity;

File diff suppressed because it is too large Load Diff

View File

@ -23,16 +23,6 @@ struct llama_sampler_chain {
mutable int32_t n_sample;
};
using llama_token_cnt = std::unordered_map<llama_token, int>;
// TODO: tmp exposed until test-sampling is fixed
void llama_sampler_penalties_impl(
llama_token_data_array * cur_p,
const llama_token_cnt & token_count,
float penalty_repeat,
float penalty_freq,
float penalty_present);
struct llama_sampler * llama_sampler_init_grammar_impl(
const struct llama_vocab & vocab,
const char * grammar_str,

View File

@ -148,15 +148,17 @@ static void test_penalties(
cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
}
llama_token_cnt token_count;
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
auto * sampler = llama_sampler_init_penalties(n_vocab, LLAMA_TOKEN_NULL, LLAMA_TOKEN_NULL, last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence, false, false);
for (size_t i = 0; i < last_tokens.size(); i++) {
token_count[last_tokens[i]]++;
llama_sampler_accept(sampler, last_tokens[i]);
}
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
APPLY(llama_sampler_init_softmax(), &cur_p);
DUMP(&cur_p);
llama_sampler_penalties_impl(&cur_p, token_count, repeat_penalty, alpha_frequency, alpha_presence); // TODO: avoid
APPLY(sampler, &cur_p);
APPLY(llama_sampler_init_softmax(), &cur_p);
DUMP(&cur_p);