server : minor style

This commit is contained in:
Georgi Gerganov 2023-10-22 19:52:38 +03:00
parent a4d69d8b81
commit dd1af2ed35
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -403,6 +403,7 @@ struct llama_client_slot
sent_count = 0; sent_count = 0;
sent_token_probs_index = 0; sent_token_probs_index = 0;
infill = false; infill = false;
generated_token_probs.clear(); generated_token_probs.clear();
for (slot_image &img : images) for (slot_image &img : images)
@ -882,7 +883,8 @@ struct llama_server_context
// wait until system prompt load // wait until system prompt load
system_need_update = true; system_need_update = true;
while (system_need_update) { while (system_need_update)
{
std::this_thread::sleep_for(std::chrono::milliseconds(5)); std::this_thread::sleep_for(std::chrono::milliseconds(5));
} }
// system prompt loaded, continue // system prompt loaded, continue
@ -997,26 +999,31 @@ struct llama_server_context
const std::string str_test = slot.generated_text.substr(pos); const std::string str_test = slot.generated_text.substr(pos);
bool is_stop_full = false; bool is_stop_full = false;
size_t stop_pos = find_stopping_strings(str_test, token_str.size(), STOP_FULL, slot); size_t stop_pos = find_stopping_strings(str_test, token_str.size(), STOP_FULL, slot);
if (stop_pos != std::string::npos) { if (stop_pos != std::string::npos)
{
is_stop_full = true; is_stop_full = true;
slot.generated_text.erase( slot.generated_text.erase(
slot.generated_text.begin() + pos + stop_pos, slot.generated_text.begin() + pos + stop_pos,
slot.generated_text.end()); slot.generated_text.end());
pos = std::min(slot.sent_count, slot.generated_text.size()); pos = std::min(slot.sent_count, slot.generated_text.size());
} else { }
else
{
is_stop_full = false; is_stop_full = false;
stop_pos = find_stopping_strings(str_test, token_str.size(), STOP_PARTIAL, slot); stop_pos = find_stopping_strings(str_test, token_str.size(), STOP_PARTIAL, slot);
} }
// check if there is any token to predict // check if there is any token to predict
if(stop_pos == std::string::npos || (!slot.has_next_token && !is_stop_full && stop_pos > 0)) { if (stop_pos == std::string::npos || (!slot.has_next_token && !is_stop_full && stop_pos > 0))
{
// no send the stop word in the response // no send the stop word in the response
result.text_to_send = slot.generated_text.substr(pos, std::string::npos); result.text_to_send = slot.generated_text.substr(pos, std::string::npos);
slot.sent_count += result.text_to_send.size(); slot.sent_count += result.text_to_send.size();
// add the token to slot queue and cache // add the token to slot queue and cache
} }
slot.add_token_string(result); slot.add_token_string(result);
if (slot.params.stream) { if (slot.params.stream)
{
send_partial_response(slot, result); send_partial_response(slot, result);
} }
} }
@ -1051,6 +1058,7 @@ struct llama_server_context
{"stopped_limit", slot.stopped_limit}, {"stopped_limit", slot.stopped_limit},
{"stopping_word", slot.stopping_word}, {"stopping_word", slot.stopping_word},
}); });
return slot.has_next_token; // continue return slot.has_next_token; // continue
} }
@ -1089,7 +1097,8 @@ struct llama_server_context
return slot.images.size() > 0; return slot.images.size() > 0;
} }
void send_error(int id, std::string error) { void send_error(int id, std::string error)
{
std::lock_guard<std::mutex> lock(mutex_results); std::lock_guard<std::mutex> lock(mutex_results);
task_result res; task_result res;
res.id = id; res.id = id;
@ -1098,11 +1107,13 @@ struct llama_server_context
queue_results.push_back(res); queue_results.push_back(res);
} }
json get_model_props() { json get_model_props()
{
return get_formated_generation(slots[0]); return get_formated_generation(slots[0]);
} }
json get_formated_generation(llama_client_slot &slot) { json get_formated_generation(llama_client_slot &slot)
{
const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(ctx)); const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(ctx));
const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() && const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() &&
eos_bias->second < 0.0f && std::isinf(eos_bias->second); eos_bias->second < 0.0f && std::isinf(eos_bias->second);
@ -1134,12 +1145,14 @@ struct llama_server_context
}; };
} }
void send_partial_response(llama_client_slot & slot, completion_token_output tkn) { void send_partial_response(llama_client_slot &slot, completion_token_output tkn)
{
std::lock_guard<std::mutex> lock(mutex_results); std::lock_guard<std::mutex> lock(mutex_results);
task_result res; task_result res;
res.id = slot.task_id; res.id = slot.task_id;
res.error = false; res.error = false;
res.stop = false; res.stop = false;
res.result_json = json res.result_json = json
{ {
{"content", tkn.text_to_send}, {"content", tkn.text_to_send},
@ -1147,6 +1160,7 @@ struct llama_server_context
{"slot_id", slot.id}, {"slot_id", slot.id},
{"multimodal", multimodal} {"multimodal", multimodal}
}; };
if (slot.sparams.n_probs > 0) if (slot.sparams.n_probs > 0)
{ {
std::vector<completion_token_output> probs_output = {}; std::vector<completion_token_output> probs_output = {};
@ -1160,15 +1174,18 @@ struct llama_server_context
slot.sent_token_probs_index = probs_stop_pos; slot.sent_token_probs_index = probs_stop_pos;
res.result_json["completion_probabilities"] = probs_vector_to_json(ctx, probs_output); res.result_json["completion_probabilities"] = probs_vector_to_json(ctx, probs_output);
} }
queue_results.push_back(res); queue_results.push_back(res);
} }
void send_final_response(llama_client_slot & slot) { void send_final_response(llama_client_slot &slot)
{
std::lock_guard<std::mutex> lock(mutex_results); std::lock_guard<std::mutex> lock(mutex_results);
task_result res; task_result res;
res.id = slot.task_id; res.id = slot.task_id;
res.error = false; res.error = false;
res.stop = true; res.stop = true;
res.result_json = json res.result_json = json
{ {
{"content", !slot.params.stream ? slot.generated_text : ""}, {"content", !slot.params.stream ? slot.generated_text : ""},
@ -1191,20 +1208,25 @@ struct llama_server_context
if (slot.sparams.n_probs > 0) if (slot.sparams.n_probs > 0)
{ {
std::vector<completion_token_output> probs = {}; std::vector<completion_token_output> probs = {};
if(!slot.params.stream && slot.stopped_word) { if (!slot.params.stream && slot.stopped_word)
{
const std::vector<llama_token> stop_word_toks = llama_tokenize(ctx, slot.stopping_word, false); const std::vector<llama_token> stop_word_toks = llama_tokenize(ctx, slot.stopping_word, false);
probs = std::vector<completion_token_output>(slot.generated_token_probs.begin(), slot.generated_token_probs.end() - stop_word_toks.size()); probs = std::vector<completion_token_output>(slot.generated_token_probs.begin(), slot.generated_token_probs.end() - stop_word_toks.size());
} else { }
else
{
probs = std::vector<completion_token_output>( probs = std::vector<completion_token_output>(
slot.generated_token_probs.begin(), slot.generated_token_probs.begin(),
slot.generated_token_probs.begin() + slot.sent_token_probs_index); slot.generated_token_probs.begin() + slot.sent_token_probs_index);
} }
res.result_json["completion_probabilities"] = probs_vector_to_json(ctx, probs); res.result_json["completion_probabilities"] = probs_vector_to_json(ctx, probs);
} }
queue_results.push_back(res); queue_results.push_back(res);
} }
void send_embedding(llama_client_slot & slot) { void send_embedding(llama_client_slot &slot)
{
std::lock_guard<std::mutex> lock(mutex_results); std::lock_guard<std::mutex> lock(mutex_results);
task_result res; task_result res;
res.id = slot.task_id; res.id = slot.task_id;
@ -1234,7 +1256,8 @@ struct llama_server_context
queue_results.push_back(res); queue_results.push_back(res);
} }
int request_completion(json data, bool infill) { int request_completion(json data, bool infill)
{
std::lock_guard<std::mutex> lock(mutex_tasks); std::lock_guard<std::mutex> lock(mutex_tasks);
task_server task; task_server task;
task.id = id_gen++; task.id = id_gen++;
@ -1245,17 +1268,22 @@ struct llama_server_context
return task.id; return task.id;
} }
task_result next_result(int task_id) { task_result next_result(int task_id)
while (true) { {
while (true)
{
std::this_thread::sleep_for(std::chrono::microseconds(5)); std::this_thread::sleep_for(std::chrono::microseconds(5));
std::lock_guard<std::mutex> lock(mutex_results); std::lock_guard<std::mutex> lock(mutex_results);
if (queue_results.empty()) { if (queue_results.empty())
{
continue; continue;
} }
for (int i = 0; i < (int) queue_results.size(); i++) { for (int i = 0; i < (int) queue_results.size(); i++)
if (queue_results[i].id == task_id) { {
if (queue_results[i].id == task_id)
{
task_result res = queue_results[i]; task_result res = queue_results[i];
queue_results.erase(queue_results.begin() + i); queue_results.erase(queue_results.begin() + i);
return res; return res;
@ -1335,7 +1363,8 @@ struct llama_server_context
return true; return true;
} }
void request_cancel(int task_id) { void request_cancel(int task_id)
{
std::lock_guard<std::mutex> lock(mutex_tasks); std::lock_guard<std::mutex> lock(mutex_tasks);
task_server task; task_server task;
task.id = id_gen++; task.id = id_gen++;
@ -1344,9 +1373,11 @@ struct llama_server_context
queue_tasks.push_back(task); queue_tasks.push_back(task);
} }
void process_tasks() { void process_tasks()
{
std::lock_guard<std::mutex> lock(mutex_tasks); std::lock_guard<std::mutex> lock(mutex_tasks);
while (!queue_tasks.empty()) { while (!queue_tasks.empty())
{
task_server task = queue_tasks.front(); task_server task = queue_tasks.front();
queue_tasks.erase(queue_tasks.begin()); queue_tasks.erase(queue_tasks.begin());
switch (task.type) switch (task.type)
@ -1379,8 +1410,10 @@ struct llama_server_context
} }
} break; } break;
case CANCEL_TASK: { // release slot linked with the task id case CANCEL_TASK: { // release slot linked with the task id
for (auto & slot : slots) { for (auto & slot : slots)
if (slot.task_id == task.target_id) { {
if (slot.task_id == task.target_id)
{
slot.release(); slot.release();
break; break;
} }
@ -2006,7 +2039,8 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
else if (arg == "--embedding") else if (arg == "--embedding")
{ {
params.embedding = true; params.embedding = true;
} else if (arg == "-cb" || arg == "--cont-batching") }
else if (arg == "-cb" || arg == "--cont-batching")
{ {
params.cont_batching = true; params.cont_batching = true;
} }
@ -2047,7 +2081,8 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
); );
llama.process_system_prompt_data(json::parse(systm_content)); llama.process_system_prompt_data(json::parse(systm_content));
} }
else if(arg == "--mmproj") { else if(arg == "--mmproj")
{
if (++i >= argc) if (++i >= argc)
{ {
invalid_param = true; invalid_param = true;
@ -2163,6 +2198,7 @@ int main(int argc, char **argv)
LOG_INFO("build info", {{"build", BUILD_NUMBER}, LOG_INFO("build info", {{"build", BUILD_NUMBER},
{"commit", BUILD_COMMIT}}); {"commit", BUILD_COMMIT}});
LOG_INFO("system info", { LOG_INFO("system info", {
{"n_threads", params.n_threads}, {"n_threads", params.n_threads},
{"n_threads_batch", params.n_threads_batch}, {"n_threads_batch", params.n_threads_batch},
@ -2239,8 +2275,10 @@ int main(int argc, char **argv)
return; return;
} }
} else { } else {
const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink & sink) { const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink & sink)
while(true) { {
while (true)
{
task_result result = llama.next_result(task_id); task_result result = llama.next_result(task_id);
if (!result.error) { if (!result.error) {
const std::string str = const std::string str =
@ -2264,10 +2302,13 @@ int main(int argc, char **argv)
sink.done(); sink.done();
return true; return true;
}; };
auto on_complete = [task_id, &llama] (bool) {
auto on_complete = [task_id, &llama] (bool)
{
// cancel // cancel
llama.request_cancel(task_id); llama.request_cancel(task_id);
}; };
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
} }
}); });
@ -2279,7 +2320,8 @@ int main(int argc, char **argv)
if (!json_value(data, "stream", false)) { if (!json_value(data, "stream", false)) {
std::string completion_text; std::string completion_text;
task_result result = llama.next_result(task_id); task_result result = llama.next_result(task_id);
if(!result.error && result.stop) { if (!result.error && result.stop)
{
res.set_content(result.result_json.dump(-1, ' ', false, json::error_handler_t::replace), "application/json"); res.set_content(result.result_json.dump(-1, ' ', false, json::error_handler_t::replace), "application/json");
} }
else else
@ -2290,7 +2332,8 @@ int main(int argc, char **argv)
} }
} else { } else {
const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink & sink) { const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink & sink) {
while(true) { while (true)
{
task_result result = llama.next_result(task_id); task_result result = llama.next_result(task_id);
if (!result.error) { if (!result.error) {
const std::string str = const std::string str =
@ -2304,20 +2347,28 @@ int main(int argc, char **argv)
{ {
return false; return false;
} }
if(result.stop) { if (result.stop)
break; {
}
} else {
break; break;
} }
} }
else
{
break;
}
}
sink.done(); sink.done();
return true; return true;
}; };
auto on_complete = [task_id, &llama] (bool) {
auto on_complete = [task_id, &llama] (bool)
{
// cancel // cancel
llama.request_cancel(task_id); llama.request_cancel(task_id);
}; };
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
} }
}); });