new sampling API

This commit is contained in:
FSSRepo 2023-10-18 16:50:09 -04:00
parent 84b8f2b060
commit 7196c4e08a

View File

@ -125,7 +125,7 @@ enum slot_command {
struct slot_params { struct slot_params {
bool stream = true; bool stream = true;
uint32_t seed = -1; // RNG seed uint32_t seed = -1; // RNG seed
int n_keep = 0; // RNG seed int n_keep = 0; // number of tokens to keep from initial prompt
int32_t n_predict = -1; // new tokens to predict int32_t n_predict = -1; // new tokens to predict
std::string grammar = ""; // optional BNF-like grammar to constrain sampling std::string grammar = ""; // optional BNF-like grammar to constrain sampling
bool cache_prompt = false; // remember a the prompt to avoid reprocessing all prompt bool cache_prompt = false; // remember a the prompt to avoid reprocessing all prompt
@ -262,6 +262,34 @@ static json probs_vector_to_json(const llama_context *ctx, const std::vector<com
return out; return out;
} }
struct llama_sampling_context * llama_sampling_init_srv(const struct llama_sampling_params sparams, std::string grammar, int n_ctx) {
struct llama_sampling_context * result = new llama_sampling_context();
result->params = sparams;
result->grammar = nullptr;
// if there is a grammar, parse it
if (!grammar.empty()) {
result->parsed_grammar = grammar_parser::parse(grammar.c_str());
// will be empty (default) if there are parse errors
if (result->parsed_grammar.rules.empty()) {
fprintf(stderr, "%s: failed to parse grammar\n", __func__);
return nullptr;
}
std::vector<const llama_grammar_element *> grammar_rules(result->parsed_grammar.c_rules());
result->grammar = llama_grammar_init(
grammar_rules.data(),
grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root"));
}
result->prev.resize(n_ctx);
return result;
}
struct slot_image { struct slot_image {
clip_image_u8 img_data; clip_image_u8 img_data;
bool request_encode_image = false; bool request_encode_image = false;
@ -287,7 +315,6 @@ struct llama_client_slot
int num_tokens_predicted = 0; int num_tokens_predicted = 0;
llama_token sampled; llama_token sampled;
std::vector<llama_token> cache_tokens; std::vector<llama_token> cache_tokens;
std::vector<llama_token> last_n_tokens;
std::vector<completion_token_output> generated_token_probs; std::vector<completion_token_output> generated_token_probs;
int sent_tokens = 0; int sent_tokens = 0;
slot_state state = IDLE; slot_state state = IDLE;
@ -307,13 +334,12 @@ struct llama_client_slot
double t_token_generation; // ms double t_token_generation; // ms
struct slot_params params; struct slot_params params;
struct llama_sampling_params sparams;
llama_sampling_context ctx_sampling;
bool has_next_token = true;
// grammar props // sampling
grammar_parser::parse_state parsed_grammar; struct llama_sampling_params sparams;
llama_grammar *grammar = nullptr; llama_sampling_context* ctx_sampling = nullptr;
bool has_next_token = true;
int max_context_size = 0;
// multimodal // multimodal
std::vector<slot_image> images; std::vector<slot_image> images;
@ -332,47 +358,26 @@ struct llama_client_slot
infill = false; infill = false;
clean_tokens(); clean_tokens();
if (grammar != nullptr) { if (ctx_sampling != nullptr) {
llama_grammar_free(grammar); llama_sampling_free(ctx_sampling);
grammar = nullptr;
ctx_sampling.params = sparams;
ctx_sampling.grammar = NULL;
} }
ctx_sampling = llama_sampling_init_srv(sparams, params.grammar, max_context_size);
for(slot_image img : images) { for(slot_image img : images) {
free(img.image_embedding); free(img.image_embedding);
delete[] img.img_data.data; delete[] img.img_data.data;
img.prefix_prompt = ""; img.prefix_prompt = "";
} }
images.clear(); images.clear();
// llama_set_rng_seed(ctx, params.seed); in batched the seed matter??????? // llama_set_rng_seed(ctx, params.seed); in batched the seed matter???????
} }
bool loadGrammar(llama_token eos) bool loadGrammar(llama_token eos)
{ {
if (!params.grammar.empty()) { ctx_sampling = llama_sampling_init_srv(sparams, params.grammar, max_context_size);
parsed_grammar = grammar_parser::parse(params.grammar.c_str()); return ctx_sampling != nullptr;
// will be empty (default) if there are parse errors
if (parsed_grammar.rules.empty()) {
LOG_ERROR("grammar parse error", {{"grammar", params.grammar}});
return false;
}
grammar_parser::print_grammar(stderr, parsed_grammar);
{
auto it = sparams.logit_bias.find(eos);
if (it != sparams.logit_bias.end() && it->second == -INFINITY) {
LOG_WARNING("EOS token is disabled, which will cause most grammars to fail", {});
}
}
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
grammar = llama_grammar_init(
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
}
ctx_sampling.params = sparams;
ctx_sampling.grammar = grammar;
return true;
} }
bool hasBudget(gpt_params &global_params) { bool hasBudget(gpt_params &global_params) {
@ -448,7 +453,6 @@ struct llama_server_context
llama_model *model = nullptr; llama_model *model = nullptr;
llama_context *ctx = nullptr; llama_context *ctx = nullptr;
llama_batch batch; llama_batch batch;
std::vector<llama_token_data> candidates;
bool all_slots_are_idle = false; bool all_slots_are_idle = false;
gpt_params params; gpt_params params;
int n_ctx; int n_ctx;
@ -468,11 +472,6 @@ struct llama_server_context
llama_free_model(model); llama_free_model(model);
model = nullptr; model = nullptr;
} }
for(auto &slot : slots) {
if(slot.grammar) {
llama_grammar_free(slot.grammar);
}
}
} }
bool loadModel(const gpt_params &params_) bool loadModel(const gpt_params &params_)
@ -510,7 +509,6 @@ struct llama_server_context
} }
n_ctx = llama_n_ctx(ctx); n_ctx = llama_n_ctx(ctx);
n_vocab = llama_n_vocab(model); n_vocab = llama_n_vocab(model);
candidates.reserve(n_vocab);
return true; return true;
} }
@ -529,13 +527,12 @@ struct llama_server_context
{ {
llama_client_slot slot; llama_client_slot slot;
slot.id = i; slot.id = i;
slot.last_n_tokens.resize(max_ctx_per_slot); slot.max_context_size = max_ctx_per_slot;
std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end(), 0);
slot.reset(); slot.reset();
LOG_TEE(" -> Slot %i - max context: %i\n", slot.id, max_ctx_per_slot); LOG_TEE(" -> Slot %i - max context: %i\n", slot.id, max_ctx_per_slot);
slots.push_back(slot); slots.push_back(slot);
} }
batch = llama_batch_init(n_ctx, 0); batch = llama_batch_init(n_ctx, 0, 1);
// empty system prompt // empty system prompt
system_prompt = ""; system_prompt = "";
num_tokens_system = 0; num_tokens_system = 0;
@ -626,10 +623,7 @@ struct llama_server_context
for (int32_t i = 0; i < batch.n_tokens; ++i) for (int32_t i = 0; i < batch.n_tokens; ++i)
{ {
batch.token[i] = tokens_system[i]; llama_batch_add(batch, tokens_system[i], i, { 0 }, false);
batch.pos[i] = i;
batch.seq_id[i] = 0;
batch.logits[i] = false;
} }
if (llama_decode(ctx, batch) != 0) if (llama_decode(ctx, batch) != 0)
@ -726,8 +720,6 @@ struct llama_server_context
bool processToken(completion_token_output & result, llama_client_slot & slot) { bool processToken(completion_token_output & result, llama_client_slot & slot) {
// remember which tokens were sampled - used for repetition penalties during sampling // remember which tokens were sampled - used for repetition penalties during sampling
slot.last_n_tokens.erase(slot.last_n_tokens.begin());
slot.last_n_tokens.push_back(result.tok);
const std::string token_str = llama_token_to_piece(ctx, result.tok); const std::string token_str = llama_token_to_piece(ctx, result.tok);
slot.sampled = result.tok; slot.sampled = result.tok;
@ -859,11 +851,12 @@ struct llama_server_context
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
llama_batch batch_view = { llama_batch batch_view = {
n_tokens, n_tokens,
batch.token + i, batch.token + i,
nullptr, nullptr,
batch.pos + i, batch.pos + i,
batch.seq_id + i, batch.n_seq_id + i,
batch.logits + i, batch.seq_id + i,
batch.logits + i,
0, 0, 0, // unused 0, 0, 0, // unused
}; };
if (llama_decode(ctx, batch_view)) { if (llama_decode(ctx, batch_view)) {
@ -878,8 +871,8 @@ struct llama_server_context
if (n_eval > n_batch) { if (n_eval > n_batch) {
n_eval = n_batch; n_eval = n_batch;
} }
llama_batch batch = {int32_t(n_eval), nullptr, (img.image_embedding + i * n_embd), nullptr, nullptr, nullptr, slot.n_past, 1, 0, }; llama_batch batch_img = {int32_t(n_eval), nullptr, (img.image_embedding + i * n_embd), nullptr, nullptr, nullptr, nullptr, slot.n_past, 1, 0, };
if (llama_decode(ctx, batch)) { if (llama_decode(ctx, batch_img)) {
LOG_TEE("%s : failed to eval image\n", __func__); LOG_TEE("%s : failed to eval image\n", __func__);
return false; return false;
} }
@ -894,10 +887,7 @@ struct llama_server_context
(json)(slot.images[image_idx].prefix_prompt); (json)(slot.images[image_idx].prefix_prompt);
std::vector<llama_token> append_tokens = tokenize(json_prompt, false); // has next image std::vector<llama_token> append_tokens = tokenize(json_prompt, false); // has next image
for (int i = 0; i < append_tokens.size(); ++i) { for (int i = 0; i < append_tokens.size(); ++i) {
batch.token [batch.n_tokens] = append_tokens[i]; llama_batch_add(batch, append_tokens[i], slot.n_past, { slot.id }, true);
batch.pos [batch.n_tokens] = slot.n_past;
batch.seq_id[batch.n_tokens] = slot.id;
batch.logits[batch.n_tokens] = false;
slot.n_past += 1; slot.n_past += 1;
batch.n_tokens += 1; batch.n_tokens += 1;
} }
@ -922,7 +912,6 @@ struct llama_server_context
std::this_thread::sleep_for(std::chrono::milliseconds(5)); std::this_thread::sleep_for(std::chrono::milliseconds(5));
} }
// context shift takes effect only when there is a single slot
for(llama_client_slot &slot : slots) { for(llama_client_slot &slot : slots) {
if (slot.isProcessing() && slot.cache_tokens.size() >= (size_t)max_ctx_per_slot) if (slot.isProcessing() && slot.cache_tokens.size() >= (size_t)max_ctx_per_slot)
{ {
@ -976,16 +965,12 @@ struct llama_server_context
continue; continue;
} }
batch.token [batch.n_tokens] = slot.sampled; slot.i_batch = batch.n_tokens;
batch.pos [batch.n_tokens] = num_tokens_system + slot.n_past;
batch.seq_id[batch.n_tokens] = slot.id; llama_batch_add(batch, slot.sampled, num_tokens_system + slot.n_past, { slot.id }, true);
batch.logits[batch.n_tokens] = true;
slot.n_decoded += 1; slot.n_decoded += 1;
slot.i_batch = batch.n_tokens;
slot.n_past += 1; slot.n_past += 1;
batch.n_tokens += 1;
} }
// process in chunks of params.n_batch // process in chunks of params.n_batch
int32_t n_batch = params.n_batch; int32_t n_batch = params.n_batch;
@ -1026,7 +1011,7 @@ struct llama_server_context
slot.num_prompt_tokens = prompt_tokens.size(); slot.num_prompt_tokens = prompt_tokens.size();
if(!slot.params.cache_prompt) { if(!slot.params.cache_prompt) {
std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end(), 0); std::fill(slot.ctx_sampling->prev.begin(), slot.ctx_sampling->prev.end(), 0);
slot.n_past = 0; slot.n_past = 0;
slot.num_prompt_tokens_processed = slot.num_prompt_tokens; slot.num_prompt_tokens_processed = slot.num_prompt_tokens;
} else { } else {
@ -1038,23 +1023,27 @@ struct llama_server_context
//if input prompt is too big, truncate like normal //if input prompt is too big, truncate like normal
if (slot.num_prompt_tokens >= (size_t)max_ctx_per_slot) if (slot.num_prompt_tokens >= (size_t)max_ctx_per_slot)
{ {
// applied bug of #3661
const int n_left = max_ctx_per_slot - slot.params.n_keep; const int n_left = max_ctx_per_slot - slot.params.n_keep;
const int n_block_size = n_left / 2;
const int erased_blocks = (slot.num_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + slot.params.n_keep); std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + slot.params.n_keep);
// Use half the left-over space in the context for the prompt // Use half the left-over space in the context for the prompt
new_tokens.insert(new_tokens.end(), prompt_tokens.end() - n_left / 2, prompt_tokens.end()); new_tokens.insert(new_tokens.end(), prompt_tokens.end() + slot.params.n_keep + erased_blocks * n_block_size, prompt_tokens.end());
LOG_VERBOSE("input truncated", { LOG_VERBOSE("input truncated", {
{"n_ctx", n_ctx}, {"n_ctx", max_ctx_per_slot},
{"n_keep", params.n_keep}, {"n_keep", slot.params.n_keep},
{"n_left", n_left}, {"n_left", n_left},
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())}, {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
}); });
slot.truncated = true; slot.truncated = true;
prompt_tokens = new_tokens; prompt_tokens = new_tokens;
slot.num_prompt_tokens = prompt_tokens.size(); slot.num_prompt_tokens = prompt_tokens.size();
GGML_ASSERT(slot.num_prompt_tokens < (size_t)max_ctx_per_slot);
} }
const size_t ps = slot.num_prompt_tokens; const size_t ps = slot.num_prompt_tokens;
std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end() - ps, 0); std::fill(slot.ctx_sampling->prev.begin(), slot.ctx_sampling->prev.end() - ps, 0);
std::copy(prompt_tokens.begin(), prompt_tokens.end(), slot.last_n_tokens.end() - ps); std::copy(prompt_tokens.begin(), prompt_tokens.end(), slot.ctx_sampling->prev.end() - ps);
slot.n_past = common_part(slot.cache_tokens, prompt_tokens); slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
slot.num_prompt_tokens_processed = slot.num_prompt_tokens - slot.n_past; slot.num_prompt_tokens_processed = slot.num_prompt_tokens - slot.n_past;
LOG_TEE("slot %i - in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed); LOG_TEE("slot %i - in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed);
@ -1081,11 +1070,7 @@ struct llama_server_context
// process the prefix of first image // process the prefix of first image
std::vector<llama_token> prefix_tokens = ingest_images ? tokenize(slot.images[0].prefix_prompt, true) : prompt_tokens; std::vector<llama_token> prefix_tokens = ingest_images ? tokenize(slot.images[0].prefix_prompt, true) : prompt_tokens;
for (; slot.n_past < prefix_tokens.size(); ++slot.n_past) { for (; slot.n_past < prefix_tokens.size(); ++slot.n_past) {
batch.token [batch.n_tokens] = prefix_tokens[slot.n_past]; llama_batch_add(batch, prefix_tokens[slot.n_past], num_tokens_system + slot.n_past, { slot.id }, false);
batch.pos [batch.n_tokens] = slot.n_past + num_tokens_system;
batch.seq_id[batch.n_tokens] = slot.id;
batch.logits[batch.n_tokens] = false;
batch.n_tokens += 1;
} }
if(ingest_images && !ingestImages(slot, n_batch)) { if(ingest_images && !ingestImages(slot, n_batch)) {
@ -1113,11 +1098,12 @@ struct llama_server_context
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
llama_batch batch_view = { llama_batch batch_view = {
n_tokens, n_tokens,
batch.token + i, batch.token + i,
nullptr, nullptr,
batch.pos + i, batch.pos + i,
batch.seq_id + i, batch.n_seq_id + i,
batch.logits + i, batch.seq_id + i,
batch.logits + i,
0, 0, 0, // unused 0, 0, 0, // unused
}; };
@ -1150,25 +1136,27 @@ struct llama_server_context
} }
completion_token_output result; completion_token_output result;
const llama_token id = llama_sampling_sample(ctx, NULL, slot.ctx_sampling, slot.last_n_tokens, candidates, slot.i_batch - i); const llama_token id = llama_sampling_sample(slot.ctx_sampling, ctx, NULL, slot.i_batch - i);
llama_sampling_accept(slot.ctx_sampling, ctx, id);
if (slot.n_decoded == 1) { if (slot.n_decoded == 1) {
slot.t_start_genereration = ggml_time_us(); slot.t_start_genereration = ggml_time_us();
slot.t_prompt_processing = (slot.t_start_genereration - slot.t_start_process_prompt) / 1e3; slot.t_prompt_processing = (slot.t_start_genereration - slot.t_start_process_prompt) / 1e3;
} }
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; llama_token_data_array cur_p = { slot.ctx_sampling->cur.data(), slot.ctx_sampling->cur.size(), false };
result.tok = id; result.tok = id;
const int32_t n_probs = slot.sparams.n_probs; const int32_t n_probs = slot.sparams.n_probs;
if (slot.sparams.temp <= 0 && n_probs > 0) if (slot.sparams.temp <= 0 && n_probs > 0)
{ {
// For llama_sample_token_greedy we need to sort candidates // For llama_sample_token_greedy we need to sort candidates
llama_sample_softmax(ctx, &candidates_p); llama_sample_softmax(ctx, &cur_p);
} }
for (size_t i = 0; i < std::min(candidates_p.size, (size_t)n_probs); ++i) for (size_t i = 0; i < std::min(cur_p.size, (size_t)n_probs); ++i)
{ {
result.probs.push_back({candidates_p.data[i].id, candidates_p.data[i].p}); result.probs.push_back({cur_p.data[i].id, cur_p.data[i].p});
} }
if (!processToken(result, slot)) { if (!processToken(result, slot)) {