diff --git a/examples/server/server.cpp b/examples/server/server.cpp index bc0d042ae..436170a03 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -719,14 +719,17 @@ struct server_task_result_embd : server_task_result { int index = 0; std::vector embedding; + int32_t n_tokens; + virtual int get_index() override { return index; } virtual json to_json() override { return json { - {"index", index}, - {"embedding", embedding}, + {"index", index}, + {"embedding", embedding}, + {"tokens_evaluated", n_tokens}, }; } }; @@ -735,14 +738,17 @@ struct server_task_result_rerank : server_task_result { int index = 0; float score = -1e6; + int32_t n_tokens; + virtual int get_index() override { return index; } virtual json to_json() override { return json { - {"index", index}, - {"score", score}, + {"index", index}, + {"score", score}, + {"tokens_evaluated", n_tokens}, }; } }; @@ -1995,6 +2001,7 @@ struct server_context { auto res = std::make_unique(); res->id = slot.id_task; res->index = slot.index; + res->n_tokens = slot.n_prompt_tokens; const int n_embd = llama_n_embd(model); @@ -2030,6 +2037,7 @@ struct server_context { auto res = std::make_unique(); res->id = slot.id_task; res->index = slot.index; + res->n_tokens = slot.n_prompt_tokens; for (int i = 0; i < batch.n_tokens; ++i) { if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { diff --git a/examples/server/tests/unit/test_embedding.py b/examples/server/tests/unit/test_embedding.py index fc7c20064..fea1d6510 100644 --- a/examples/server/tests/unit/test_embedding.py +++ b/examples/server/tests/unit/test_embedding.py @@ -97,3 +97,33 @@ def test_same_prompt_give_same_result(): vi = res.body['data'][i]['embedding'] for x, y in zip(v0, vi): assert abs(x - y) < EPSILON + + +@pytest.mark.parametrize( + "content,n_tokens", + [ + ("I believe the meaning of life is", 7), + ("This is a test", 4), + ] +) +def test_embedding_usage_single(content, n_tokens): + global server + server.start() + res = server.make_request("POST", "/embeddings", data={"input": content}) + assert res.status_code == 200 + assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens'] + assert res.body['usage']['prompt_tokens'] == n_tokens + + +def test_embedding_usage_multiple(): + global server + server.start() + res = server.make_request("POST", "/embeddings", data={ + "input": [ + "I believe the meaning of life is", + "I believe the meaning of life is", + ], + }) + assert res.status_code == 200 + assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens'] + assert res.body['usage']['prompt_tokens'] == 2 * 7 diff --git a/examples/server/tests/unit/test_rerank.py b/examples/server/tests/unit/test_rerank.py index 189bc4c96..7203d7943 100644 --- a/examples/server/tests/unit/test_rerank.py +++ b/examples/server/tests/unit/test_rerank.py @@ -53,3 +53,26 @@ def test_invalid_rerank_req(documents): }) assert res.status_code == 400 assert "error" in res.body + + +@pytest.mark.parametrize( + "query,doc1,doc2,n_tokens", + [ + ("Machine learning is", "A machine", "Learning is", 19), + ("Which city?", "Machine learning is ", "Paris, capitale de la", 26), + ] +) +def test_rerank_usage(query, doc1, doc2, n_tokens): + global server + server.start() + + res = server.make_request("POST", "/rerank", data={ + "query": query, + "documents": [ + doc1, + doc2, + ] + }) + assert res.status_code == 200 + assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens'] + assert res.body['usage']['prompt_tokens'] == n_tokens diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index c6f08bf21..8fffe484a 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -560,6 +560,7 @@ static json oaicompat_completion_params_parse( static json format_embeddings_response_oaicompat(const json & request, const json & embeddings) { json data = json::array(); + int32_t n_tokens = 0; int i = 0; for (const auto & elem : embeddings) { data.push_back(json{ @@ -567,14 +568,16 @@ static json format_embeddings_response_oaicompat(const json & request, const jso {"index", i++}, {"object", "embedding"} }); + + n_tokens += json_value(elem, "tokens_evaluated", 0); } json res = json { {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, {"object", "list"}, - {"usage", json { // TODO: fill - {"prompt_tokens", 0}, - {"total_tokens", 0} + {"usage", json { + {"prompt_tokens", n_tokens}, + {"total_tokens", n_tokens} }}, {"data", data} }; @@ -584,20 +587,23 @@ static json format_embeddings_response_oaicompat(const json & request, const jso static json format_response_rerank(const json & request, const json & ranks) { json data = json::array(); + int32_t n_tokens = 0; int i = 0; for (const auto & rank : ranks) { data.push_back(json{ {"index", i++}, {"relevance_score", json_value(rank, "score", 0.0)}, }); + + n_tokens += json_value(rank, "tokens_evaluated", 0); } json res = json { {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, {"object", "list"}, - {"usage", json { // TODO: fill - {"prompt_tokens", 0}, - {"total_tokens", 0} + {"usage", json { + {"prompt_tokens", n_tokens}, + {"total_tokens", n_tokens} }}, {"results", data} };