fixed premature end due stop word

This commit is contained in:
FSSRepo 2023-10-16 12:36:05 -04:00
parent fd64f04fc2
commit 2d9f11db28
2 changed files with 21 additions and 16 deletions

View File

@ -86,7 +86,7 @@ async function chat_completion(question) {
n_predict: 256, n_predict: 256,
cache_prompt: no_cached_prompt === "false", cache_prompt: no_cached_prompt === "false",
slot_id: slot_id, slot_id: slot_id,
stop: ["### Human:"], // stop completion after generating this stop: ["\n### Human:"], // stop completion after generating this
grammar, grammar,
stream: true, stream: true,
}) })

View File

@ -316,6 +316,7 @@ struct llama_client_slot
struct slot_params params; struct slot_params params;
struct llama_sampling_params sparams; struct llama_sampling_params sparams;
llama_sampling_context ctx_sampling; llama_sampling_context ctx_sampling;
bool has_next_token = true;
// grammar props // grammar props
grammar_parser::parse_state parsed_grammar; grammar_parser::parse_state parsed_grammar;
@ -710,9 +711,14 @@ struct llama_server_context
if (pos != std::string::npos && if (pos != std::string::npos &&
(stop_pos == std::string::npos || pos < stop_pos)) (stop_pos == std::string::npos || pos < stop_pos))
{ {
if (type == STOP_FULL)
{
slot.stopped_word = true;
slot.stopping_word = word;
slot.has_next_token = false;
}
stop_pos = pos; stop_pos = pos;
slot.stopped_word = true;
slot.stopping_word = word;
} }
} }
return stop_pos; return stop_pos;
@ -727,6 +733,8 @@ struct llama_server_context
// search stop word and delete it // search stop word and delete it
slot.generated_text += token_str; slot.generated_text += token_str;
slot.has_next_token = true;
size_t pos = std::min(slot.sent_count, slot.generated_text.size()); size_t pos = std::min(slot.sent_count, slot.generated_text.size());
const std::string str_test = slot.generated_text.substr(pos); const std::string str_test = slot.generated_text.substr(pos);
bool is_stop_full = false; bool is_stop_full = false;
@ -744,15 +752,13 @@ struct llama_server_context
} }
// check if there is any token to predict // check if there is any token to predict
bool has_next_token = !is_stop_full && stop_pos > 0; if(stop_pos == std::string::npos || (!slot.has_next_token && !is_stop_full && stop_pos > 0)) {
if(stop_pos == std::string::npos) {
// no send the stop word in the response // no send the stop word in the response
result.text_to_send = slot.generated_text.substr(pos, std::string::npos); result.text_to_send = slot.generated_text.substr(pos, std::string::npos);
slot.sent_count += result.text_to_send.size(); slot.sent_count += result.text_to_send.size();
has_next_token = true; // add the token to slot queue and cache
slot.addTokenString(result);
} }
// add the token to slot queue and cache
slot.addTokenString(result);
if (slot.multibyte_pending > 0) if (slot.multibyte_pending > 0)
{ {
slot.multibyte_pending -= token_str.size(); slot.multibyte_pending -= token_str.size();
@ -781,29 +787,29 @@ struct llama_server_context
} }
} }
if (slot.multibyte_pending > 0 && !has_next_token) if (slot.multibyte_pending > 0 && !slot.has_next_token)
{ {
has_next_token = true; slot.has_next_token = true;
} }
// check the limits // check the limits
if ( if (
slot.n_decoded > 2 && has_next_token && !slot.hasBudget(params)) slot.n_decoded > 2 && slot.has_next_token && !slot.hasBudget(params))
{ {
slot.stopped_limit = true; slot.stopped_limit = true;
has_next_token = false; slot.has_next_token = false;
} }
if (!slot.cache_tokens.empty() && result.tok == llama_token_eos(ctx)){ if (!slot.cache_tokens.empty() && result.tok == llama_token_eos(ctx)){
slot.stopped_eos = true; slot.stopped_eos = true;
has_next_token = false; slot.has_next_token = false;
LOG_VERBOSE("eos token found", {}); LOG_VERBOSE("eos token found", {});
} }
LOG_VERBOSE("next token", { LOG_VERBOSE("next token", {
{"token", result.tok}, {"token", result.tok},
{"token_text", tokens_to_output_formatted_string(ctx, result.tok)}, {"token_text", tokens_to_output_formatted_string(ctx, result.tok)},
{"has_next_token", has_next_token}, {"has_next_token", slot.has_next_token},
{"n_remain", slot.n_remaining}, {"n_remain", slot.n_remaining},
{"num_tokens_predicted", slot.num_tokens_predicted}, {"num_tokens_predicted", slot.num_tokens_predicted},
{"stopped_eos", slot.stopped_eos}, {"stopped_eos", slot.stopped_eos},
@ -811,7 +817,7 @@ struct llama_server_context
{"stopped_limit", slot.stopped_limit}, {"stopped_limit", slot.stopped_limit},
{"stopping_word", slot.stopping_word}, {"stopping_word", slot.stopping_word},
}); });
return has_next_token; // continue return slot.has_next_token; // continue
} }
#ifdef SERVER_MULTIMODAL_SUPPORT #ifdef SERVER_MULTIMODAL_SUPPORT
@ -2293,7 +2299,6 @@ int main(int argc, char **argv)
const json body = json::parse(req.body); const json body = json::parse(req.body);
llama_client_slot* slot = llama.getSlot(-1); llama_client_slot* slot = llama.getSlot(-1);
slot->reset(); slot->reset();
//llama_reset_timings(llama.ctx);
if (body.count("content") != 0) if (body.count("content") != 0)
{ {
slot->prompt = body["content"]; slot->prompt = body["content"];