diff --git a/examples/server/server.cpp b/examples/server/server.cpp index ad2ae4ee4..3e073de1c 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -91,6 +91,7 @@ struct completion_token_output std::vector probs; llama_token tok; + std::string text_to_send; }; static size_t common_part(const std::vector &a, const std::vector &b) @@ -231,6 +232,7 @@ struct llama_client_slot bool stopped_limit = false; std::string stopping_word; int32_t multibyte_pending = 0; + size_t sent_count = 0; struct slot_params params; struct llama_sampling_params sparams; @@ -453,7 +455,6 @@ struct llama_server_context else { auto s = json_prompt.template get(); - printf("----------------------\nprompt:\n%s-----------------------\n", s.c_str()); prompt_tokens = ::llama_tokenize(ctx, s, add_bos); } @@ -492,7 +493,6 @@ struct llama_server_context // compare the evaluated prompt with the new prompt } - llama_client_slot* getSlot(int id) { for (llama_client_slot & slot : slots) { @@ -703,17 +703,36 @@ struct llama_server_context slot.last_n_tokens.push_back(result.tok); const std::string token_str = llama_token_to_piece(ctx, result.tok); slot.sampled = result.tok; - slot.addTokenString(result); slot.generated_text += token_str; - size_t stop_pos = findStoppingStrings(slot.generated_text, token_str.size(), STOP_FULL, slot); - - bool has_next_token = !(slot.n_decoded > 2 && - (result.tok == llama_token_eos(ctx) || - (slot.n_decoded + slot.n_past >= - params.n_predict) || - stop_pos != std::string::npos)); + size_t pos = std::min(slot.sent_count, slot.generated_text.size()); + const std::string str_test = slot.generated_text.substr(pos); + bool is_stop_full = false; + size_t stop_pos = findStoppingStrings(str_test, token_str.size(), STOP_FULL, slot); + if (stop_pos != std::string::npos) { + is_stop_full = true; + slot.generated_text.erase( + slot.generated_text.begin() + pos + stop_pos, + slot.generated_text.end()); + pos = std::min(slot.sent_count, slot.generated_text.size()); + result.tok = -1; + } else { + is_stop_full = false; + stop_pos = findStoppingStrings(str_test, token_str.size(), + STOP_PARTIAL, slot); + } + bool has_next_token = !is_stop_full && stop_pos > 0; + if(stop_pos == std::string::npos) { + result.text_to_send = slot.generated_text.substr(pos, std::string::npos); + slot.sent_count += result.text_to_send.size(); + has_next_token = true; + } + slot.addTokenString(result); + if(slot.n_decoded > 2 && (result.tok == llama_token_eos(ctx) || + slot.n_past + slot.n_decoded >= params.n_predict)) { + has_next_token = false; + } if (slot.sparams.n_probs > 0) { slot.generated_token_probs.push_back(result); @@ -804,6 +823,7 @@ struct llama_server_context slot.state = IDLE; slot.command = NONE; slot.generated_text.clear(); + return true; } else { slot.state = SLEEPING; slot.command = NONE; @@ -1748,10 +1768,9 @@ int main(int argc, char **argv) while(slot->isProcessing()) { if(slot->hasNewToken()) { // new token notification const completion_token_output token = slot->next(); - std::string token_str = llama_token_to_piece(llama.ctx, token.tok); std::vector probs_output = {}; if (slot->sparams.n_probs > 0) { - const std::vector to_send_toks = llama_tokenize(llama.ctx, token_str, false); + const std::vector to_send_toks = llama_tokenize(llama.ctx, token.text_to_send, false); size_t probs_pos = std::min(sent_token_probs_index, slot->generated_token_probs.size()); size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), slot->generated_token_probs.size()); if (probs_pos < probs_stop_pos) { @@ -1759,7 +1778,7 @@ int main(int argc, char **argv) } sent_token_probs_index = probs_stop_pos; } - const json data = format_partial_response(llama, slot, token_str, probs_output); + const json data = format_partial_response(llama, slot, token.text_to_send, probs_output); const std::string str = "data: " + data.dump(-1, ' ', false, json::error_handler_t::replace) + @@ -1796,7 +1815,7 @@ int main(int argc, char **argv) sink.done(); return true; }; - auto on_complete = [&] (bool) { + auto on_complete = [slot, &llama] (bool) { llama.mutex.unlock(); slot->sent_tokens = 0; slot->generated_token_probs.clear();