From 2e04ccf4e66a56eade51c2b62d7fe9026021fbb9 Mon Sep 17 00:00:00 2001 From: nvrxq Date: Wed, 18 Dec 2024 01:21:44 +0300 Subject: [PATCH 1/4] llama_server_response_fields --- examples/server/server.cpp | 6 +++++- examples/server/utils.hpp | 27 +++++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 436170a03..bc179cfb5 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -91,6 +91,7 @@ struct slot_params { int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit std::vector antiprompt; + std::vector requested_fields; bool timings_per_token = false; bool ignore_eos = false; @@ -205,6 +206,7 @@ struct server_task { params.n_discard = json_value(data, "n_discard", defaults.n_discard); //params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms); + params.requested_fields = json_value(data, "requested_fields", std::vector()); params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k); params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p); @@ -482,6 +484,7 @@ struct server_task_result_cmpl_final : server_task_result { stop_type stop = STOP_TYPE_NONE; std::vector probs_output; + std::vector requested_fields; slot_params generation_params; @@ -527,7 +530,7 @@ struct server_task_result_cmpl_final : server_task_result { if (!probs_output.empty()) { res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output); } - return res; + return requested_fields.empty() ? res : json_get_nested_values(requested_fields, res); } json to_json_oaicompat_chat() { @@ -1960,6 +1963,7 @@ struct server_context { res->content = slot.generated_text; res->timings = slot.get_timings(); res->prompt = common_detokenize(ctx, slot.prompt_tokens, true); + res->requested_fields = slot.params.requested_fields; res->truncated = slot.truncated; res->n_decoded = slot.n_decoded; diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 8fffe484a..0ac8b2cce 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -88,6 +88,33 @@ static bool json_is_array_of_mixed_numbers_strings(const json & data) { return false; } +// get value by path(key1 / key2) +static json json_get_nested_values(const std::vector& paths, const json& js) { + json result = json::object(); + + for (const std::string& path : paths) { + json current = js; + std::istringstream stream(path); + std::string key; + std::vector keys; + while (std::getline(stream, key, '/')) { + keys.push_back(key); + } + bool valid_path = true; + for (const std::string& k : keys) { + if (valid_path && current.is_object() && current.contains(k)) { + current = current[k]; + } else { + valid_path = false; + } + } + if (valid_path) { + result[path] = current; + } + } + return result; +} + /** * this handles 2 cases: * - only string, example: "string" From bc09b1acdf18ca199489938abc93bfc934552019 Mon Sep 17 00:00:00 2001 From: nvrxq Date: Sun, 22 Dec 2024 18:57:55 +0300 Subject: [PATCH 2/4] llama_server_response_fields_fix_issues --- examples/server/README.md | 2 ++ examples/server/tests/unit/test_completion.py | 36 +++++++++++++++++++ examples/server/utils.hpp | 15 +++----- 3 files changed, 43 insertions(+), 10 deletions(-) diff --git a/examples/server/README.md b/examples/server/README.md index 63a7bf43a..ccb40dba3 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -442,6 +442,8 @@ These words will not be included in the completion, so make sure to add them to `timings_per_token`: Include prompt processing and text generation speed information in each response. Default: `false` +`requested_fields`: A list of required response fields, for example : `"requested_fields": ["content", "generation_settings/n_predict"]` If there is no field, return an empty json for that field. + **Response format** - Note: In streaming mode (`stream`), only `content` and `stop` will be returned until end of completion. Responses are sent using the [Server-sent events](https://html.spec.whatwg.org/multipage/server-sent-events.html) standard. Note: the browser's `EventSource` interface cannot be used due to its lack of `POST` request support. diff --git a/examples/server/tests/unit/test_completion.py b/examples/server/tests/unit/test_completion.py index 062ebcd4a..83d1a5d77 100644 --- a/examples/server/tests/unit/test_completion.py +++ b/examples/server/tests/unit/test_completion.py @@ -249,6 +249,42 @@ def test_completion_parallel_slots(n_slots: int, n_requests: int): # assert match_regex(re_content, res.body["content"]) +@pytest.mark.parametrize( + "prompt,n_predict,requested_fields", + [ + ("I believe the meaning of life is", 8, []), + ( + "I believe the meaning of life is", 32, ["content", "generation_settings/n_predict", "prompt"], + ), + ], +) +def test_completion_requested_fields( + prompt: str, n_predict: int, requested_fields: list[str] +): + global server + server.start() + res = server.make_request( + "POST", + "/completion", + data={ + "n_predict": n_predict, + "prompt": prompt, + "requested_fields": requested_fields, + }, + ) + assert res.status_code == 200 + assert "content" in res.body + assert len(res.body["content"]) + if len(requested_fields) > 0: + assert res.body["generation_settings/n_predict"] == n_predict + assert res.body["prompt"] == " " + prompt + assert isinstance(res.body["content"], str) + assert len(res.body) == len(requested_fields) + else: + assert len(res.body) > 0 + assert "generation_settings" in res.body + + def test_n_probs(): global server server.start() diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 0ac8b2cce..9ad900067 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -89,19 +89,14 @@ static bool json_is_array_of_mixed_numbers_strings(const json & data) { } // get value by path(key1 / key2) -static json json_get_nested_values(const std::vector& paths, const json& js) { +static json json_get_nested_values(const std::vector & paths, const json & js) { json result = json::object(); - - for (const std::string& path : paths) { + + for (const std::string & path : paths) { json current = js; - std::istringstream stream(path); - std::string key; - std::vector keys; - while (std::getline(stream, key, '/')) { - keys.push_back(key); - } + const auto keys = string_split(path, /*delim*/ '/'); bool valid_path = true; - for (const std::string& k : keys) { + for (const std::string & k : keys) { if (valid_path && current.is_object() && current.contains(k)) { current = current[k]; } else { From 0958ee96ac464f80d22d59bcd0b3593a0a2149be Mon Sep 17 00:00:00 2001 From: nvrxq Date: Sun, 22 Dec 2024 19:16:28 +0300 Subject: [PATCH 3/4] params fixes --- examples/server/tests/unit/test_completion.py | 4 ++-- examples/server/utils.hpp | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/server/tests/unit/test_completion.py b/examples/server/tests/unit/test_completion.py index 1a6c77974..ee65901f1 100644 --- a/examples/server/tests/unit/test_completion.py +++ b/examples/server/tests/unit/test_completion.py @@ -283,13 +283,13 @@ def test_completion_requested_fields( assert res.status_code == 200 assert "content" in res.body assert len(res.body["content"]) - if len(requested_fields) > 0: + if len(requested_fields): assert res.body["generation_settings/n_predict"] == n_predict assert res.body["prompt"] == " " + prompt assert isinstance(res.body["content"], str) assert len(res.body) == len(requested_fields) else: - assert len(res.body) > 0 + assert len(res.body) assert "generation_settings" in res.body diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index e5164a889..d0e8d5266 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -94,7 +94,7 @@ static json json_get_nested_values(const std::vector & paths, const for (const std::string & path : paths) { json current = js; - const auto keys = string_split(path, /*delim*/ '/'); + const auto keys = string_split(path, /*separator*/ '/'); bool valid_path = true; for (const std::string & k : keys) { if (valid_path && current.is_object() && current.contains(k)) { From 3d3c6bae46417cdd572c6b3f2a3e132dc004ca31 Mon Sep 17 00:00:00 2001 From: nvrxq Date: Sun, 22 Dec 2024 19:18:54 +0300 Subject: [PATCH 4/4] fix --- examples/server/tests/unit/test_completion.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/examples/server/tests/unit/test_completion.py b/examples/server/tests/unit/test_completion.py index ee65901f1..f7a427c33 100644 --- a/examples/server/tests/unit/test_completion.py +++ b/examples/server/tests/unit/test_completion.py @@ -261,9 +261,7 @@ def test_completion_parallel_slots(n_slots: int, n_requests: int): "prompt,n_predict,requested_fields", [ ("I believe the meaning of life is", 8, []), - ( - "I believe the meaning of life is", 32, ["content", "generation_settings/n_predict", "prompt"], - ), + ("I believe the meaning of life is", 32, ["content", "generation_settings/n_predict", "prompt"]), ], ) def test_completion_requested_fields(