From dd1af2ed35dfcf1e6842a57c4e69478e93622d89 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 22 Oct 2023 19:52:38 +0300 Subject: [PATCH] server : minor style --- examples/server/server.cpp | 145 +++++++++++++++++++++++++------------ 1 file changed, 98 insertions(+), 47 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 14edd3dd1..686a8c7c3 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -391,18 +391,19 @@ struct llama_client_slot double t_token_generation; // ms void reset() { - num_prompt_tokens = 0; - generated_text = ""; - truncated = false; - stopped_eos = false; - stopped_word = false; - stopped_limit = false; - stopping_word = ""; - multibyte_pending = 0; - n_past = 0; - sent_count = 0; + num_prompt_tokens = 0; + generated_text = ""; + truncated = false; + stopped_eos = false; + stopped_word = false; + stopped_limit = false; + stopping_word = ""; + multibyte_pending = 0; + n_past = 0; + sent_count = 0; sent_token_probs_index = 0; - infill = false; + infill = false; + generated_token_probs.clear(); for (slot_image &img : images) @@ -882,7 +883,8 @@ struct llama_server_context // wait until system prompt load system_need_update = true; - while (system_need_update) { + while (system_need_update) + { std::this_thread::sleep_for(std::chrono::milliseconds(5)); } // system prompt loaded, continue @@ -997,26 +999,31 @@ struct llama_server_context const std::string str_test = slot.generated_text.substr(pos); bool is_stop_full = false; size_t stop_pos = find_stopping_strings(str_test, token_str.size(), STOP_FULL, slot); - if (stop_pos != std::string::npos) { + if (stop_pos != std::string::npos) + { is_stop_full = true; slot.generated_text.erase( slot.generated_text.begin() + pos + stop_pos, slot.generated_text.end()); pos = std::min(slot.sent_count, slot.generated_text.size()); - } else { + } + else + { is_stop_full = false; stop_pos = find_stopping_strings(str_test, token_str.size(), STOP_PARTIAL, slot); } // check if there is any token to predict - if(stop_pos == std::string::npos || (!slot.has_next_token && !is_stop_full && stop_pos > 0)) { + if (stop_pos == std::string::npos || (!slot.has_next_token && !is_stop_full && stop_pos > 0)) + { // no send the stop word in the response result.text_to_send = slot.generated_text.substr(pos, std::string::npos); slot.sent_count += result.text_to_send.size(); // add the token to slot queue and cache } slot.add_token_string(result); - if (slot.params.stream) { + if (slot.params.stream) + { send_partial_response(slot, result); } } @@ -1051,6 +1058,7 @@ struct llama_server_context {"stopped_limit", slot.stopped_limit}, {"stopping_word", slot.stopping_word}, }); + return slot.has_next_token; // continue } @@ -1089,7 +1097,8 @@ struct llama_server_context return slot.images.size() > 0; } - void send_error(int id, std::string error) { + void send_error(int id, std::string error) + { std::lock_guard lock(mutex_results); task_result res; res.id = id; @@ -1098,11 +1107,13 @@ struct llama_server_context queue_results.push_back(res); } - json get_model_props() { + json get_model_props() + { return get_formated_generation(slots[0]); } - json get_formated_generation(llama_client_slot &slot) { + json get_formated_generation(llama_client_slot &slot) + { const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(ctx)); const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second); @@ -1134,12 +1145,14 @@ struct llama_server_context }; } - void send_partial_response(llama_client_slot & slot, completion_token_output tkn) { + void send_partial_response(llama_client_slot &slot, completion_token_output tkn) + { std::lock_guard lock(mutex_results); task_result res; res.id = slot.task_id; res.error = false; res.stop = false; + res.result_json = json { {"content", tkn.text_to_send}, @@ -1147,6 +1160,7 @@ struct llama_server_context {"slot_id", slot.id}, {"multimodal", multimodal} }; + if (slot.sparams.n_probs > 0) { std::vector probs_output = {}; @@ -1160,15 +1174,18 @@ struct llama_server_context slot.sent_token_probs_index = probs_stop_pos; res.result_json["completion_probabilities"] = probs_vector_to_json(ctx, probs_output); } + queue_results.push_back(res); } - void send_final_response(llama_client_slot & slot) { + void send_final_response(llama_client_slot &slot) + { std::lock_guard lock(mutex_results); task_result res; res.id = slot.task_id; res.error = false; res.stop = true; + res.result_json = json { {"content", !slot.params.stream ? slot.generated_text : ""}, @@ -1191,20 +1208,25 @@ struct llama_server_context if (slot.sparams.n_probs > 0) { std::vector probs = {}; - if(!slot.params.stream && slot.stopped_word) { + if (!slot.params.stream && slot.stopped_word) + { const std::vector stop_word_toks = llama_tokenize(ctx, slot.stopping_word, false); probs = std::vector(slot.generated_token_probs.begin(), slot.generated_token_probs.end() - stop_word_toks.size()); - } else { + } + else + { probs = std::vector( slot.generated_token_probs.begin(), slot.generated_token_probs.begin() + slot.sent_token_probs_index); } res.result_json["completion_probabilities"] = probs_vector_to_json(ctx, probs); } + queue_results.push_back(res); } - void send_embedding(llama_client_slot & slot) { + void send_embedding(llama_client_slot &slot) + { std::lock_guard lock(mutex_results); task_result res; res.id = slot.task_id; @@ -1234,7 +1256,8 @@ struct llama_server_context queue_results.push_back(res); } - int request_completion(json data, bool infill) { + int request_completion(json data, bool infill) + { std::lock_guard lock(mutex_tasks); task_server task; task.id = id_gen++; @@ -1245,17 +1268,22 @@ struct llama_server_context return task.id; } - task_result next_result(int task_id) { - while (true) { + task_result next_result(int task_id) + { + while (true) + { std::this_thread::sleep_for(std::chrono::microseconds(5)); std::lock_guard lock(mutex_results); - if (queue_results.empty()) { + if (queue_results.empty()) + { continue; } - for (int i = 0; i < (int) queue_results.size(); i++) { - if (queue_results[i].id == task_id) { + for (int i = 0; i < (int) queue_results.size(); i++) + { + if (queue_results[i].id == task_id) + { task_result res = queue_results[i]; queue_results.erase(queue_results.begin() + i); return res; @@ -1335,7 +1363,8 @@ struct llama_server_context return true; } - void request_cancel(int task_id) { + void request_cancel(int task_id) + { std::lock_guard lock(mutex_tasks); task_server task; task.id = id_gen++; @@ -1344,9 +1373,11 @@ struct llama_server_context queue_tasks.push_back(task); } - void process_tasks() { + void process_tasks() + { std::lock_guard lock(mutex_tasks); - while (!queue_tasks.empty()) { + while (!queue_tasks.empty()) + { task_server task = queue_tasks.front(); queue_tasks.erase(queue_tasks.begin()); switch (task.type) @@ -1379,8 +1410,10 @@ struct llama_server_context } } break; case CANCEL_TASK: { // release slot linked with the task id - for (auto & slot : slots) { - if (slot.task_id == task.target_id) { + for (auto & slot : slots) + { + if (slot.task_id == task.target_id) + { slot.release(); break; } @@ -2006,7 +2039,8 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, else if (arg == "--embedding") { params.embedding = true; - } else if (arg == "-cb" || arg == "--cont-batching") + } + else if (arg == "-cb" || arg == "--cont-batching") { params.cont_batching = true; } @@ -2047,7 +2081,8 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, ); llama.process_system_prompt_data(json::parse(systm_content)); } - else if(arg == "--mmproj") { + else if(arg == "--mmproj") + { if (++i >= argc) { invalid_param = true; @@ -2163,6 +2198,7 @@ int main(int argc, char **argv) LOG_INFO("build info", {{"build", BUILD_NUMBER}, {"commit", BUILD_COMMIT}}); + LOG_INFO("system info", { {"n_threads", params.n_threads}, {"n_threads_batch", params.n_threads_batch}, @@ -2239,10 +2275,12 @@ int main(int argc, char **argv) return; } } else { - const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink & sink) { - while(true) { + const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink & sink) + { + while (true) + { task_result result = llama.next_result(task_id); - if(!result.error) { + if (!result.error) { const std::string str = "data: " + result.result_json.dump(-1, ' ', false, json::error_handler_t::replace) + @@ -2264,10 +2302,13 @@ int main(int argc, char **argv) sink.done(); return true; }; - auto on_complete = [task_id, &llama] (bool) { + + auto on_complete = [task_id, &llama] (bool) + { // cancel llama.request_cancel(task_id); }; + res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); } }); @@ -2279,7 +2320,8 @@ int main(int argc, char **argv) if (!json_value(data, "stream", false)) { std::string completion_text; task_result result = llama.next_result(task_id); - if(!result.error && result.stop) { + if (!result.error && result.stop) + { res.set_content(result.result_json.dump(-1, ' ', false, json::error_handler_t::replace), "application/json"); } else @@ -2290,9 +2332,10 @@ int main(int argc, char **argv) } } else { const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink & sink) { - while(true) { + while (true) + { task_result result = llama.next_result(task_id); - if(!result.error) { + if (!result.error) { const std::string str = "data: " + result.result_json.dump(-1, ' ', false, json::error_handler_t::replace) + @@ -2304,20 +2347,28 @@ int main(int argc, char **argv) { return false; } - if(result.stop) { + if (result.stop) + { break; } - } else { + } + else + { break; } } + sink.done(); + return true; }; - auto on_complete = [task_id, &llama] (bool) { + + auto on_complete = [task_id, &llama] (bool) + { // cancel llama.request_cancel(task_id); }; + res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); } });