diff --git a/examples/server/chat.mjs b/examples/server/chat.mjs index 103116bf2..219ebb51a 100644 --- a/examples/server/chat.mjs +++ b/examples/server/chat.mjs @@ -86,7 +86,7 @@ async function chat_completion(question) { n_predict: 256, cache_prompt: no_cached_prompt === "false", slot_id: slot_id, - stop: ["### Human:"], // stop completion after generating this + stop: ["\n### Human:"], // stop completion after generating this grammar, stream: true, }) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 66632a4ed..b8bff2443 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -316,6 +316,7 @@ struct llama_client_slot struct slot_params params; struct llama_sampling_params sparams; llama_sampling_context ctx_sampling; + bool has_next_token = true; // grammar props grammar_parser::parse_state parsed_grammar; @@ -710,9 +711,14 @@ struct llama_server_context if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) { + if (type == STOP_FULL) + { + slot.stopped_word = true; + slot.stopping_word = word; + slot.has_next_token = false; + } stop_pos = pos; - slot.stopped_word = true; - slot.stopping_word = word; + } } return stop_pos; @@ -727,6 +733,8 @@ struct llama_server_context // search stop word and delete it slot.generated_text += token_str; + slot.has_next_token = true; + size_t pos = std::min(slot.sent_count, slot.generated_text.size()); const std::string str_test = slot.generated_text.substr(pos); bool is_stop_full = false; @@ -744,15 +752,13 @@ struct llama_server_context } // check if there is any token to predict - bool has_next_token = !is_stop_full && stop_pos > 0; - if(stop_pos == std::string::npos) { + if(stop_pos == std::string::npos || (!slot.has_next_token && !is_stop_full && stop_pos > 0)) { // no send the stop word in the response result.text_to_send = slot.generated_text.substr(pos, std::string::npos); slot.sent_count += result.text_to_send.size(); - has_next_token = true; + // add the token to slot queue and cache + slot.addTokenString(result); } - // add the token to slot queue and cache - slot.addTokenString(result); if (slot.multibyte_pending > 0) { slot.multibyte_pending -= token_str.size(); @@ -781,29 +787,29 @@ struct llama_server_context } } - if (slot.multibyte_pending > 0 && !has_next_token) + if (slot.multibyte_pending > 0 && !slot.has_next_token) { - has_next_token = true; + slot.has_next_token = true; } // check the limits if ( - slot.n_decoded > 2 && has_next_token && !slot.hasBudget(params)) + slot.n_decoded > 2 && slot.has_next_token && !slot.hasBudget(params)) { slot.stopped_limit = true; - has_next_token = false; + slot.has_next_token = false; } if (!slot.cache_tokens.empty() && result.tok == llama_token_eos(ctx)){ slot.stopped_eos = true; - has_next_token = false; + slot.has_next_token = false; LOG_VERBOSE("eos token found", {}); } LOG_VERBOSE("next token", { {"token", result.tok}, {"token_text", tokens_to_output_formatted_string(ctx, result.tok)}, - {"has_next_token", has_next_token}, + {"has_next_token", slot.has_next_token}, {"n_remain", slot.n_remaining}, {"num_tokens_predicted", slot.num_tokens_predicted}, {"stopped_eos", slot.stopped_eos}, @@ -811,7 +817,7 @@ struct llama_server_context {"stopped_limit", slot.stopped_limit}, {"stopping_word", slot.stopping_word}, }); - return has_next_token; // continue + return slot.has_next_token; // continue } #ifdef SERVER_MULTIMODAL_SUPPORT @@ -2293,7 +2299,6 @@ int main(int argc, char **argv) const json body = json::parse(req.body); llama_client_slot* slot = llama.getSlot(-1); slot->reset(); - //llama_reset_timings(llama.ctx); if (body.count("content") != 0) { slot->prompt = body["content"];