sampling : hide prev behind API and apply #3661

ggml-ci
This commit is contained in:
Georgi Gerganov 2023-10-20 18:26:20 +03:00
parent 7e2b5fb1dd
commit 56ba00b923
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
9 changed files with 119 additions and 105 deletions

View File

@ -66,6 +66,24 @@ void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * ds
dst->prev = src->prev; dst->prev = src->prev;
} }
llama_token llama_sampling_last(llama_sampling_context * ctx) {
return ctx->prev.back();
}
std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n) {
const int size = ctx_sampling->prev.size();
n = std::min(n, size);
std::string result;
for (int i = size - n; i < size; i++) {
result += llama_token_to_piece(ctx_main, ctx_sampling->prev[i]);
}
return result;
}
std::string llama_sampling_print(const llama_sampling_params & params) { std::string llama_sampling_print(const llama_sampling_params & params) {
char result[1024]; char result[1024];
@ -193,11 +211,12 @@ llama_token llama_sampling_sample(
void llama_sampling_accept( void llama_sampling_accept(
struct llama_sampling_context * ctx_sampling, struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main, struct llama_context * ctx_main,
llama_token id) { llama_token id,
bool apply_grammar) {
ctx_sampling->prev.erase(ctx_sampling->prev.begin()); ctx_sampling->prev.erase(ctx_sampling->prev.begin());
ctx_sampling->prev.push_back(id); ctx_sampling->prev.push_back(id);
if (ctx_sampling->grammar != NULL) { if (ctx_sampling->grammar != NULL && apply_grammar) {
llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id); llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id);
} }
} }

View File

@ -70,6 +70,12 @@ void llama_sampling_reset(llama_sampling_context * ctx);
// Copy the sampler context // Copy the sampler context
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst); void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst);
// Get the last sampled token
llama_token llama_sampling_last(llama_sampling_context * ctx);
// Get a string representation of the last sampled tokens
std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n);
// Print sampling parameters into a string // Print sampling parameters into a string
std::string llama_sampling_print(const llama_sampling_params & params); std::string llama_sampling_print(const llama_sampling_params & params);
@ -99,4 +105,5 @@ llama_token llama_sampling_sample(
void llama_sampling_accept( void llama_sampling_accept(
struct llama_sampling_context * ctx_sampling, struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main, struct llama_context * ctx_main,
llama_token id); llama_token id,
bool apply_grammar);

View File

@ -12,25 +12,26 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR})
if (EMSCRIPTEN) if (EMSCRIPTEN)
else() else()
add_subdirectory(main)
add_subdirectory(quantize)
add_subdirectory(quantize-stats)
add_subdirectory(perplexity)
add_subdirectory(embedding)
add_subdirectory(save-load-state)
add_subdirectory(benchmark)
add_subdirectory(baby-llama) add_subdirectory(baby-llama)
add_subdirectory(train-text-from-scratch)
add_subdirectory(finetune)
add_subdirectory(convert-llama2c-to-ggml)
add_subdirectory(simple)
add_subdirectory(batched) add_subdirectory(batched)
add_subdirectory(batched-bench) add_subdirectory(batched-bench)
add_subdirectory(speculative)
add_subdirectory(parallel)
add_subdirectory(llava)
add_subdirectory(llama-bench)
add_subdirectory(beam-search) add_subdirectory(beam-search)
add_subdirectory(benchmark)
add_subdirectory(convert-llama2c-to-ggml)
add_subdirectory(embedding)
add_subdirectory(finetune)
add_subdirectory(infill)
add_subdirectory(llama-bench)
add_subdirectory(llava)
add_subdirectory(main)
add_subdirectory(parallel)
add_subdirectory(perplexity)
add_subdirectory(quantize)
add_subdirectory(quantize-stats)
add_subdirectory(save-load-state)
add_subdirectory(simple)
add_subdirectory(speculative)
add_subdirectory(train-text-from-scratch)
if (LLAMA_METAL) if (LLAMA_METAL)
add_subdirectory(metal) add_subdirectory(metal)
endif() endif()

View File

