server : (refactoring) do not rely on JSON internally (#10643)

* server : (refactoring) reduce usage of json internally

* move all response types to struct

* wip [no ci]

* many fixes

* add virtual function

* fix index

* minor style fix

* add std::move

* refactor handle_completions_generic

* add virtual functions

* remove server.hpp

* clarify server_sent_event RFC specs

* apply review comments

* fix model_alias and completion_probabilities

* small clean up

* remove virtual for to_json_oai_compat()

* naming oai_compat --> oaicompat

* fix unwanted recursive call

* update docs
This commit is contained in:
Xuan Son Nguyen 2024-12-06 11:14:32 +01:00 committed by GitHub
parent 7736837d62
commit 6c5bc0625f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 985 additions and 697 deletions

View File

@ -215,7 +215,7 @@ struct common_params {
struct common_params_speculative speculative; struct common_params_speculative speculative;
std::string model = ""; // model path // NOLINT std::string model = ""; // model path // NOLINT
std::string model_alias = "unknown"; // model alias // NOLINT std::string model_alias = ""; // model alias // NOLINT
std::string model_url = ""; // model url to download // NOLINT std::string model_url = ""; // model url to download // NOLINT
std::string hf_token = ""; // HF token // NOLINT std::string hf_token = ""; // HF token // NOLINT
std::string hf_repo = ""; // HF repo // NOLINT std::string hf_repo = ""; // HF repo // NOLINT

View File

@ -473,9 +473,11 @@ Notice that each `probs` is an array of length `n_probs`.
- `generation_settings`: The provided options above excluding `prompt` but including `n_ctx`, `model`. These options may differ from the original ones in some way (e.g. bad values filtered out, strings converted to tokens, etc.). - `generation_settings`: The provided options above excluding `prompt` but including `n_ctx`, `model`. These options may differ from the original ones in some way (e.g. bad values filtered out, strings converted to tokens, etc.).
- `model`: The path to the model loaded with `-m` - `model`: The path to the model loaded with `-m`
- `prompt`: The provided `prompt` - `prompt`: The provided `prompt`
- `stopped_eos`: Indicating whether the completion has stopped because it encountered the EOS token - `stop_type`: Indicating whether the completion has stopped. Possible values are:
- `stopped_limit`: Indicating whether the completion stopped because `n_predict` tokens were generated before stop words or EOS was encountered - `none`: Generating (not stopped)
- `stopped_word`: Indicating whether the completion stopped due to encountering a stopping word from `stop` JSON array provided - `eos`: Stopped because it encountered the EOS token
- `limit`: Stopped because `n_predict` tokens were generated before stop words or EOS was encountered
- `word`: Stopped due to encountering a stopping word from `stop` JSON array provided
- `stopping_word`: The stopping word encountered which stopped the generation (or "" if not stopped due to a stopping word) - `stopping_word`: The stopping word encountered which stopped the generation (or "" if not stopped due to a stopping word)
- `timings`: Hash of timing information about the completion such as the number of tokens `predicted_per_second` - `timings`: Hash of timing information about the completion such as the number of tokens `predicted_per_second`
- `tokens_cached`: Number of tokens from the prompt which could be re-used from previous completion (`n_past`) - `tokens_cached`: Number of tokens from the prompt which could be re-used from previous completion (`n_past`)

File diff suppressed because it is too large Load Diff

View File

@ -44,4 +44,10 @@ To run with stdout/stderr display in real time (verbose output, but useful for d
DEBUG=1 ./tests.sh -s -v -x DEBUG=1 ./tests.sh -s -v -x
``` ```
Hint: You can compile and run test in single command, useful for local developement:
```shell
cmake --build build -j --target llama-server && ./examples/server/tests/tests.sh
```
To see all available arguments, please refer to [pytest documentation](https://docs.pytest.org/en/stable/how-to/usage.html) To see all available arguments, please refer to [pytest documentation](https://docs.pytest.org/en/stable/how-to/usage.html)

View File

@ -1,5 +1,9 @@
#!/bin/bash #!/bin/bash
# make sure we are in the right directory
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
cd $SCRIPT_DIR
set -eu set -eu
if [ $# -lt 1 ] if [ $# -lt 1 ]

View File

@ -12,13 +12,13 @@ def create_server():
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,truncated", "model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason",
[ [
("llama-2", "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, False), (None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"),
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, False), ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length"),
] ]
) )
def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, truncated): def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason):
global server global server
server.start() server.start()
res = server.make_request("POST", "/chat/completions", data={ res = server.make_request("POST", "/chat/completions", data={
@ -30,29 +30,27 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte
], ],
}) })
assert res.status_code == 200 assert res.status_code == 200
assert res.body["model"] == model if model is not None else server.model_alias
assert res.body["usage"]["prompt_tokens"] == n_prompt assert res.body["usage"]["prompt_tokens"] == n_prompt
assert res.body["usage"]["completion_tokens"] == n_predicted assert res.body["usage"]["completion_tokens"] == n_predicted
choice = res.body["choices"][0] choice = res.body["choices"][0]
assert "assistant" == choice["message"]["role"] assert "assistant" == choice["message"]["role"]
assert match_regex(re_content, choice["message"]["content"]) assert match_regex(re_content, choice["message"]["content"])
if truncated: assert choice["finish_reason"] == finish_reason
assert choice["finish_reason"] == "length"
else:
assert choice["finish_reason"] == "stop"
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,truncated", "system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason",
[ [
("llama-2", "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, False), ("Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"),
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, False), ("You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length"),
] ]
) )
def test_chat_completion_stream(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, truncated): def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason):
global server global server
server.model_alias = None # try using DEFAULT_OAICOMPAT_MODEL
server.start() server.start()
res = server.make_stream_request("POST", "/chat/completions", data={ res = server.make_stream_request("POST", "/chat/completions", data={
"model": model,
"max_tokens": max_tokens, "max_tokens": max_tokens,
"messages": [ "messages": [
{"role": "system", "content": system_prompt}, {"role": "system", "content": system_prompt},
@ -63,16 +61,13 @@ def test_chat_completion_stream(model, system_prompt, user_prompt, max_tokens, r
content = "" content = ""
for data in res: for data in res:
choice = data["choices"][0] choice = data["choices"][0]
assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future
if choice["finish_reason"] in ["stop", "length"]: if choice["finish_reason"] in ["stop", "length"]:
assert data["usage"]["prompt_tokens"] == n_prompt assert data["usage"]["prompt_tokens"] == n_prompt
assert data["usage"]["completion_tokens"] == n_predicted assert data["usage"]["completion_tokens"] == n_predicted
assert "content" not in choice["delta"] assert "content" not in choice["delta"]
assert match_regex(re_content, content) assert match_regex(re_content, content)
# FIXME: not sure why this is incorrect in stream mode assert choice["finish_reason"] == finish_reason
# if truncated:
# assert choice["finish_reason"] == "length"
# else:
# assert choice["finish_reason"] == "stop"
else: else:
assert choice["finish_reason"] is None assert choice["finish_reason"] is None
content += choice["delta"]["content"] content += choice["delta"]["content"]
@ -93,7 +88,7 @@ def test_chat_completion_with_openai_library():
temperature=0.8, temperature=0.8,
) )
print(res) print(res)
assert res.choices[0].finish_reason == "stop" assert res.choices[0].finish_reason == "length"
assert res.choices[0].message.content is not None assert res.choices[0].message.content is not None
assert match_regex("(Suddenly)+", res.choices[0].message.content) assert match_regex("(Suddenly)+", res.choices[0].message.content)

View File

@ -51,6 +51,24 @@ def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_promp
content += data["content"] content += data["content"]
def test_completion_stream_vs_non_stream():
global server
server.start()
res_stream = server.make_stream_request("POST", "/completion", data={
"n_predict": 8,
"prompt": "I believe the meaning of life is",
"stream": True,
})
res_non_stream = server.make_request("POST", "/completion", data={
"n_predict": 8,
"prompt": "I believe the meaning of life is",
})
content_stream = ""
for data in res_stream:
content_stream += data["content"]
assert content_stream == res_non_stream.body["content"]
@pytest.mark.parametrize("n_slots", [1, 2]) @pytest.mark.parametrize("n_slots", [1, 2])
def test_consistent_result_same_seed(n_slots: int): def test_consistent_result_same_seed(n_slots: int):
global server global server
@ -221,3 +239,24 @@ def test_completion_parallel_slots(n_slots: int, n_requests: int):
assert len(res.body["content"]) > 10 assert len(res.body["content"]) > 10
# FIXME: the result is not deterministic when using other slot than slot 0 # FIXME: the result is not deterministic when using other slot than slot 0
# assert match_regex(re_content, res.body["content"]) # assert match_regex(re_content, res.body["content"])
def test_n_probs():
global server
server.start()
res = server.make_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is",
"n_probs": 10,
"temperature": 0.0,
"n_predict": 5,
})
assert res.status_code == 200
assert "completion_probabilities" in res.body
assert len(res.body["completion_probabilities"]) == 5
for tok in res.body["completion_probabilities"]:
assert "probs" in tok
assert len(tok["probs"]) == 10
for prob in tok["probs"]:
assert "prob" in prob
assert "tok_str" in prob
assert 0.0 <= prob["prob"] <= 1.0

View File

@ -20,6 +20,7 @@
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <vector> #include <vector>
#include <memory>
#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613" #define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
@ -40,17 +41,6 @@ using json = nlohmann::ordered_json;
#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) #define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) #define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
enum error_type {
ERROR_TYPE_INVALID_REQUEST,
ERROR_TYPE_AUTHENTICATION,
ERROR_TYPE_SERVER,
ERROR_TYPE_NOT_FOUND,
ERROR_TYPE_PERMISSION,
ERROR_TYPE_UNAVAILABLE, // custom error
ERROR_TYPE_NOT_SUPPORTED, // custom error
};
template <typename T> template <typename T>
static T json_value(const json & body, const std::string & key, const T & default_value) { static T json_value(const json & body, const std::string & key, const T & default_value) {
// Fallback null to default value // Fallback null to default value
@ -485,48 +475,11 @@ static std::string tokens_to_output_formatted_string(const llama_context * ctx,
return out; return out;
} }
struct completion_token_output {
llama_token tok;
std::string text_to_send;
struct token_prob {
llama_token tok;
float prob;
};
std::vector<token_prob> probs;
};
// convert a vector of completion_token_output to json
static json probs_vector_to_json(const llama_context * ctx, const std::vector<completion_token_output> & probs) {
json out = json::array();
for (const auto & prob : probs) {
json probs_for_token = json::array();
for (const auto & p : prob.probs) {
const std::string tok_str = tokens_to_output_formatted_string(ctx, p.tok);
probs_for_token.push_back(json {
{"tok_str", tok_str},
{"prob", p.prob},
});
}
const std::string tok_str = tokens_to_output_formatted_string(ctx, prob.tok);
out.push_back(json {
{"content", tok_str},
{"probs", probs_for_token},
});
}
return out;
}
static bool server_sent_event(httplib::DataSink & sink, const char * event, const json & data) { static bool server_sent_event(httplib::DataSink & sink, const char * event, const json & data) {
const std::string str = const std::string str =
std::string(event) + ": " + std::string(event) + ": " +
data.dump(-1, ' ', false, json::error_handler_t::replace) + data.dump(-1, ' ', false, json::error_handler_t::replace) +
"\n\n"; // note: these newlines are important (not sure why though, if you know, add a comment to explain) "\n\n"; // required by RFC 8895 - A message is terminated by a blank line (two line terminators in a row).
LOG_DBG("data stream, to_send: %s", str.c_str()); LOG_DBG("data stream, to_send: %s", str.c_str());
@ -604,164 +557,6 @@ static json oaicompat_completion_params_parse(
return llama_params; return llama_params;
} }
static json format_final_response_oaicompat(const json & request, const json & result, const std::string & completion_id, bool streaming = false, bool verbose = false) {
bool stopped_word = result.count("stopped_word") != 0;
bool stopped_eos = json_value(result, "stopped_eos", false);
int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);
std::string content = json_value(result, "content", std::string(""));
std::string finish_reason = "length";
if (stopped_word || stopped_eos) {
finish_reason = "stop";
}
json choices =
streaming ? json::array({json{{"finish_reason", finish_reason},
{"index", 0},
{"delta", json::object()}}})
: json::array({json{{"finish_reason", finish_reason},
{"index", 0},
{"message", json{{"content", content},
{"role", "assistant"}}}}});
std::time_t t = std::time(0);
json res = json {
{"choices", choices},
{"created", t},
{"model",
json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
{"object", streaming ? "chat.completion.chunk" : "chat.completion"},
{"usage", json {
{"completion_tokens", num_tokens_predicted},
{"prompt_tokens", num_prompt_tokens},
{"total_tokens", num_tokens_predicted + num_prompt_tokens}
}},
{"id", completion_id}
};
// extra fields for debugging purposes
if (verbose) {
res["__verbose"] = result;
}
if (result.contains("completion_probabilities")) {
res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array());
}
if (result.contains("timings")) {
res.push_back({"timings", json_value(result, "timings", json::object())});
}
return res;
}
// return value is vector as there is one case where we might need to generate two responses
static std::vector<json> format_partial_response_oaicompat(const json & result, const std::string & completion_id) {
if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) {
return std::vector<json>({result});
}
bool first = json_value(result, "oaicompat_token_ctr", 0) == 0;
std::string modelname = json_value(result, "model", std::string(DEFAULT_OAICOMPAT_MODEL));
bool stopped_word = json_value(result, "stopped_word", false);
bool stopped_eos = json_value(result, "stopped_eos", false);
bool stopped_limit = json_value(result, "stopped_limit", false);
std::string content = json_value(result, "content", std::string(""));
std::string finish_reason;
if (stopped_word || stopped_eos) {
finish_reason = "stop";
}
if (stopped_limit) {
finish_reason = "length";
}
std::time_t t = std::time(0);
json choices;
if (!finish_reason.empty()) {
choices = json::array({json{{"finish_reason", finish_reason},
{"index", 0},
{"delta", json::object()}}});
} else {
if (first) {
if (content.empty()) {
choices = json::array({json{{"finish_reason", nullptr},
{"index", 0},
{"delta", json{{"role", "assistant"}}}}});
} else {
// We have to send this as two updates to conform to openai behavior
json initial_ret = json{{"choices", json::array({json{
{"finish_reason", nullptr},
{"index", 0},
{"delta", json{
{"role", "assistant"}
}}}})},
{"created", t},
{"id", completion_id},
{"model", modelname},
{"object", "chat.completion.chunk"}};
json second_ret = json{
{"choices", json::array({json{{"finish_reason", nullptr},
{"index", 0},
{"delta", json{
{"content", content}}}
}})},
{"created", t},
{"id", completion_id},
{"model", modelname},
{"object", "chat.completion.chunk"}};
return std::vector<json>({initial_ret, second_ret});
}
} else {
// Some idiosyncrasy in task processing logic makes several trailing calls
// with empty content, we ignore these at the calee site.
if (content.empty()) {
return std::vector<json>({json::object()});
}
choices = json::array({json{
{"finish_reason", nullptr},
{"index", 0},
{"delta",
json{
{"content", content},
}},
}});
}
}
json ret = json {
{"choices", choices},
{"created", t},
{"id", completion_id},
{"model", modelname},
{"object", "chat.completion.chunk"}
};
if (result.contains("timings")) {
ret.push_back({"timings", json_value(result, "timings", json::object())});
}
if (!finish_reason.empty()) {
int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);
ret.push_back({"usage", json {
{"completion_tokens", num_tokens_predicted},
{"prompt_tokens", num_prompt_tokens},
{"total_tokens", num_tokens_predicted + num_prompt_tokens}
}});
}
return std::vector<json>({ret});
}
static json format_embeddings_response_oaicompat(const json & request, const json & embeddings) { static json format_embeddings_response_oaicompat(const json & request, const json & embeddings) {
json data = json::array(); json data = json::array();
int i = 0; int i = 0;
@ -853,43 +648,3 @@ static json format_detokenized_response(const std::string & content) {
{"content", content} {"content", content}
}; };
} }
static json format_error_response(const std::string & message, const enum error_type type) {
std::string type_str;
int code = 500;
switch (type) {
case ERROR_TYPE_INVALID_REQUEST:
type_str = "invalid_request_error";
code = 400;
break;
case ERROR_TYPE_AUTHENTICATION:
type_str = "authentication_error";
code = 401;
break;
case ERROR_TYPE_NOT_FOUND:
type_str = "not_found_error";
code = 404;
break;
case ERROR_TYPE_SERVER:
type_str = "server_error";
code = 500;
break;
case ERROR_TYPE_PERMISSION:
type_str = "permission_error";
code = 403;
break;
case ERROR_TYPE_NOT_SUPPORTED:
type_str = "not_supported_error";
code = 501;
break;
case ERROR_TYPE_UNAVAILABLE:
type_str = "unavailable_error";
code = 503;
break;
}
return json {
{"code", code},
{"message", message},
{"type", type_str},
};
}