diff --git a/examples/server/README.md b/examples/server/README.md index 63a7bf43a..ccb40dba3 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -442,6 +442,8 @@ These words will not be included in the completion, so make sure to add them to `timings_per_token`: Include prompt processing and text generation speed information in each response. Default: `false` +`requested_fields`: A list of required response fields, for example : `"requested_fields": ["content", "generation_settings/n_predict"]` If there is no field, return an empty json for that field. + **Response format** - Note: In streaming mode (`stream`), only `content` and `stop` will be returned until end of completion. Responses are sent using the [Server-sent events](https://html.spec.whatwg.org/multipage/server-sent-events.html) standard. Note: the browser's `EventSource` interface cannot be used due to its lack of `POST` request support. diff --git a/examples/server/tests/unit/test_completion.py b/examples/server/tests/unit/test_completion.py index 062ebcd4a..83d1a5d77 100644 --- a/examples/server/tests/unit/test_completion.py +++ b/examples/server/tests/unit/test_completion.py @@ -249,6 +249,42 @@ def test_completion_parallel_slots(n_slots: int, n_requests: int): # assert match_regex(re_content, res.body["content"]) +@pytest.mark.parametrize( + "prompt,n_predict,requested_fields", + [ + ("I believe the meaning of life is", 8, []), + ( + "I believe the meaning of life is", 32, ["content", "generation_settings/n_predict", "prompt"], + ), + ], +) +def test_completion_requested_fields( + prompt: str, n_predict: int, requested_fields: list[str] +): + global server + server.start() + res = server.make_request( + "POST", + "/completion", + data={ + "n_predict": n_predict, + "prompt": prompt, + "requested_fields": requested_fields, + }, + ) + assert res.status_code == 200 + assert "content" in res.body + assert len(res.body["content"]) + if len(requested_fields) > 0: + assert res.body["generation_settings/n_predict"] == n_predict + assert res.body["prompt"] == " " + prompt + assert isinstance(res.body["content"], str) + assert len(res.body) == len(requested_fields) + else: + assert len(res.body) > 0 + assert "generation_settings" in res.body + + def test_n_probs(): global server server.start() diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 0ac8b2cce..9ad900067 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -89,19 +89,14 @@ static bool json_is_array_of_mixed_numbers_strings(const json & data) { } // get value by path(key1 / key2) -static json json_get_nested_values(const std::vector& paths, const json& js) { +static json json_get_nested_values(const std::vector & paths, const json & js) { json result = json::object(); - - for (const std::string& path : paths) { + + for (const std::string & path : paths) { json current = js; - std::istringstream stream(path); - std::string key; - std::vector keys; - while (std::getline(stream, key, '/')) { - keys.push_back(key); - } + const auto keys = string_split(path, /*delim*/ '/'); bool valid_path = true; - for (const std::string& k : keys) { + for (const std::string & k : keys) { if (valid_path && current.is_object() && current.contains(k)) { current = current[k]; } else {