@ -523,7 +523,7 @@ int main(int argc, char ** argv) {
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance); const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
llama_sampling_accept(ctx_sampling, ctx, id); llama_sampling_accept(ctx_sampling, ctx, id, true);
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str()); LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str());
@ -541,8 +541,11 @@ int main(int argc, char ** argv) {
LOG("embd_inp.size(): %d, n_consumed: %d\n", (int) embd_inp.size(), n_consumed); LOG("embd_inp.size(): %d, n_consumed: %d\n", (int) embd_inp.size(), n_consumed);
while ((int) embd_inp.size() > n_consumed) { while ((int) embd_inp.size() > n_consumed) {
embd.push_back(embd_inp[n_consumed]); embd.push_back(embd_inp[n_consumed]);
ctx_sampling->prev.erase(ctx_sampling->prev.begin());
ctx_sampling->prev.push_back(embd_inp[n_consumed]); // push the prompt in the sampling context in order to apply repetition penalties later
// for the prompt, we don't apply grammar rules
llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], false);
++n_consumed; ++n_consumed;
if ((int) embd.size() >= params.n_batch) { if ((int) embd.size() >= params.n_batch) {
break; break;
@ -574,7 +577,7 @@ int main(int argc, char ** argv) {
if ((int) embd_inp.size() <= n_consumed) { if ((int) embd_inp.size() <= n_consumed) {
// deal with eot token in infill mode // deal with eot token in infill mode
if ((ctx_sampling->prev.back() == llama_token_eot(ctx) || is_interacting) && params.interactive){ if ((llama_sampling_last(ctx_sampling) == llama_token_eot(ctx) || is_interacting) && params.interactive){
if(is_interacting && !params.interactive_first) { if(is_interacting && !params.interactive_first) {
// print an eot token // print an eot token
printf("%s", llama_token_to_piece(ctx, llama_token_eot(ctx)).c_str()); printf("%s", llama_token_to_piece(ctx, llama_token_eot(ctx)).c_str());
@ -591,7 +594,7 @@ int main(int argc, char ** argv) {
buffer += line; buffer += line;
} while (another_line); } while (another_line);
// check if we got an empty line, if so we use the old input // check if we got an empty line, if so we use the old input
if(!buffer.empty() && !(buffer.length() == 1 && buffer[0] == '\n')) { if (!buffer.empty() && !(buffer.length() == 1 && buffer[0] == '\n')) {
params.input_prefix = buffer; params.input_prefix = buffer;
} }
buffer.clear(); buffer.clear();
@ -601,7 +604,7 @@ int main(int argc, char ** argv) {
buffer += line; buffer += line;
} while (another_line); } while (another_line);
// check if we got an empty line // check if we got an empty line
if(!buffer.empty() && !(buffer.length() == 1 && buffer[0] == '\n')) { if (!buffer.empty() && !(buffer.length() == 1 && buffer[0] == '\n')) {
params.input_suffix = buffer; params.input_suffix = buffer;
} }
buffer.clear(); buffer.clear();
@ -614,7 +617,7 @@ int main(int argc, char ** argv) {
process_escapes(params.input_suffix); process_escapes(params.input_suffix);
} }
suff_rm_leading_spc = params.escape; suff_rm_leading_spc = params.escape;
if (suff_rm_leading_spc && params.input_suffix.find_first_of(" ") == 0 && params.input_suffix.size() > 1) { if (suff_rm_leading_spc && params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) {
params.input_suffix.erase(0, 1); params.input_suffix.erase(0, 1);
suff_rm_leading_spc = false; suff_rm_leading_spc = false;
} }
@ -641,7 +644,7 @@ int main(int argc, char ** argv) {
is_interacting = false; is_interacting = false;
} }
// deal with end of text token in interactive mode // deal with end of text token in interactive mode
else if (ctx_sampling->prev.back() == llama_token_eos(ctx)) { else if (llama_sampling_last(ctx_sampling) == llama_token_eos(ctx)) {
LOG("found EOS token\n"); LOG("found EOS token\n");
if (params.interactive) { if (params.interactive) {

View File

@ -611,7 +611,7 @@ int main(int argc, char ** argv) {
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance); const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
llama_sampling_accept(ctx_sampling, ctx, id); llama_sampling_accept(ctx_sampling, ctx, id, true);
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str()); LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str());
@ -630,12 +630,9 @@ int main(int argc, char ** argv) {
while ((int) embd_inp.size() > n_consumed) { while ((int) embd_inp.size() > n_consumed) {
embd.push_back(embd_inp[n_consumed]); embd.push_back(embd_inp[n_consumed]);
// GG: I'm not sure it's a good idea to push the prompt tokens into the sampling context // push the prompt in the sampling context in order to apply repetition penalties later
// Most likely will remove this in the future to avoid exposing "prev" // for the prompt, we don't apply grammar rules
// Same thing is done in "server". If we stop pushing the prompt tokens, then the repetition llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], false);
// penalty will be applied only based on the tokens generated by the model.
ctx_sampling->prev.erase(ctx_sampling->prev.begin());
ctx_sampling->prev.push_back(embd_inp[n_consumed]);
++n_consumed; ++n_consumed;
if ((int) embd.size() >= params.n_batch) { if ((int) embd.size() >= params.n_batch) {
@ -666,12 +663,10 @@ int main(int argc, char ** argv) {
// if not currently processing queued inputs; // if not currently processing queued inputs;
if ((int) embd_inp.size() <= n_consumed) { if ((int) embd_inp.size() <= n_consumed) {
// check for reverse prompt // check for reverse prompt in the last n_prev tokens
if (!params.antiprompt.empty()) { if (!params.antiprompt.empty()) {
std::string last_output; const int n_prev = 32;
for (auto id : ctx_sampling->prev) { const std::string last_output = llama_sampling_prev_str(ctx_sampling, ctx, n_prev);
last_output += llama_token_to_piece(ctx, id);
}
is_antiprompt = false; is_antiprompt = false;
// Check if each of the reverse prompts appears at the end of the output. // Check if each of the reverse prompts appears at the end of the output.
@ -698,7 +693,7 @@ int main(int argc, char ** argv) {
} }
// deal with end of text token in interactive mode // deal with end of text token in interactive mode
if (ctx_sampling->prev.back() == llama_token_eos(ctx)) { if (llama_sampling_last(ctx_sampling) == llama_token_eos(ctx)) {
LOG("found EOS token\n"); LOG("found EOS token\n");
if (params.interactive) { if (params.interactive) {

View File

@ -330,7 +330,7 @@ int main(int argc, char ** argv) {
const llama_token id = llama_sampling_sample(client.ctx_sampling, ctx, NULL, client.i_batch - i); const llama_token id = llama_sampling_sample(client.ctx_sampling, ctx, NULL, client.i_batch - i);
llama_sampling_accept(client.ctx_sampling, ctx, id); llama_sampling_accept(client.ctx_sampling, ctx, id, true);
if (client.n_decoded == 1) { if (client.n_decoded == 1) {
// start measuring generation time after the first token to make sure all concurrent clients // start measuring generation time after the first token to make sure all concurrent clients

View File

@ -195,10 +195,12 @@ struct llama_server_context
json prompt; json prompt;
std::vector<llama_token> embd; std::vector<llama_token> embd;
gpt_params params;
llama_model *model = nullptr; llama_model *model = nullptr;
llama_context *ctx = nullptr; llama_context *ctx = nullptr;
gpt_params params;
llama_sampling_context *ctx_sampling = nullptr; llama_sampling_context *ctx_sampling = nullptr;
int n_ctx; int n_ctx;
bool truncated = false; bool truncated = false;
@ -246,7 +248,10 @@ struct llama_server_context
multibyte_pending = 0; multibyte_pending = 0;
n_remain = 0; n_remain = 0;
n_past = 0; n_past = 0;
params.sparams.n_prev = n_ctx;
}
void initSampling() {
if (ctx_sampling != nullptr) { if (ctx_sampling != nullptr) {
llama_sampling_free(ctx_sampling); llama_sampling_free(ctx_sampling);
} }
@ -311,16 +316,32 @@ struct llama_server_context
return prompt_tokens; return prompt_tokens;
} }
bool loadGrammar() void truncatePrompt(std::vector<llama_token> &prompt_tokens) {
{ const int n_left = n_ctx - params.n_keep;
ctx_sampling = llama_sampling_init(params.sparams); const int n_block_size = n_left / 2;
return true; const int erased_blocks = (prompt_tokens.size() - params.n_keep - n_block_size) / n_block_size;
// Keep n_keep tokens at start of prompt (at most n_ctx - 4)
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_block_size, prompt_tokens.end());
LOG_VERBOSE("input truncated", {
{"n_ctx", n_ctx},
{"n_keep", params.n_keep},
{"n_left", n_left},
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
{"num_prompt_tokens", new_tokens.size()}
});
truncated = true;
prompt_tokens = new_tokens;
} }
void loadInfill() void loadInfill()
{ {
bool suff_rm_leading_spc = true; bool suff_rm_leading_spc = true;
if (params.input_suffix.find_first_of(" ") == 0 && params.input_suffix.size() > 1) { if (params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) {
params.input_suffix.erase(0, 1); params.input_suffix.erase(0, 1);
suff_rm_leading_spc = false; suff_rm_leading_spc = false;
} }
@ -336,6 +357,7 @@ struct llama_server_context
prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(ctx)); prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(ctx));
prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end()); prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end());
prefix_tokens.push_back(llama_token_middle(ctx)); prefix_tokens.push_back(llama_token_middle(ctx));
auto prompt_tokens = prefix_tokens; auto prompt_tokens = prefix_tokens;
num_prompt_tokens = prompt_tokens.size(); num_prompt_tokens = prompt_tokens.size();
@ -347,31 +369,18 @@ struct llama_server_context
params.n_keep = std::min(params.n_ctx - 4, params.n_keep); params.n_keep = std::min(params.n_ctx - 4, params.n_keep);
// if input prompt is too big, truncate like normal // if input prompt is too big, truncate like normal
if (num_prompt_tokens >= (size_t)params.n_ctx) if (num_prompt_tokens >= (size_t) n_ctx)
{ {
printf("Input prompt is too big, truncating. Can only take %d tokens but got %zu\n", params.n_ctx, num_prompt_tokens); truncatePrompt(prompt_tokens);
// todo we probably want to cut from both sides num_prompt_tokens = prompt_tokens.size();
const int n_left = (params.n_ctx - params.n_keep) / 2;
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left;
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end());
std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), ctx_sampling->prev.begin());
LOG_VERBOSE("input truncated", { GGML_ASSERT(num_prompt_tokens < (size_t)n_ctx);
{"n_ctx", params.n_ctx},
{"n_keep", params.n_keep},
{"n_left", n_left},
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
});
truncated = true;
prompt_tokens = new_tokens;
} }
else
// push the prompt into the sampling context (do not apply grammar)
for (auto & token : prompt_tokens)
{ {
const size_t ps = num_prompt_tokens; llama_sampling_accept(ctx_sampling, ctx, token, false);
std::fill(ctx_sampling->prev.begin(), ctx_sampling->prev.end() - ps, 0);
std::copy(prompt_tokens.begin(), prompt_tokens.end(), ctx_sampling->prev.end() - ps);
} }
// compare the evaluated prompt with the new prompt // compare the evaluated prompt with the new prompt
@ -409,29 +418,18 @@ struct llama_server_context
params.n_keep = std::min(n_ctx - 4, params.n_keep); params.n_keep = std::min(n_ctx - 4, params.n_keep);
// if input prompt is too big, truncate like normal // if input prompt is too big, truncate like normal
if (num_prompt_tokens >= (size_t)n_ctx) if (num_prompt_tokens >= (size_t) n_ctx)
{ {
const int n_left = (n_ctx - params.n_keep) / 2; truncatePrompt(prompt_tokens);
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep); num_prompt_tokens = prompt_tokens.size();
const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left;
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end());
std::copy(prompt_tokens.end() - n_ctx, prompt_tokens.end(), ctx_sampling->prev.begin());
LOG_VERBOSE("input truncated", { GGML_ASSERT(num_prompt_tokens < (size_t)n_ctx);
{"n_ctx", n_ctx},
{"n_keep", params.n_keep},
{"n_left", n_left},
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
});
truncated = true;
prompt_tokens = new_tokens;
} }
else
// push the prompt into the sampling context (do not apply grammar)
for (auto & token : prompt_tokens)
{ {
const size_t ps = num_prompt_tokens; llama_sampling_accept(ctx_sampling, ctx, token, false);
std::fill(ctx_sampling->prev.begin(), ctx_sampling->prev.end() - ps, 0);
std::copy(prompt_tokens.begin(), prompt_tokens.end(), ctx_sampling->prev.end() - ps);
} }
// compare the evaluated prompt with the new prompt // compare the evaluated prompt with the new prompt
@ -542,7 +540,7 @@ struct llama_server_context
result.probs.push_back({cur_p.data[i].id, cur_p.data[i].p}); result.probs.push_back({cur_p.data[i].id, cur_p.data[i].p});
} }
llama_sampling_accept(ctx_sampling, ctx, result.tok); llama_sampling_accept(ctx_sampling, ctx, result.tok, true);
if (tg) { if (tg) {
num_tokens_predicted++; num_tokens_predicted++;
@ -1206,8 +1204,6 @@ static void parse_options_completion(const json &body, llama_server_context &lla
} }
} }
llama.ctx_sampling = llama_sampling_init(llama.params.sparams);
LOG_VERBOSE("completion parameters parsed", format_generation_settings(llama)); LOG_VERBOSE("completion parameters parsed", format_generation_settings(llama));
} }
@ -1376,15 +1372,9 @@ int main(int argc, char **argv)
llama.rewind(); llama.rewind();
llama_reset_timings(llama.ctx); llama_reset_timings(llama.ctx);
parse_options_completion(json::parse(req.body), llama); parse_options_completion(json::parse(req.body), llama);
if (!llama.loadGrammar()) llama.initSampling();
{
res.status = 400;
return;
}
llama.loadPrompt(); llama.loadPrompt();
llama.beginCompletion(); llama.beginCompletion();
@ -1539,14 +1529,9 @@ int main(int argc, char **argv)
llama.rewind(); llama.rewind();
llama_reset_timings(llama.ctx); llama_reset_timings(llama.ctx);
parse_options_infill(json::parse(req.body), llama); parse_options_infill(json::parse(req.body), llama);
if (!llama.loadGrammar()) llama.initSampling();
{
res.status = 400;
return;
}
llama.loadInfill(); llama.loadInfill();
llama.beginCompletion(); llama.beginCompletion();
const auto chunked_content_provider = [&](size_t, DataSink & sink) { const auto chunked_content_provider = [&](size_t, DataSink & sink) {
@ -1696,7 +1681,9 @@ int main(int argc, char **argv)
const json body = json::parse(req.body); const json body = json::parse(req.body);
llama.rewind(); llama.rewind();
llama_reset_timings(llama.ctx); llama_reset_timings(llama.ctx);
if (body.count("content") != 0) if (body.count("content") != 0)
{ {
llama.prompt = body["content"]; llama.prompt = body["content"];
@ -1706,6 +1693,8 @@ int main(int argc, char **argv)
llama.prompt = ""; llama.prompt = "";
} }
llama.params.n_predict = 0; llama.params.n_predict = 0;
llama.initSampling();
llama.loadPrompt(); llama.loadPrompt();
llama.beginCompletion(); llama.beginCompletion();
llama.doCompletion(); llama.doCompletion();

View File

@ -154,7 +154,7 @@ int main(int argc, char ** argv) {
// sample from the target model // sample from the target model
llama_token id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]); llama_token id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
llama_sampling_accept(ctx_sampling, ctx_tgt, id); llama_sampling_accept(ctx_sampling, ctx_tgt, id, true);
//LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str()); //LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str());
@ -328,7 +328,7 @@ int main(int argc, char ** argv) {
const int s = sa[is]; const int s = sa[is];
llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id); llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id, true);
drafts[s].tokens.push_back(id); drafts[s].tokens.push_back(id);