server : add n_indent parameter for line indentation requirement (#9929)

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-10-18 07:32:19 +03:00 committed by GitHub
parent 6f55bccbb8
commit 8901755ba3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 49 additions and 7 deletions

View File

@ -333,6 +333,8 @@ node index.js
`n_predict`: Set the maximum number of tokens to predict when generating text. **Note:** May exceed the set limit slightly if the last token is a partial multibyte character. When 0, no tokens will be generated but the prompt is evaluated into the cache. Default: `-1`, where `-1` is infinity.
`n_indent`: Specify the minimum line indentation for the generated text in number of whitespace characters. Useful for code completion tasks. Default: `0`
`n_keep`: Specify the number of tokens from the prompt to retain when the context size is exceeded and tokens need to be discarded. The number excludes the BOS token.
By default, this value is set to `0`, meaning no tokens are kept. Use `-1` to retain all tokens from the prompt.

View File

@ -131,6 +131,7 @@ struct slot_params {
int32_t n_keep = 0; // number of tokens to keep from initial prompt
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
int32_t n_predict = -1; // new tokens to predict
int32_t n_indent = 0; // mininum line indentation for the generated text in number of whitespace characters
int64_t t_max_prompt_ms = -1; // TODO: implement
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
@ -173,6 +174,8 @@ struct server_slot {
std::vector<llama_token> prompt_tokens;
std::vector<llama_token> extra_tokens;
size_t last_nl_pos = 0;
std::string generated_text;
std::vector<llama_token> cache_tokens;
std::vector<completion_token_output> generated_token_probs;
@ -215,6 +218,7 @@ struct server_slot {
SLT_DBG(*this, "%s", "\n");
n_prompt_tokens = 0;
last_nl_pos = 0;
generated_text = "";
has_new_line = false;
truncated = false;
@ -860,6 +864,7 @@ struct server_context {
slot.params.stream = json_value(data, "stream", false);
slot.params.cache_prompt = json_value(data, "cache_prompt", false);
slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", default_params.n_predict));
slot.params.n_indent = json_value(data, "n_indent", default_params.n_indent);
slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
@ -878,7 +883,7 @@ struct server_context {
slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
slot.params.n_keep = json_value(data, "n_keep", default_params.n_keep);
slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
@ -1129,13 +1134,48 @@ struct server_context {
SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict);
}
// if we have already seen a new line, we stop after a certain time limit
if (slot.has_new_line && slot.params.t_max_predict_ms > 0 &&
(ggml_time_us() - slot.t_start_generation > 1000.0f*slot.params.t_max_predict_ms)) {
slot.stopped_limit = true;
slot.has_next_token = false;
if (slot.has_new_line) {
// if we have already seen a new line, we stop after a certain time limit
if (slot.params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.params.t_max_predict_ms)) {
slot.stopped_limit = true;
slot.has_next_token = false;
SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.params.t_max_predict_ms);
SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.params.t_max_predict_ms);
}
// require that each new line has a whitespace prefix (i.e. indentation) of at least slot.params.n_indent
if (slot.params.n_indent > 0) {
// check the current indentation
// TODO: improve by not doing it more than once for each new line
if (slot.last_nl_pos > 0) {
size_t pos = slot.last_nl_pos;
int n_indent = 0;
while (pos < slot.generated_text.size() && (slot.generated_text[pos] == ' ' || slot.generated_text[pos] == '\t')) {
n_indent++;
pos++;
}
if (pos < slot.generated_text.size() && n_indent < slot.params.n_indent) {
slot.stopped_limit = true;
slot.has_next_token = false;
// cut the last line
slot.generated_text.erase(pos, std::string::npos);
SLT_DBG(slot, "stopped by indentation limit, n_decoded = %d, n_indent = %d\n", slot.n_decoded, n_indent);
}
}
// find the next new line
{
const size_t pos = slot.generated_text.find('\n', slot.last_nl_pos);
if (pos != std::string::npos) {
slot.last_nl_pos = pos + 1;
}
}
}
}
// check if there is a new line in the generated text