mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-13 20:14:29 +00:00
new sampling API
This commit is contained in:
parent
84b8f2b060
commit
7196c4e08a
@ -125,7 +125,7 @@ enum slot_command {
|
|||||||
struct slot_params {
|
struct slot_params {
|
||||||
bool stream = true;
|
bool stream = true;
|
||||||
uint32_t seed = -1; // RNG seed
|
uint32_t seed = -1; // RNG seed
|
||||||
int n_keep = 0; // RNG seed
|
int n_keep = 0; // number of tokens to keep from initial prompt
|
||||||
int32_t n_predict = -1; // new tokens to predict
|
int32_t n_predict = -1; // new tokens to predict
|
||||||
std::string grammar = ""; // optional BNF-like grammar to constrain sampling
|
std::string grammar = ""; // optional BNF-like grammar to constrain sampling
|
||||||
bool cache_prompt = false; // remember a the prompt to avoid reprocessing all prompt
|
bool cache_prompt = false; // remember a the prompt to avoid reprocessing all prompt
|
||||||
@ -262,6 +262,34 @@ static json probs_vector_to_json(const llama_context *ctx, const std::vector<com
|
|||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct llama_sampling_context * llama_sampling_init_srv(const struct llama_sampling_params sparams, std::string grammar, int n_ctx) {
|
||||||
|
struct llama_sampling_context * result = new llama_sampling_context();
|
||||||
|
|
||||||
|
result->params = sparams;
|
||||||
|
result->grammar = nullptr;
|
||||||
|
|
||||||
|
// if there is a grammar, parse it
|
||||||
|
if (!grammar.empty()) {
|
||||||
|
result->parsed_grammar = grammar_parser::parse(grammar.c_str());
|
||||||
|
|
||||||
|
// will be empty (default) if there are parse errors
|
||||||
|
if (result->parsed_grammar.rules.empty()) {
|
||||||
|
fprintf(stderr, "%s: failed to parse grammar\n", __func__);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<const llama_grammar_element *> grammar_rules(result->parsed_grammar.c_rules());
|
||||||
|
|
||||||
|
result->grammar = llama_grammar_init(
|
||||||
|
grammar_rules.data(),
|
||||||
|
grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root"));
|
||||||
|
}
|
||||||
|
|
||||||
|
result->prev.resize(n_ctx);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
struct slot_image {
|
struct slot_image {
|
||||||
clip_image_u8 img_data;
|
clip_image_u8 img_data;
|
||||||
bool request_encode_image = false;
|
bool request_encode_image = false;
|
||||||
@ -287,7 +315,6 @@ struct llama_client_slot
|
|||||||
int num_tokens_predicted = 0;
|
int num_tokens_predicted = 0;
|
||||||
llama_token sampled;
|
llama_token sampled;
|
||||||
std::vector<llama_token> cache_tokens;
|
std::vector<llama_token> cache_tokens;
|
||||||
std::vector<llama_token> last_n_tokens;
|
|
||||||
std::vector<completion_token_output> generated_token_probs;
|
std::vector<completion_token_output> generated_token_probs;
|
||||||
int sent_tokens = 0;
|
int sent_tokens = 0;
|
||||||
slot_state state = IDLE;
|
slot_state state = IDLE;
|
||||||
@ -307,13 +334,12 @@ struct llama_client_slot
|
|||||||
double t_token_generation; // ms
|
double t_token_generation; // ms
|
||||||
|
|
||||||
struct slot_params params;
|
struct slot_params params;
|
||||||
struct llama_sampling_params sparams;
|
|
||||||
llama_sampling_context ctx_sampling;
|
|
||||||
bool has_next_token = true;
|
|
||||||
|
|
||||||
// grammar props
|
// sampling
|
||||||
grammar_parser::parse_state parsed_grammar;
|
struct llama_sampling_params sparams;
|
||||||
llama_grammar *grammar = nullptr;
|
llama_sampling_context* ctx_sampling = nullptr;
|
||||||
|
bool has_next_token = true;
|
||||||
|
int max_context_size = 0;
|
||||||
|
|
||||||
// multimodal
|
// multimodal
|
||||||
std::vector<slot_image> images;
|
std::vector<slot_image> images;
|
||||||
@ -332,47 +358,26 @@ struct llama_client_slot
|
|||||||
infill = false;
|
infill = false;
|
||||||
clean_tokens();
|
clean_tokens();
|
||||||
|
|
||||||
if (grammar != nullptr) {
|
if (ctx_sampling != nullptr) {
|
||||||
llama_grammar_free(grammar);
|
llama_sampling_free(ctx_sampling);
|
||||||
grammar = nullptr;
|
|
||||||
ctx_sampling.params = sparams;
|
|
||||||
ctx_sampling.grammar = NULL;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ctx_sampling = llama_sampling_init_srv(sparams, params.grammar, max_context_size);
|
||||||
|
|
||||||
for(slot_image img : images) {
|
for(slot_image img : images) {
|
||||||
free(img.image_embedding);
|
free(img.image_embedding);
|
||||||
delete[] img.img_data.data;
|
delete[] img.img_data.data;
|
||||||
img.prefix_prompt = "";
|
img.prefix_prompt = "";
|
||||||
}
|
}
|
||||||
|
|
||||||
images.clear();
|
images.clear();
|
||||||
// llama_set_rng_seed(ctx, params.seed); in batched the seed matter???????
|
// llama_set_rng_seed(ctx, params.seed); in batched the seed matter???????
|
||||||
}
|
}
|
||||||
|
|
||||||
bool loadGrammar(llama_token eos)
|
bool loadGrammar(llama_token eos)
|
||||||
{
|
{
|
||||||
if (!params.grammar.empty()) {
|
ctx_sampling = llama_sampling_init_srv(sparams, params.grammar, max_context_size);
|
||||||
parsed_grammar = grammar_parser::parse(params.grammar.c_str());
|
return ctx_sampling != nullptr;
|
||||||
// will be empty (default) if there are parse errors
|
|
||||||
if (parsed_grammar.rules.empty()) {
|
|
||||||
LOG_ERROR("grammar parse error", {{"grammar", params.grammar}});
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
grammar_parser::print_grammar(stderr, parsed_grammar);
|
|
||||||
|
|
||||||
{
|
|
||||||
auto it = sparams.logit_bias.find(eos);
|
|
||||||
if (it != sparams.logit_bias.end() && it->second == -INFINITY) {
|
|
||||||
LOG_WARNING("EOS token is disabled, which will cause most grammars to fail", {});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
|
|
||||||
grammar = llama_grammar_init(
|
|
||||||
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
|
|
||||||
}
|
|
||||||
ctx_sampling.params = sparams;
|
|
||||||
ctx_sampling.grammar = grammar;
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool hasBudget(gpt_params &global_params) {
|
bool hasBudget(gpt_params &global_params) {
|
||||||
@ -448,7 +453,6 @@ struct llama_server_context
|
|||||||
llama_model *model = nullptr;
|
llama_model *model = nullptr;
|
||||||
llama_context *ctx = nullptr;
|
llama_context *ctx = nullptr;
|
||||||
llama_batch batch;
|
llama_batch batch;
|
||||||
std::vector<llama_token_data> candidates;
|
|
||||||
bool all_slots_are_idle = false;
|
bool all_slots_are_idle = false;
|
||||||
gpt_params params;
|
gpt_params params;
|
||||||
int n_ctx;
|
int n_ctx;
|
||||||
@ -468,11 +472,6 @@ struct llama_server_context
|
|||||||
llama_free_model(model);
|
llama_free_model(model);
|
||||||
model = nullptr;
|
model = nullptr;
|
||||||
}
|
}
|
||||||
for(auto &slot : slots) {
|
|
||||||
if(slot.grammar) {
|
|
||||||
llama_grammar_free(slot.grammar);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool loadModel(const gpt_params ¶ms_)
|
bool loadModel(const gpt_params ¶ms_)
|
||||||
@ -510,7 +509,6 @@ struct llama_server_context
|
|||||||
}
|
}
|
||||||
n_ctx = llama_n_ctx(ctx);
|
n_ctx = llama_n_ctx(ctx);
|
||||||
n_vocab = llama_n_vocab(model);
|
n_vocab = llama_n_vocab(model);
|
||||||
candidates.reserve(n_vocab);
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -529,13 +527,12 @@ struct llama_server_context
|
|||||||
{
|
{
|
||||||
llama_client_slot slot;
|
llama_client_slot slot;
|
||||||
slot.id = i;
|
slot.id = i;
|
||||||
slot.last_n_tokens.resize(max_ctx_per_slot);
|
slot.max_context_size = max_ctx_per_slot;
|
||||||
std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end(), 0);
|
|
||||||
slot.reset();
|
slot.reset();
|
||||||
LOG_TEE(" -> Slot %i - max context: %i\n", slot.id, max_ctx_per_slot);
|
LOG_TEE(" -> Slot %i - max context: %i\n", slot.id, max_ctx_per_slot);
|
||||||
slots.push_back(slot);
|
slots.push_back(slot);
|
||||||
}
|
}
|
||||||
batch = llama_batch_init(n_ctx, 0);
|
batch = llama_batch_init(n_ctx, 0, 1);
|
||||||
// empty system prompt
|
// empty system prompt
|
||||||
system_prompt = "";
|
system_prompt = "";
|
||||||
num_tokens_system = 0;
|
num_tokens_system = 0;
|
||||||
@ -626,10 +623,7 @@ struct llama_server_context
|
|||||||
|
|
||||||
for (int32_t i = 0; i < batch.n_tokens; ++i)
|
for (int32_t i = 0; i < batch.n_tokens; ++i)
|
||||||
{
|
{
|
||||||
batch.token[i] = tokens_system[i];
|
llama_batch_add(batch, tokens_system[i], i, { 0 }, false);
|
||||||
batch.pos[i] = i;
|
|
||||||
batch.seq_id[i] = 0;
|
|
||||||
batch.logits[i] = false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (llama_decode(ctx, batch) != 0)
|
if (llama_decode(ctx, batch) != 0)
|
||||||
@ -726,8 +720,6 @@ struct llama_server_context
|
|||||||
|
|
||||||
bool processToken(completion_token_output & result, llama_client_slot & slot) {
|
bool processToken(completion_token_output & result, llama_client_slot & slot) {
|
||||||
// remember which tokens were sampled - used for repetition penalties during sampling
|
// remember which tokens were sampled - used for repetition penalties during sampling
|
||||||
slot.last_n_tokens.erase(slot.last_n_tokens.begin());
|
|
||||||
slot.last_n_tokens.push_back(result.tok);
|
|
||||||
const std::string token_str = llama_token_to_piece(ctx, result.tok);
|
const std::string token_str = llama_token_to_piece(ctx, result.tok);
|
||||||
slot.sampled = result.tok;
|
slot.sampled = result.tok;
|
||||||
|
|
||||||
@ -859,11 +851,12 @@ struct llama_server_context
|
|||||||
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
|
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
|
||||||
llama_batch batch_view = {
|
llama_batch batch_view = {
|
||||||
n_tokens,
|
n_tokens,
|
||||||
batch.token + i,
|
batch.token + i,
|
||||||
nullptr,
|
nullptr,
|
||||||
batch.pos + i,
|
batch.pos + i,
|
||||||
batch.seq_id + i,
|
batch.n_seq_id + i,
|
||||||
batch.logits + i,
|
batch.seq_id + i,
|
||||||
|
batch.logits + i,
|
||||||
0, 0, 0, // unused
|
0, 0, 0, // unused
|
||||||
};
|
};
|
||||||
if (llama_decode(ctx, batch_view)) {
|
if (llama_decode(ctx, batch_view)) {
|
||||||
@ -878,8 +871,8 @@ struct llama_server_context
|
|||||||
if (n_eval > n_batch) {
|
if (n_eval > n_batch) {
|
||||||
n_eval = n_batch;
|
n_eval = n_batch;
|
||||||
}
|
}
|
||||||
llama_batch batch = {int32_t(n_eval), nullptr, (img.image_embedding + i * n_embd), nullptr, nullptr, nullptr, slot.n_past, 1, 0, };
|
llama_batch batch_img = {int32_t(n_eval), nullptr, (img.image_embedding + i * n_embd), nullptr, nullptr, nullptr, nullptr, slot.n_past, 1, 0, };
|
||||||
if (llama_decode(ctx, batch)) {
|
if (llama_decode(ctx, batch_img)) {
|
||||||
LOG_TEE("%s : failed to eval image\n", __func__);
|
LOG_TEE("%s : failed to eval image\n", __func__);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -894,10 +887,7 @@ struct llama_server_context
|
|||||||
(json)(slot.images[image_idx].prefix_prompt);
|
(json)(slot.images[image_idx].prefix_prompt);
|
||||||
std::vector<llama_token> append_tokens = tokenize(json_prompt, false); // has next image
|
std::vector<llama_token> append_tokens = tokenize(json_prompt, false); // has next image
|
||||||
for (int i = 0; i < append_tokens.size(); ++i) {
|
for (int i = 0; i < append_tokens.size(); ++i) {
|
||||||
batch.token [batch.n_tokens] = append_tokens[i];
|
llama_batch_add(batch, append_tokens[i], slot.n_past, { slot.id }, true);
|
||||||
batch.pos [batch.n_tokens] = slot.n_past;
|
|
||||||
batch.seq_id[batch.n_tokens] = slot.id;
|
|
||||||
batch.logits[batch.n_tokens] = false;
|
|
||||||
slot.n_past += 1;
|
slot.n_past += 1;
|
||||||
batch.n_tokens += 1;
|
batch.n_tokens += 1;
|
||||||
}
|
}
|
||||||
@ -922,7 +912,6 @@ struct llama_server_context
|
|||||||
std::this_thread::sleep_for(std::chrono::milliseconds(5));
|
std::this_thread::sleep_for(std::chrono::milliseconds(5));
|
||||||
}
|
}
|
||||||
|
|
||||||
// context shift takes effect only when there is a single slot
|
|
||||||
for(llama_client_slot &slot : slots) {
|
for(llama_client_slot &slot : slots) {
|
||||||
if (slot.isProcessing() && slot.cache_tokens.size() >= (size_t)max_ctx_per_slot)
|
if (slot.isProcessing() && slot.cache_tokens.size() >= (size_t)max_ctx_per_slot)
|
||||||
{
|
{
|
||||||
@ -976,16 +965,12 @@ struct llama_server_context
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
batch.token [batch.n_tokens] = slot.sampled;
|
slot.i_batch = batch.n_tokens;
|
||||||
batch.pos [batch.n_tokens] = num_tokens_system + slot.n_past;
|
|
||||||
batch.seq_id[batch.n_tokens] = slot.id;
|
llama_batch_add(batch, slot.sampled, num_tokens_system + slot.n_past, { slot.id }, true);
|
||||||
batch.logits[batch.n_tokens] = true;
|
|
||||||
|
|
||||||
slot.n_decoded += 1;
|
slot.n_decoded += 1;
|
||||||
slot.i_batch = batch.n_tokens;
|
|
||||||
slot.n_past += 1;
|
slot.n_past += 1;
|
||||||
|
|
||||||
batch.n_tokens += 1;
|
|
||||||
}
|
}
|
||||||
// process in chunks of params.n_batch
|
// process in chunks of params.n_batch
|
||||||
int32_t n_batch = params.n_batch;
|
int32_t n_batch = params.n_batch;
|
||||||
@ -1026,7 +1011,7 @@ struct llama_server_context
|
|||||||
slot.num_prompt_tokens = prompt_tokens.size();
|
slot.num_prompt_tokens = prompt_tokens.size();
|
||||||
|
|
||||||
if(!slot.params.cache_prompt) {
|
if(!slot.params.cache_prompt) {
|
||||||
std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end(), 0);
|
std::fill(slot.ctx_sampling->prev.begin(), slot.ctx_sampling->prev.end(), 0);
|
||||||
slot.n_past = 0;
|
slot.n_past = 0;
|
||||||
slot.num_prompt_tokens_processed = slot.num_prompt_tokens;
|
slot.num_prompt_tokens_processed = slot.num_prompt_tokens;
|
||||||
} else {
|
} else {
|
||||||
@ -1038,23 +1023,27 @@ struct llama_server_context
|
|||||||
//if input prompt is too big, truncate like normal
|
//if input prompt is too big, truncate like normal
|
||||||
if (slot.num_prompt_tokens >= (size_t)max_ctx_per_slot)
|
if (slot.num_prompt_tokens >= (size_t)max_ctx_per_slot)
|
||||||
{
|
{
|
||||||
|
// applied bug of #3661
|
||||||
const int n_left = max_ctx_per_slot - slot.params.n_keep;
|
const int n_left = max_ctx_per_slot - slot.params.n_keep;
|
||||||
|
const int n_block_size = n_left / 2;
|
||||||
|
const int erased_blocks = (slot.num_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
|
||||||
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + slot.params.n_keep);
|
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + slot.params.n_keep);
|
||||||
// Use half the left-over space in the context for the prompt
|
// Use half the left-over space in the context for the prompt
|
||||||
new_tokens.insert(new_tokens.end(), prompt_tokens.end() - n_left / 2, prompt_tokens.end());
|
new_tokens.insert(new_tokens.end(), prompt_tokens.end() + slot.params.n_keep + erased_blocks * n_block_size, prompt_tokens.end());
|
||||||
LOG_VERBOSE("input truncated", {
|
LOG_VERBOSE("input truncated", {
|
||||||
{"n_ctx", n_ctx},
|
{"n_ctx", max_ctx_per_slot},
|
||||||
{"n_keep", params.n_keep},
|
{"n_keep", slot.params.n_keep},
|
||||||
{"n_left", n_left},
|
{"n_left", n_left},
|
||||||
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
|
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
|
||||||
});
|
});
|
||||||
slot.truncated = true;
|
slot.truncated = true;
|
||||||
prompt_tokens = new_tokens;
|
prompt_tokens = new_tokens;
|
||||||
slot.num_prompt_tokens = prompt_tokens.size();
|
slot.num_prompt_tokens = prompt_tokens.size();
|
||||||
|
GGML_ASSERT(slot.num_prompt_tokens < (size_t)max_ctx_per_slot);
|
||||||
}
|
}
|
||||||
const size_t ps = slot.num_prompt_tokens;
|
const size_t ps = slot.num_prompt_tokens;
|
||||||
std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end() - ps, 0);
|
std::fill(slot.ctx_sampling->prev.begin(), slot.ctx_sampling->prev.end() - ps, 0);
|
||||||
std::copy(prompt_tokens.begin(), prompt_tokens.end(), slot.last_n_tokens.end() - ps);
|
std::copy(prompt_tokens.begin(), prompt_tokens.end(), slot.ctx_sampling->prev.end() - ps);
|
||||||
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
|
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
|
||||||
slot.num_prompt_tokens_processed = slot.num_prompt_tokens - slot.n_past;
|
slot.num_prompt_tokens_processed = slot.num_prompt_tokens - slot.n_past;
|
||||||
LOG_TEE("slot %i - in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed);
|
LOG_TEE("slot %i - in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed);
|
||||||
@ -1081,11 +1070,7 @@ struct llama_server_context
|
|||||||
// process the prefix of first image
|
// process the prefix of first image
|
||||||
std::vector<llama_token> prefix_tokens = ingest_images ? tokenize(slot.images[0].prefix_prompt, true) : prompt_tokens;
|
std::vector<llama_token> prefix_tokens = ingest_images ? tokenize(slot.images[0].prefix_prompt, true) : prompt_tokens;
|
||||||
for (; slot.n_past < prefix_tokens.size(); ++slot.n_past) {
|
for (; slot.n_past < prefix_tokens.size(); ++slot.n_past) {
|
||||||
batch.token [batch.n_tokens] = prefix_tokens[slot.n_past];
|
llama_batch_add(batch, prefix_tokens[slot.n_past], num_tokens_system + slot.n_past, { slot.id }, false);
|
||||||
batch.pos [batch.n_tokens] = slot.n_past + num_tokens_system;
|
|
||||||
batch.seq_id[batch.n_tokens] = slot.id;
|
|
||||||
batch.logits[batch.n_tokens] = false;
|
|
||||||
batch.n_tokens += 1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if(ingest_images && !ingestImages(slot, n_batch)) {
|
if(ingest_images && !ingestImages(slot, n_batch)) {
|
||||||
@ -1113,11 +1098,12 @@ struct llama_server_context
|
|||||||
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
|
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
|
||||||
llama_batch batch_view = {
|
llama_batch batch_view = {
|
||||||
n_tokens,
|
n_tokens,
|
||||||
batch.token + i,
|
batch.token + i,
|
||||||
nullptr,
|
nullptr,
|
||||||
batch.pos + i,
|
batch.pos + i,
|
||||||
batch.seq_id + i,
|
batch.n_seq_id + i,
|
||||||
batch.logits + i,
|
batch.seq_id + i,
|
||||||
|
batch.logits + i,
|
||||||
0, 0, 0, // unused
|
0, 0, 0, // unused
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -1150,25 +1136,27 @@ struct llama_server_context
|
|||||||
}
|
}
|
||||||
|
|
||||||
completion_token_output result;
|
completion_token_output result;
|
||||||
const llama_token id = llama_sampling_sample(ctx, NULL, slot.ctx_sampling, slot.last_n_tokens, candidates, slot.i_batch - i);
|
const llama_token id = llama_sampling_sample(slot.ctx_sampling, ctx, NULL, slot.i_batch - i);
|
||||||
|
|
||||||
|
llama_sampling_accept(slot.ctx_sampling, ctx, id);
|
||||||
|
|
||||||
if (slot.n_decoded == 1) {
|
if (slot.n_decoded == 1) {
|
||||||
slot.t_start_genereration = ggml_time_us();
|
slot.t_start_genereration = ggml_time_us();
|
||||||
slot.t_prompt_processing = (slot.t_start_genereration - slot.t_start_process_prompt) / 1e3;
|
slot.t_prompt_processing = (slot.t_start_genereration - slot.t_start_process_prompt) / 1e3;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
llama_token_data_array cur_p = { slot.ctx_sampling->cur.data(), slot.ctx_sampling->cur.size(), false };
|
||||||
result.tok = id;
|
result.tok = id;
|
||||||
const int32_t n_probs = slot.sparams.n_probs;
|
const int32_t n_probs = slot.sparams.n_probs;
|
||||||
if (slot.sparams.temp <= 0 && n_probs > 0)
|
if (slot.sparams.temp <= 0 && n_probs > 0)
|
||||||
{
|
{
|
||||||
// For llama_sample_token_greedy we need to sort candidates
|
// For llama_sample_token_greedy we need to sort candidates
|
||||||
llama_sample_softmax(ctx, &candidates_p);
|
llama_sample_softmax(ctx, &cur_p);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (size_t i = 0; i < std::min(candidates_p.size, (size_t)n_probs); ++i)
|
for (size_t i = 0; i < std::min(cur_p.size, (size_t)n_probs); ++i)
|
||||||
{
|
{
|
||||||
result.probs.push_back({candidates_p.data[i].id, candidates_p.data[i].p});
|
result.probs.push_back({cur_p.data[i].id, cur_p.data[i].p});
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!processToken(result, slot)) {
|
if (!processToken(result, slot)) {
|
||||||
|
Loading…
Reference in New Issue
Block a user