mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 11:24:35 +00:00
latest changes of sampling API
This commit is contained in:
parent
8540568c48
commit
ab2fc00224
@ -3,6 +3,9 @@
|
||||
#include "build-info.h"
|
||||
#include "grammar-parser.h"
|
||||
|
||||
#include "../llava/clip.h"
|
||||
#include "stb_image.h"
|
||||
|
||||
#ifndef NDEBUG
|
||||
// crash the server in debug mode, otherwise send an http 500 error
|
||||
#define CPPHTTPLIB_NO_EXCEPTIONS 1
|
||||
@ -308,18 +311,14 @@ struct llama_client_slot
|
||||
int32_t n_remaining = -1;
|
||||
|
||||
json prompt;
|
||||
std::vector<llama_token> embd;
|
||||
std::vector<llama_token> last_n_tokens;
|
||||
|
||||
llama_model *model = nullptr;
|
||||
llama_context *ctx = nullptr;
|
||||
gpt_params params;
|
||||
llama_sampling_context ctx_sampling;
|
||||
int n_ctx;
|
||||
|
||||
grammar_parser::parse_state parsed_grammar;
|
||||
llama_grammar *grammar = nullptr;
|
||||
|
||||
std::string generated_text = "";
|
||||
int num_tokens_predicted = 0;
|
||||
llama_token sampled;
|
||||
std::vector<llama_token> cache_tokens;
|
||||
std::vector<completion_token_output> generated_token_probs;
|
||||
int sent_tokens = 0;
|
||||
slot_state state = IDLE;
|
||||
slot_command command = NONE;
|
||||
bool truncated = false;
|
||||
bool stopped_eos = false;
|
||||
bool stopped_word = false;
|
||||
@ -475,34 +474,22 @@ struct llama_server_context
|
||||
}
|
||||
}
|
||||
|
||||
bool loadModel(const gpt_params ¶ms_)
|
||||
{
|
||||
params.antiprompt.clear();
|
||||
params.grammar.clear();
|
||||
num_prompt_tokens = 0;
|
||||
num_tokens_predicted = 0;
|
||||
generated_text = "";
|
||||
generated_text.reserve(n_ctx);
|
||||
generated_token_probs.clear();
|
||||
truncated = false;
|
||||
stopped_eos = false;
|
||||
stopped_word = false;
|
||||
stopped_limit = false;
|
||||
stopping_word = "";
|
||||
multibyte_pending = 0;
|
||||
n_remain = 0;
|
||||
n_past = 0;
|
||||
|
||||
if (grammar != nullptr) {
|
||||
llama_grammar_free(grammar);
|
||||
grammar = nullptr;
|
||||
ctx_sampling = llama_sampling_context_init(params, NULL);
|
||||
}
|
||||
}
|
||||
|
||||
bool loadModel(const gpt_params ¶ms_)
|
||||
{
|
||||
params = params_;
|
||||
if(!params.mmproj.empty()) {
|
||||
multimodal = true;
|
||||
LOG_TEE("Multi Modal Mode Enabled");
|
||||
clp_ctx = clip_model_load(params.mmproj.c_str(), /*verbosity=*/ 1);
|
||||
if(clp_ctx == nullptr) {
|
||||
LOG_ERROR("unable to load clip model", {{"model", params.mmproj}});
|
||||
return false;
|
||||
}
|
||||
|
||||
if(params.n_ctx < 2048) { // request larger context for the image embedding
|
||||
params.n_ctx = 2048;
|
||||
}
|
||||
}
|
||||
std::tie(model, ctx) = llama_init_from_gpt_params(params);
|
||||
if (model == nullptr)
|
||||
{
|
||||
@ -521,8 +508,7 @@ struct llama_server_context
|
||||
}
|
||||
}
|
||||
n_ctx = llama_n_ctx(ctx);
|
||||
last_n_tokens.resize(n_ctx);
|
||||
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
|
||||
n_vocab = llama_n_vocab(model);
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -597,29 +583,24 @@ struct llama_server_context
|
||||
return prompt_tokens;
|
||||
}
|
||||
|
||||
bool loadGrammar()
|
||||
{
|
||||
if (!params.grammar.empty()) {
|
||||
parsed_grammar = grammar_parser::parse(params.grammar.c_str());
|
||||
// will be empty (default) if there are parse errors
|
||||
if (parsed_grammar.rules.empty()) {
|
||||
LOG_ERROR("grammar parse error", {{"grammar", params.grammar}});
|
||||
return false;
|
||||
}
|
||||
grammar_parser::print_grammar(stderr, parsed_grammar);
|
||||
|
||||
llama_client_slot* getSlot(int id) {
|
||||
for (llama_client_slot & slot : slots)
|
||||
{
|
||||
if ((id == -1 && slot.available()) || slot.id == id)
|
||||
{
|
||||
auto it = params.sampling_params.logit_bias.find(llama_token_eos(ctx));
|
||||
if (it != params.sampling_params.logit_bias.end() && it->second == -INFINITY) {
|
||||
LOG_WARNING("EOS token is disabled, which will cause most grammars to fail", {});
|
||||
}
|
||||
return &slot;
|
||||
}
|
||||
|
||||
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
|
||||
grammar = llama_grammar_init(
|
||||
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
|
||||
}
|
||||
ctx_sampling = llama_sampling_context_init(params, grammar);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool launchSlot(llama_client_slot* &slot) {
|
||||
if(!slot->loadGrammar(llama_token_eos(ctx))) {
|
||||
return false;
|
||||
}
|
||||
all_slots_are_idle = false;
|
||||
slot->command = LOAD_PROMPT;
|
||||
LOG_TEE("slot %i is processing\n", slot->id);
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -665,119 +646,253 @@ struct llama_server_context
|
||||
// release all slots
|
||||
for (llama_client_slot &slot : slots)
|
||||
{
|
||||
params.n_keep = (int)num_prompt_tokens;
|
||||
slot.release();
|
||||
}
|
||||
params.n_keep = std::min(params.n_ctx - 4, params.n_keep);
|
||||
waitAllAreIdle();
|
||||
all_slots_are_idle = true;
|
||||
|
||||
// if input prompt is too big, truncate like normal
|
||||
if (num_prompt_tokens >= (size_t)params.n_ctx)
|
||||
{
|
||||
printf("Input prompt is too big, truncating. Can only take %d tokens but got %zu\n", params.n_ctx, num_prompt_tokens);
|
||||
// todo we probably want to cut from both sides
|
||||
const int n_left = (params.n_ctx - params.n_keep) / 2;
|
||||
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
|
||||
const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left;
|
||||
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end());
|
||||
std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), last_n_tokens.begin());
|
||||
|
||||
LOG_VERBOSE("input truncated", {
|
||||
{"n_ctx", params.n_ctx},
|
||||
{"n_keep", params.n_keep},
|
||||
{"n_left", n_left},
|
||||
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
|
||||
});
|
||||
|
||||
truncated = true;
|
||||
prompt_tokens = new_tokens;
|
||||
// wait until system prompt load
|
||||
update_system_prompt = true;
|
||||
while(update_system_prompt) {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(5));
|
||||
}
|
||||
else
|
||||
{
|
||||
const size_t ps = num_prompt_tokens;
|
||||
std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0);
|
||||
std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps);
|
||||
}
|
||||
|
||||
// compare the evaluated prompt with the new prompt
|
||||
n_past = common_part(embd, prompt_tokens);
|
||||
embd = prompt_tokens;
|
||||
|
||||
if (n_past == num_prompt_tokens)
|
||||
{
|
||||
// we have to evaluate at least 1 token to generate logits.
|
||||
printf("we have to evaluate at least 1 token to generate logits\n");
|
||||
n_past--;
|
||||
}
|
||||
|
||||
// since #3228 we now have to manually manage the KV cache
|
||||
llama_kv_cache_seq_rm(ctx, 0, n_past, -1);
|
||||
|
||||
LOG_VERBOSE("prompt ingested", {
|
||||
{"n_past", n_past},
|
||||
{"cached", tokens_to_str(ctx, embd.cbegin(), embd.cbegin() + n_past)},
|
||||
{"to_eval", tokens_to_str(ctx, embd.cbegin() + n_past, embd.cend())},
|
||||
});
|
||||
|
||||
has_next_token = true;
|
||||
// system prompt loaded, continue
|
||||
}
|
||||
void loadPrompt()
|
||||
|
||||
void processSystemPromptData(json sys_props) {
|
||||
system_prompt = sys_props.value("prompt", "");
|
||||
user_name = sys_props.value("anti_prompt", "");
|
||||
assistant_name = sys_props.value("assistant_name", "");
|
||||
if(slots.size() > 0) {
|
||||
notifySystemPromptChanged();
|
||||
} else {
|
||||
update_system_prompt = true;
|
||||
}
|
||||
}
|
||||
|
||||
void waitAllAreIdle() {
|
||||
bool wait = true;
|
||||
while(wait) {
|
||||
wait = false;
|
||||
for (auto &slot : slots)
|
||||
{
|
||||
if (!slot.available())
|
||||
{
|
||||
wait = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
size_t findStoppingStrings(const std::string &text, const size_t last_token_size,
|
||||
const stop_type type, llama_client_slot &slot)
|
||||
{
|
||||
auto prompt_tokens = tokenize(prompt, true); // always add BOS
|
||||
size_t stop_pos = std::string::npos;
|
||||
for (const std::string &word : slot.params.antiprompt)
|
||||
{
|
||||
size_t pos;
|
||||
if (type == STOP_FULL)
|
||||
{
|
||||
const size_t tmp = word.size() + last_token_size;
|
||||
const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0;
|
||||
pos = text.find(word, from_pos);
|
||||
}
|
||||
else
|
||||
{
|
||||
pos = find_partial_stop_string(word, text);
|
||||
}
|
||||
if (pos != std::string::npos &&
|
||||
(stop_pos == std::string::npos || pos < stop_pos))
|
||||
{
|
||||
if (type == STOP_FULL)
|
||||
{
|
||||
slot.stopped_word = true;
|
||||
slot.stopping_word = word;
|
||||
slot.has_next_token = false;
|
||||
}
|
||||
stop_pos = pos;
|
||||
|
||||
num_prompt_tokens = prompt_tokens.size();
|
||||
}
|
||||
}
|
||||
return stop_pos;
|
||||
}
|
||||
|
||||
if (params.n_keep < 0)
|
||||
bool processToken(completion_token_output & result, llama_client_slot & slot) {
|
||||
// remember which tokens were sampled - used for repetition penalties during sampling
|
||||
const std::string token_str = llama_token_to_piece(ctx, result.tok);
|
||||
slot.sampled = result.tok;
|
||||
|
||||
// search stop word and delete it
|
||||
slot.generated_text += token_str;
|
||||
slot.has_next_token = true;
|
||||
|
||||
size_t pos = std::min(slot.sent_count, slot.generated_text.size());
|
||||
const std::string str_test = slot.generated_text.substr(pos);
|
||||
bool is_stop_full = false;
|
||||
size_t stop_pos = findStoppingStrings(str_test, token_str.size(), STOP_FULL, slot);
|
||||
if (stop_pos != std::string::npos) {
|
||||
is_stop_full = true;
|
||||
slot.generated_text.erase(
|
||||
slot.generated_text.begin() + pos + stop_pos,
|
||||
slot.generated_text.end());
|
||||
pos = std::min(slot.sent_count, slot.generated_text.size());
|
||||
} else {
|
||||
is_stop_full = false;
|
||||
stop_pos = findStoppingStrings(str_test, token_str.size(),
|
||||
STOP_PARTIAL, slot);
|
||||
}
|
||||
|
||||
// check if there is any token to predict
|
||||
if(stop_pos == std::string::npos || (!slot.has_next_token && !is_stop_full && stop_pos > 0)) {
|
||||
// no send the stop word in the response
|
||||
result.text_to_send = slot.generated_text.substr(pos, std::string::npos);
|
||||
slot.sent_count += result.text_to_send.size();
|
||||
// add the token to slot queue and cache
|
||||
}
|
||||
slot.addTokenString(result);
|
||||
if (slot.multibyte_pending > 0)
|
||||
{
|
||||
slot.multibyte_pending -= token_str.size();
|
||||
}
|
||||
else if (token_str.size() == 1)
|
||||
{
|
||||
const char c = token_str[0];
|
||||
// 2-byte characters: 110xxxxx 10xxxxxx
|
||||
if ((c & 0xE0) == 0xC0)
|
||||
{
|
||||
slot.multibyte_pending = 1;
|
||||
// 3-byte characters: 1110xxxx 10xxxxxx 10xxxxxx
|
||||
}
|
||||
else if ((c & 0xF0) == 0xE0)
|
||||
{
|
||||
slot.multibyte_pending = 2;
|
||||
// 4-byte characters: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
|
||||
}
|
||||
else if ((c & 0xF8) == 0xF0)
|
||||
{
|
||||
slot.multibyte_pending = 3;
|
||||
}
|
||||
else
|
||||
{
|
||||
slot.multibyte_pending = 0;
|
||||
}
|
||||
}
|
||||
|
||||
if (slot.multibyte_pending > 0 && !slot.has_next_token)
|
||||
{
|
||||
slot.has_next_token = true;
|
||||
}
|
||||
|
||||
// check the limits
|
||||
if (
|
||||
slot.n_decoded > 2 && slot.has_next_token && !slot.hasBudget(params))
|
||||
{
|
||||
slot.stopped_limit = true;
|
||||
slot.has_next_token = false;
|
||||
}
|
||||
params.n_keep = std::min(n_ctx - 4, params.n_keep);
|
||||
|
||||
// if input prompt is too big, truncate like normal
|
||||
if (num_prompt_tokens >= (size_t)n_ctx)
|
||||
{
|
||||
const int n_left = (n_ctx - params.n_keep) / 2;
|
||||
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
|
||||
const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left;
|
||||
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end());
|
||||
std::copy(prompt_tokens.end() - n_ctx, prompt_tokens.end(), last_n_tokens.begin());
|
||||
|
||||
LOG_VERBOSE("input truncated", {
|
||||
{"n_ctx", n_ctx},
|
||||
{"n_keep", params.n_keep},
|
||||
{"n_left", n_left},
|
||||
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
|
||||
});
|
||||
|
||||
truncated = true;
|
||||
prompt_tokens = new_tokens;
|
||||
}
|
||||
else
|
||||
{
|
||||
const size_t ps = num_prompt_tokens;
|
||||
std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0);
|
||||
std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps);
|
||||
if (!slot.cache_tokens.empty() && result.tok == llama_token_eos(ctx)){
|
||||
slot.stopped_eos = true;
|
||||
slot.has_next_token = false;
|
||||
LOG_VERBOSE("eos token found", {});
|
||||
}
|
||||
|
||||
// compare the evaluated prompt with the new prompt
|
||||
n_past = common_part(embd, prompt_tokens);
|
||||
LOG_VERBOSE("next token", {
|
||||
{"token", result.tok},
|
||||
{"token_text", tokens_to_output_formatted_string(ctx, result.tok)},
|
||||
{"has_next_token", slot.has_next_token},
|
||||
{"n_remain", slot.n_remaining},
|
||||
{"num_tokens_predicted", slot.num_tokens_predicted},
|
||||
{"stopped_eos", slot.stopped_eos},
|
||||
{"stopped_word", slot.stopped_word},
|
||||
{"stopped_limit", slot.stopped_limit},
|
||||
{"stopping_word", slot.stopping_word},
|
||||
});
|
||||
return slot.has_next_token; // continue
|
||||
}
|
||||
|
||||
embd = prompt_tokens;
|
||||
if (n_past == num_prompt_tokens)
|
||||
{
|
||||
// we have to evaluate at least 1 token to generate logits.
|
||||
n_past--;
|
||||
bool processImages(llama_client_slot &slot) {
|
||||
for(slot_image &img : slot.images) {
|
||||
if(!img.request_encode_image) {
|
||||
continue;
|
||||
}
|
||||
clip_image_f32 img_res;
|
||||
if (!clip_image_preprocess(clp_ctx, &img.img_data, &img_res, /*pad2square =*/ true)) {
|
||||
LOG_TEE("Error processing the given image");
|
||||
clip_free(clp_ctx);
|
||||
return false;
|
||||
}
|
||||
img.image_tokens = clip_n_patches(clp_ctx);
|
||||
img.image_embedding = (float *)malloc(clip_embd_nbytes(clp_ctx));
|
||||
if (!img.image_embedding) {
|
||||
LOG_TEE("Unable to allocate memory for image embeddings\n");
|
||||
clip_free(clp_ctx);
|
||||
return false;
|
||||
}
|
||||
LOG_TEE("slot %i - encoding image %i\n", slot.id, img.id);
|
||||
if (!clip_image_encode(clp_ctx, params.n_threads, &img_res, img.image_embedding)) {
|
||||
LOG_TEE("Unable to encode image\n");
|
||||
return false;
|
||||
}
|
||||
img.request_encode_image = false;
|
||||
}
|
||||
return slot.images.size() > 0;
|
||||
}
|
||||
|
||||
// since #3228 we now have to manually manage the KV cache
|
||||
llama_kv_cache_seq_rm(ctx, 0, n_past, -1);
|
||||
// for multiple images processing
|
||||
bool ingestImages(llama_client_slot &slot, int n_batch) {
|
||||
int image_idx = 0;
|
||||
while(image_idx < slot.images.size()) {
|
||||
slot_image img = slot.images[image_idx];
|
||||
|
||||
LOG_VERBOSE("prompt ingested", {
|
||||
{"n_past", n_past},
|
||||
{"cached", tokens_to_str(ctx, embd.cbegin(), embd.cbegin() + n_past)},
|
||||
{"to_eval", tokens_to_str(ctx, embd.cbegin() + n_past, embd.cend())},
|
||||
});
|
||||
// process prefix prompt
|
||||
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
|
||||
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
|
||||
llama_batch batch_view = {
|
||||
n_tokens,
|
||||
batch.token + i,
|
||||
nullptr,
|
||||
batch.pos + i,
|
||||
batch.n_seq_id + i,
|
||||
batch.seq_id + i,
|
||||
batch.logits + i,
|
||||
0, 0, 0, // unused
|
||||
};
|
||||
if (llama_decode(ctx, batch_view)) {
|
||||
LOG_TEE("%s : failed to eval\n", __func__);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
has_next_token = true;
|
||||
// process image with llm
|
||||
for (int i = 0; i < img.image_tokens; i += n_batch) {
|
||||
int n_eval = img.image_tokens - i;
|
||||
if (n_eval > n_batch) {
|
||||
n_eval = n_batch;
|
||||
}
|
||||
llama_batch batch_img = {int32_t(n_eval), nullptr, (img.image_embedding + i * n_embd), nullptr, nullptr, nullptr, nullptr, slot.n_past, 1, 0, };
|
||||
if (llama_decode(ctx, batch_img)) {
|
||||
LOG_TEE("%s : failed to eval image\n", __func__);
|
||||
return false;
|
||||
}
|
||||
slot.n_past += n_eval;
|
||||
}
|
||||
image_idx++;
|
||||
|
||||
// append prefix of next image
|
||||
batch.n_tokens = 0;
|
||||
const auto json_prompt = (image_idx >= slot.images.size()) ?
|
||||
slot.params.input_suffix : // no more images, then process suffix prompt
|
||||
(json)(slot.images[image_idx].prefix_prompt);
|
||||
std::vector<llama_token> append_tokens = tokenize(json_prompt, false); // has next image
|
||||
for (int i = 0; i < append_tokens.size(); ++i) {
|
||||
llama_batch_add(batch, append_tokens[i], slot.n_past, { slot.id }, true);
|
||||
slot.n_past += 1;
|
||||
batch.n_tokens += 1;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool updateSlots() {
|
||||
@ -814,159 +929,198 @@ struct llama_server_context
|
||||
|
||||
slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard);
|
||||
|
||||
truncated = true;
|
||||
LOG_VERBOSE("input truncated", {
|
||||
{"n_ctx", n_ctx},
|
||||
{"n_keep", params.n_keep},
|
||||
{"n_left", n_left},
|
||||
});
|
||||
}
|
||||
slot.n_past -= n_discard;
|
||||
|
||||
bool tg = true;
|
||||
while (n_past < embd.size())
|
||||
{
|
||||
int n_eval = (int)embd.size() - n_past;
|
||||
tg = n_eval == 1;
|
||||
if (n_eval > params.n_batch)
|
||||
{
|
||||
n_eval = params.n_batch;
|
||||
}
|
||||
slot.truncated = true;
|
||||
|
||||
if (llama_decode(ctx, llama_batch_get_one(&embd[n_past], n_eval, n_past, 0)))
|
||||
{
|
||||
LOG_ERROR("failed to eval", {
|
||||
{"n_eval", n_eval},
|
||||
{"n_past", n_past},
|
||||
{"embd", tokens_to_str(ctx, embd.cbegin() + n_past, embd.cend())},
|
||||
LOG_VERBOSE("input truncated", {
|
||||
{"n_ctx", n_ctx},
|
||||
{"n_keep", params.n_keep},
|
||||
{"n_left", n_left},
|
||||
});
|
||||
has_next_token = false;
|
||||
return result;
|
||||
}
|
||||
n_past += n_eval;
|
||||
}
|
||||
|
||||
if (params.n_predict == 0)
|
||||
{
|
||||
has_next_token = false;
|
||||
result.tok = llama_token_eos(ctx);
|
||||
return result;
|
||||
}
|
||||
|
||||
{
|
||||
// out of user input, sample next token
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(llama_n_vocab(model));
|
||||
|
||||
result.tok = llama_sampling_sample(ctx, NULL, ctx_sampling, last_n_tokens, candidates);
|
||||
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
|
||||
const int32_t n_probs = params.sampling_params.n_probs;
|
||||
if (params.sampling_params.temp <= 0 && n_probs > 0)
|
||||
{
|
||||
// For llama_sample_token_greedy we need to sort candidates
|
||||
llama_sample_softmax(ctx, &candidates_p);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < std::min(candidates_p.size, (size_t)n_probs); ++i)
|
||||
{
|
||||
result.probs.push_back({candidates_p.data[i].id, candidates_p.data[i].p});
|
||||
}
|
||||
|
||||
last_n_tokens.erase(last_n_tokens.begin());
|
||||
last_n_tokens.push_back(result.tok);
|
||||
if (tg) {
|
||||
num_tokens_predicted++;
|
||||
}
|
||||
}
|
||||
|
||||
// add it to the context
|
||||
embd.push_back(result.tok);
|
||||
// decrement remaining sampling budget
|
||||
--n_remain;
|
||||
|
||||
if (!embd.empty() && embd.back() == llama_token_eos(ctx))
|
||||
{
|
||||
// stopping_word = llama_token_to_piece(ctx, embd.back());
|
||||
has_next_token = false;
|
||||
stopped_eos = true;
|
||||
LOG_VERBOSE("eos token found", {});
|
||||
return result;
|
||||
}
|
||||
|
||||
has_next_token = params.n_predict == -1 || n_remain != 0;
|
||||
return result;
|
||||
}
|
||||
|
||||
size_t findStoppingStrings(const std::string &text, const size_t last_token_size,
|
||||
const stop_type type)
|
||||
{
|
||||
size_t stop_pos = std::string::npos;
|
||||
for (const std::string &word : params.antiprompt)
|
||||
{
|
||||
size_t pos;
|
||||
if (type == STOP_FULL)
|
||||
// decode any currently ongoing sequences
|
||||
for (auto & slot : slots) {
|
||||
// release the slot
|
||||
if (slot.state == PROCESSING && slot.command == RELEASE && !slot.hasNewToken())
|
||||
{
|
||||
const size_t tmp = word.size() + last_token_size;
|
||||
const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0;
|
||||
pos = text.find(word, from_pos);
|
||||
}
|
||||
else
|
||||
{
|
||||
pos = find_partial_stop_string(word, text);
|
||||
}
|
||||
if (pos != std::string::npos &&
|
||||
(stop_pos == std::string::npos || pos < stop_pos))
|
||||
{
|
||||
if (type == STOP_FULL)
|
||||
{
|
||||
stopping_word = word;
|
||||
stopped_word = true;
|
||||
has_next_token = false;
|
||||
slot.state = slot.params.cache_prompt ? SLEEPING : IDLE;
|
||||
if(slot.state == SLEEPING) {
|
||||
LOG_TEE("slot %i has %i tokens in cache.\n", slot.id, slot.cache_tokens.size());
|
||||
} else {
|
||||
LOG_TEE("slot %i released\n", slot.id);
|
||||
}
|
||||
slot.command = NONE;
|
||||
continue;
|
||||
}
|
||||
|
||||
kv_cache_free -= slot.num_prompt_tokens;
|
||||
|
||||
if (
|
||||
slot.state == IDLE ||
|
||||
slot.state == SLEEPING ||
|
||||
slot.command == RELEASE) {
|
||||
continue;
|
||||
}
|
||||
|
||||
slot.i_batch = batch.n_tokens;
|
||||
|
||||
llama_batch_add(batch, slot.sampled, num_tokens_system + slot.n_past, { slot.id }, true);
|
||||
|
||||
slot.n_decoded += 1;
|
||||
slot.n_past += 1;
|
||||
}
|
||||
// process in chunks of params.n_batch
|
||||
int32_t n_batch = params.n_batch;
|
||||
|
||||
// assign workload to the slots
|
||||
if (params.cont_batching || batch.n_tokens == 0) {
|
||||
for (auto & slot : slots) {
|
||||
// need process the prompt
|
||||
if ((slot.state == IDLE || slot.state == SLEEPING) && slot.command == LOAD_PROMPT) {
|
||||
slot.state = PROCESSING;
|
||||
slot.command = NONE;
|
||||
std::vector<llama_token> prompt_tokens;
|
||||
slot.t_start_process_prompt = ggml_time_us();
|
||||
slot.t_start_genereration = 0;
|
||||
|
||||
if(slot.infill) {
|
||||
bool suff_rm_leading_spc = true;
|
||||
if (params.input_suffix.find_first_of(" ") == 0 && params.input_suffix.size() > 1) {
|
||||
params.input_suffix.erase(0, 1);
|
||||
suff_rm_leading_spc = false;
|
||||
}
|
||||
auto prefix_tokens = tokenize(slot.params.input_prefix, false);
|
||||
auto suffix_tokens = tokenize(slot.params.input_suffix, false);
|
||||
const int space_token = 29871;
|
||||
if (suff_rm_leading_spc && suffix_tokens[0] == space_token) {
|
||||
suffix_tokens.erase(suffix_tokens.begin());
|
||||
}
|
||||
prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(ctx));
|
||||
prefix_tokens.insert(prefix_tokens.begin(), llama_token_bos(ctx)); // always add BOS
|
||||
prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(ctx));
|
||||
prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end());
|
||||
prefix_tokens.push_back(llama_token_middle(ctx));
|
||||
prompt_tokens = prefix_tokens;
|
||||
} else {
|
||||
prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt
|
||||
}
|
||||
|
||||
slot.num_prompt_tokens = prompt_tokens.size();
|
||||
|
||||
if(!slot.params.cache_prompt) {
|
||||
std::fill(slot.ctx_sampling->prev.begin(), slot.ctx_sampling->prev.end(), 0);
|
||||
slot.n_past = 0;
|
||||
slot.num_prompt_tokens_processed = slot.num_prompt_tokens;
|
||||
} else {
|
||||
if (slot.params.n_keep < 0)
|
||||
{
|
||||
slot.params.n_keep = (int)slot.num_prompt_tokens;
|
||||
}
|
||||
slot.params.n_keep = std::min(max_ctx_per_slot - 4, slot.params.n_keep);
|
||||
//if input prompt is too big, truncate like normal
|
||||
if (slot.num_prompt_tokens >= (size_t)max_ctx_per_slot)
|
||||
{
|
||||
// applied bug of #3661
|
||||
const int n_left = max_ctx_per_slot - slot.params.n_keep;
|
||||
const int n_block_size = n_left / 2;
|
||||
const int erased_blocks = (slot.num_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
|
||||
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + slot.params.n_keep);
|
||||
// Use half the left-over space in the context for the prompt
|
||||
new_tokens.insert(new_tokens.end(), prompt_tokens.end() + slot.params.n_keep + erased_blocks * n_block_size, prompt_tokens.end());
|
||||
LOG_VERBOSE("input truncated", {
|
||||
{"n_ctx", max_ctx_per_slot},
|
||||
{"n_keep", slot.params.n_keep},
|
||||
{"n_left", n_left},
|
||||
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
|
||||
});
|
||||
slot.truncated = true;
|
||||
prompt_tokens = new_tokens;
|
||||
slot.num_prompt_tokens = prompt_tokens.size();
|
||||
GGML_ASSERT(slot.num_prompt_tokens < (size_t)max_ctx_per_slot);
|
||||
}
|
||||
const size_t ps = slot.num_prompt_tokens;
|
||||
std::fill(slot.ctx_sampling->prev.begin(), slot.ctx_sampling->prev.end() - ps, 0);
|
||||
std::copy(prompt_tokens.begin(), prompt_tokens.end(), slot.ctx_sampling->prev.end() - ps);
|
||||
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
|
||||
slot.num_prompt_tokens_processed = slot.num_prompt_tokens - slot.n_past;
|
||||
LOG_TEE("slot %i - in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed);
|
||||
}
|
||||
|
||||
llama_kv_cache_seq_rm(ctx, slot.id, num_tokens_system + slot.n_past, -1);
|
||||
|
||||
slot.cache_tokens = prompt_tokens;
|
||||
|
||||
if (slot.n_past == slot.num_prompt_tokens) {
|
||||
// we have to evaluate at least 1 token to generate logits.
|
||||
printf("we have to evaluate at least 1 token to generate logits\n");
|
||||
slot.n_past--;
|
||||
}
|
||||
|
||||
LOG_VERBOSE("prompt ingested", {
|
||||
{"n_past", slot.n_past},
|
||||
{"cached", tokens_to_str(ctx, slot.cache_tokens.cbegin(), slot.cache_tokens.cbegin() + slot.n_past)},
|
||||
{"to_eval", tokens_to_str(ctx, slot.cache_tokens.cbegin() + slot.n_past, slot.cache_tokens.cend())},
|
||||
});
|
||||
|
||||
bool ingest_images = processImages(slot); // has images?
|
||||
|
||||
// process the prefix of first image
|
||||
std::vector<llama_token> prefix_tokens = ingest_images ? tokenize(slot.images[0].prefix_prompt, true) : prompt_tokens;
|
||||
for (; slot.n_past < prefix_tokens.size(); ++slot.n_past) {
|
||||
llama_batch_add(batch, prefix_tokens[slot.n_past], num_tokens_system + slot.n_past, { slot.id }, false);
|
||||
}
|
||||
|
||||
if(ingest_images && !ingestImages(slot, n_batch)) {
|
||||
LOG_TEE("failed processing images\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
// extract the logits only for the last token
|
||||
if (batch.n_tokens > 0) {
|
||||
batch.logits[batch.n_tokens - 1] = true;
|
||||
}
|
||||
|
||||
slot.n_decoded = 0;
|
||||
slot.i_batch = batch.n_tokens - 1;
|
||||
}
|
||||
stop_pos = pos;
|
||||
}
|
||||
}
|
||||
return stop_pos;
|
||||
}
|
||||
|
||||
completion_token_output doCompletion()
|
||||
{
|
||||
auto token_with_probs = nextToken();
|
||||
|
||||
const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_piece(ctx, token_with_probs.tok);
|
||||
generated_text += token_text;
|
||||
|
||||
if (params.sampling_params.n_probs > 0)
|
||||
{
|
||||
generated_token_probs.push_back(token_with_probs);
|
||||
if (batch.n_tokens == 0) {
|
||||
all_slots_are_idle = true;
|
||||
return true;
|
||||
}
|
||||
|
||||
if (multibyte_pending > 0)
|
||||
{
|
||||
multibyte_pending -= token_text.size();
|
||||
}
|
||||
else if (token_text.size() == 1)
|
||||
{
|
||||
const char c = token_text[0];
|
||||
// 2-byte characters: 110xxxxx 10xxxxxx
|
||||
if ((c & 0xE0) == 0xC0)
|
||||
{
|
||||
multibyte_pending = 1;
|
||||
// 3-byte characters: 1110xxxx 10xxxxxx 10xxxxxx
|
||||
}
|
||||
else if ((c & 0xF0) == 0xE0)
|
||||
{
|
||||
multibyte_pending = 2;
|
||||
// 4-byte characters: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
|
||||
}
|
||||
else if ((c & 0xF8) == 0xF0)
|
||||
{
|
||||
multibyte_pending = 3;
|
||||
}
|
||||
else
|
||||
{
|
||||
multibyte_pending = 0;
|
||||
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
|
||||
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
|
||||
llama_batch batch_view = {
|
||||
n_tokens,
|
||||
batch.token + i,
|
||||
nullptr,
|
||||
batch.pos + i,
|
||||
batch.n_seq_id + i,
|
||||
batch.seq_id + i,
|
||||
batch.logits + i,
|
||||
0, 0, 0, // unused
|
||||
};
|
||||
|
||||
const int ret = llama_decode(ctx, batch_view);
|
||||
if (ret != 0) {
|
||||
if (n_batch == 1 || ret < 0) {
|
||||
// if you get here, it means the KV cache is full - try increasing it via the context size
|
||||
LOG_TEE("%s : failed to decode the batch, n_batch = %d, ret = %d\n", __func__, n_batch, ret);
|
||||
return false;
|
||||
}
|
||||
|
||||
LOG("%s : failed to decode the batch, retrying with n_batch = %d\n", __func__, n_batch / 2);
|
||||
|
||||
// retry with half the batch size to try to find a free slot in the KV cache
|
||||
n_batch /= 2;
|
||||
i -= n_batch;
|
||||
continue;
|
||||
}
|
||||
|
||||
for (auto & slot : slots) {
|
||||
@ -1625,9 +1779,77 @@ static void parse_options_completion(const json &body, llama_client_slot* slot,
|
||||
}
|
||||
}
|
||||
|
||||
llama.ctx_sampling = llama_sampling_context_init(llama.params, llama.grammar);
|
||||
LOG_VERBOSE("completion parameters parsed", format_generation_settings(llama, slot));
|
||||
if(!llama.multimodal) {
|
||||
return;
|
||||
}
|
||||
|
||||
LOG_VERBOSE("completion parameters parsed", format_generation_settings(llama));
|
||||
const auto &images_data = body.find("image_data");
|
||||
if (images_data != body.end() && images_data->is_array())
|
||||
{
|
||||
for (const auto &img : *images_data)
|
||||
{
|
||||
slot_image img_sl;
|
||||
std::string data_b64 = img["data"].get<std::string>();
|
||||
img_sl.id = img.count("id") != 0 ? img["id"].get<int>() : slot->images.size();
|
||||
int width, height, channels;
|
||||
std::vector<uint8_t> image_buffer = base64_decode(data_b64);
|
||||
data_b64.clear();
|
||||
auto data = stbi_load_from_memory(image_buffer.data(), image_buffer.size(), &width, &height, &channels, 3);
|
||||
if(!data) {
|
||||
LOG_TEE("slot %i - failed to load image id= %i\n", slot->id, img_sl.id);
|
||||
return;
|
||||
}
|
||||
LOG_TEE("slot %i - image id= %i loaded (%i x %i)\n", slot->id, img_sl.id, width, height);
|
||||
img_sl.img_data.nx = width;
|
||||
img_sl.img_data.ny = height;
|
||||
img_sl.img_data.size = width * height * 3;
|
||||
img_sl.img_data.data = new uint8_t[width * height * 3]();
|
||||
memcpy(img_sl.img_data.data, data, width * height * 3);
|
||||
stbi_image_free(data);
|
||||
img_sl.request_encode_image = true;
|
||||
slot->images.push_back(img_sl);
|
||||
}
|
||||
// process prompt
|
||||
// example: system prompt [img-102] user [img-103] describe [img-134] -> [{id: 102, prefix: 'system prompt '}, {id: 103, prefix: ' user '}, {id: 134, prefix: ' describe '}]}
|
||||
if(slot->images.size() > 0 && !slot->prompt.is_array()) {
|
||||
std::string prompt = slot->prompt.get<std::string>();
|
||||
size_t pos = 0, begin_prefix = 0;
|
||||
std::string pattern = "[img-";
|
||||
while ((pos = prompt.find(pattern, pos)) != std::string::npos) {
|
||||
size_t end_prefix = pos;
|
||||
pos += pattern.length();
|
||||
size_t end_pos = prompt.find("]", pos);
|
||||
if (end_pos != std::string::npos) {
|
||||
std::string image_id = prompt.substr(pos, end_pos - pos);
|
||||
try {
|
||||
int img_id = std::stoi(image_id);
|
||||
bool found = false;
|
||||
for(slot_image &img : slot->images) {
|
||||
if(img.id == img_id) {
|
||||
found = true;
|
||||
img.prefix_prompt = prompt.substr(begin_prefix, end_prefix - begin_prefix);
|
||||
begin_prefix = end_pos + 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if(!found) {
|
||||
LOG_TEE("ERROR: Image with id %i not found.\n", img_id);
|
||||
slot->images.clear();
|
||||
return;
|
||||
}
|
||||
} catch (const std::invalid_argument& e) {
|
||||
LOG_TEE("Invalid image number id in prompt\n");
|
||||
slot->images.clear();
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
slot->prompt = "";
|
||||
slot->params.input_suffix = prompt.substr(begin_prefix);
|
||||
slot->params.cache_prompt = false; // multimodal doesn't support cache prompt
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void parse_options_infill(const json &body, llama_server_context &llama, llama_client_slot *slot)
|
||||
@ -2144,10 +2366,6 @@ int main(int argc, char **argv)
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (llama.grammar != nullptr) {
|
||||
llama_grammar_free(llama.grammar);
|
||||
}
|
||||
llama_backend_free();
|
||||
return 0;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user