server : minor style

This commit is contained in:
Georgi Gerganov 2023-10-22 19:52:38 +03:00
parent a4d69d8b81
commit dd1af2ed35
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -391,18 +391,19 @@ struct llama_client_slot
double t_token_generation; // ms
void reset() {
num_prompt_tokens = 0;
generated_text = "";
truncated = false;
stopped_eos = false;
stopped_word = false;
stopped_limit = false;
stopping_word = "";
multibyte_pending = 0;
n_past = 0;
sent_count = 0;
num_prompt_tokens = 0;
generated_text = "";
truncated = false;
stopped_eos = false;
stopped_word = false;
stopped_limit = false;
stopping_word = "";
multibyte_pending = 0;
n_past = 0;
sent_count = 0;
sent_token_probs_index = 0;
infill = false;
infill = false;
generated_token_probs.clear();
for (slot_image &img : images)
@ -882,7 +883,8 @@ struct llama_server_context
// wait until system prompt load
system_need_update = true;
while (system_need_update) {
while (system_need_update)
{
std::this_thread::sleep_for(std::chrono::milliseconds(5));
}
// system prompt loaded, continue
@ -997,26 +999,31 @@ struct llama_server_context
const std::string str_test = slot.generated_text.substr(pos);
bool is_stop_full = false;
size_t stop_pos = find_stopping_strings(str_test, token_str.size(), STOP_FULL, slot);
if (stop_pos != std::string::npos) {
if (stop_pos != std::string::npos)
{
is_stop_full = true;
slot.generated_text.erase(
slot.generated_text.begin() + pos + stop_pos,
slot.generated_text.end());
pos = std::min(slot.sent_count, slot.generated_text.size());
} else {
}
else
{
is_stop_full = false;
stop_pos = find_stopping_strings(str_test, token_str.size(), STOP_PARTIAL, slot);
}
// check if there is any token to predict
if(stop_pos == std::string::npos || (!slot.has_next_token && !is_stop_full && stop_pos > 0)) {
if (stop_pos == std::string::npos || (!slot.has_next_token && !is_stop_full && stop_pos > 0))
{
// no send the stop word in the response
result.text_to_send = slot.generated_text.substr(pos, std::string::npos);
slot.sent_count += result.text_to_send.size();
// add the token to slot queue and cache
}
slot.add_token_string(result);
if (slot.params.stream) {
if (slot.params.stream)
{
send_partial_response(slot, result);
}
}
@ -1051,6 +1058,7 @@ struct llama_server_context
{"stopped_limit", slot.stopped_limit},
{"stopping_word", slot.stopping_word},
});
return slot.has_next_token; // continue
}
@ -1089,7 +1097,8 @@ struct llama_server_context
return slot.images.size() > 0;
}
void send_error(int id, std::string error) {
void send_error(int id, std::string error)
{
std::lock_guard<std::mutex> lock(mutex_results);
task_result res;
res.id = id;
@ -1098,11 +1107,13 @@ struct llama_server_context
queue_results.push_back(res);
}
json get_model_props() {
json get_model_props()
{
return get_formated_generation(slots[0]);
}
json get_formated_generation(llama_client_slot &slot) {
json get_formated_generation(llama_client_slot &slot)
{
const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(ctx));
const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() &&
eos_bias->second < 0.0f && std::isinf(eos_bias->second);
@ -1134,12 +1145,14 @@ struct llama_server_context
};
}
void send_partial_response(llama_client_slot & slot, completion_token_output tkn) {
void send_partial_response(llama_client_slot &slot, completion_token_output tkn)
{
std::lock_guard<std::mutex> lock(mutex_results);
task_result res;
res.id = slot.task_id;
res.error = false;
res.stop = false;
res.result_json = json
{
{"content", tkn.text_to_send},
@ -1147,6 +1160,7 @@ struct llama_server_context
{"slot_id", slot.id},
{"multimodal", multimodal}
};
if (slot.sparams.n_probs > 0)
{
std::vector<completion_token_output> probs_output = {};
@ -1160,15 +1174,18 @@ struct llama_server_context
slot.sent_token_probs_index = probs_stop_pos;
res.result_json["completion_probabilities"] = probs_vector_to_json(ctx, probs_output);
}
queue_results.push_back(res);
}
void send_final_response(llama_client_slot & slot) {
void send_final_response(llama_client_slot &slot)
{
std::lock_guard<std::mutex> lock(mutex_results);
task_result res;
res.id = slot.task_id;
res.error = false;
res.stop = true;
res.result_json = json
{
{"content", !slot.params.stream ? slot.generated_text : ""},
@ -1191,20 +1208,25 @@ struct llama_server_context
if (slot.sparams.n_probs > 0)
{
std::vector<completion_token_output> probs = {};
if(!slot.params.stream && slot.stopped_word) {
if (!slot.params.stream && slot.stopped_word)
{
const std::vector<llama_token> stop_word_toks = llama_tokenize(ctx, slot.stopping_word, false);
probs = std::vector<completion_token_output>(slot.generated_token_probs.begin(), slot.generated_token_probs.end() - stop_word_toks.size());
} else {
}
else
{
probs = std::vector<completion_token_output>(
slot.generated_token_probs.begin(),
slot.generated_token_probs.begin() + slot.sent_token_probs_index);
}
res.result_json["completion_probabilities"] = probs_vector_to_json(ctx, probs);
}
queue_results.push_back(res);
}
void send_embedding(llama_client_slot & slot) {
void send_embedding(llama_client_slot &slot)
{
std::lock_guard<std::mutex> lock(mutex_results);
task_result res;
res.id = slot.task_id;
@ -1234,7 +1256,8 @@ struct llama_server_context
queue_results.push_back(res);
}
int request_completion(json data, bool infill) {
int request_completion(json data, bool infill)
{
std::lock_guard<std::mutex> lock(mutex_tasks);
task_server task;
task.id = id_gen++;
@ -1245,17 +1268,22 @@ struct llama_server_context
return task.id;
}
task_result next_result(int task_id) {
while (true) {
task_result next_result(int task_id)
{
while (true)
{
std::this_thread::sleep_for(std::chrono::microseconds(5));
std::lock_guard<std::mutex> lock(mutex_results);
if (queue_results.empty()) {
if (queue_results.empty())
{
continue;
}
for (int i = 0; i < (int) queue_results.size(); i++) {
if (queue_results[i].id == task_id) {
for (int i = 0; i < (int) queue_results.size(); i++)
{
if (queue_results[i].id == task_id)
{
task_result res = queue_results[i];
queue_results.erase(queue_results.begin() + i);
return res;
@ -1335,7 +1363,8 @@ struct llama_server_context
return true;
}
void request_cancel(int task_id) {
void request_cancel(int task_id)
{
std::lock_guard<std::mutex> lock(mutex_tasks);
task_server task;
task.id = id_gen++;
@ -1344,9 +1373,11 @@ struct llama_server_context
queue_tasks.push_back(task);
}
void process_tasks() {
void process_tasks()
{
std::lock_guard<std::mutex> lock(mutex_tasks);
while (!queue_tasks.empty()) {
while (!queue_tasks.empty())
{
task_server task = queue_tasks.front();
queue_tasks.erase(queue_tasks.begin());
switch (task.type)
@ -1379,8 +1410,10 @@ struct llama_server_context
}
} break;
case CANCEL_TASK: { // release slot linked with the task id
for (auto & slot : slots) {
if (slot.task_id == task.target_id) {
for (auto & slot : slots)
{
if (slot.task_id == task.target_id)
{
slot.release();
break;
}
@ -2006,7 +2039,8 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
else if (arg == "--embedding")
{
params.embedding = true;
} else if (arg == "-cb" || arg == "--cont-batching")
}
else if (arg == "-cb" || arg == "--cont-batching")
{
params.cont_batching = true;
}
@ -2047,7 +2081,8 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
);
llama.process_system_prompt_data(json::parse(systm_content));
}
else if(arg == "--mmproj") {
else if(arg == "--mmproj")
{
if (++i >= argc)
{
invalid_param = true;
@ -2163,6 +2198,7 @@ int main(int argc, char **argv)
LOG_INFO("build info", {{"build", BUILD_NUMBER},
{"commit", BUILD_COMMIT}});
LOG_INFO("system info", {
{"n_threads", params.n_threads},
{"n_threads_batch", params.n_threads_batch},
@ -2239,10 +2275,12 @@ int main(int argc, char **argv)
return;
}
} else {
const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink & sink) {
while(true) {
const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink & sink)
{
while (true)
{
task_result result = llama.next_result(task_id);
if(!result.error) {
if (!result.error) {
const std::string str =
"data: " +
result.result_json.dump(-1, ' ', false, json::error_handler_t::replace) +
@ -2264,10 +2302,13 @@ int main(int argc, char **argv)
sink.done();
return true;
};
auto on_complete = [task_id, &llama] (bool) {
auto on_complete = [task_id, &llama] (bool)
{
// cancel
llama.request_cancel(task_id);
};
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
}
});
@ -2279,7 +2320,8 @@ int main(int argc, char **argv)
if (!json_value(data, "stream", false)) {
std::string completion_text;
task_result result = llama.next_result(task_id);
if(!result.error && result.stop) {
if (!result.error && result.stop)
{
res.set_content(result.result_json.dump(-1, ' ', false, json::error_handler_t::replace), "application/json");
}
else
@ -2290,9 +2332,10 @@ int main(int argc, char **argv)
}
} else {
const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink & sink) {
while(true) {
while (true)
{
task_result result = llama.next_result(task_id);
if(!result.error) {
if (!result.error) {
const std::string str =
"data: " +
result.result_json.dump(-1, ' ', false, json::error_handler_t::replace) +
@ -2304,20 +2347,28 @@ int main(int argc, char **argv)
{
return false;
}
if(result.stop) {
if (result.stop)
{
break;
}
} else {
}
else
{
break;
}
}
sink.done();
return true;
};
auto on_complete = [task_id, &llama] (bool) {
auto on_complete = [task_id, &llama] (bool)
{
// cancel
llama.request_cancel(task_id);
};
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
}
});