mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-13 12:10:18 +00:00
server : output embeddings for all tokens when pooling = none (#10861)
* server : add "tokens" output ggml-ci * server : output embeddings for all tokens when pooling = none ggml-ci * server : update readme [no ci] * server : fix spacing [no ci] Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com> * server : be explicit about the pooling type in the tests ggml-ci * server : update /embeddings and /v1/embeddings endpoints ggml-ci * server : do not normalize embeddings when there is no pooling ggml-ci * server : update readme ggml-ci * server : fixes * tests : update server tests ggml-ci * server : update readme [no ci] * server : remove rebase artifact --------- Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com>
This commit is contained in:
parent
0e70ba686e
commit
152610eda9
@ -1780,7 +1780,9 @@ void common_embd_normalize(const float * inp, float * out, int n, int embd_norm)
|
|||||||
break;
|
break;
|
||||||
case 0: // max absolute
|
case 0: // max absolute
|
||||||
for (int i = 0; i < n; i++) {
|
for (int i = 0; i < n; i++) {
|
||||||
if (sum < std::abs(inp[i])) sum = std::abs(inp[i]);
|
if (sum < std::abs(inp[i])) {
|
||||||
|
sum = std::abs(inp[i]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
sum /= 32760.0; // make an int16 range
|
sum /= 32760.0; // make an int16 range
|
||||||
break;
|
break;
|
||||||
|
@ -596,7 +596,8 @@ void common_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_si
|
|||||||
// Embedding utils
|
// Embedding utils
|
||||||
//
|
//
|
||||||
|
|
||||||
void common_embd_normalize(const float * inp, float * out, int n, int embd_norm = 2);
|
// TODO: repace embd_norm with an enum
|
||||||
|
void common_embd_normalize(const float * inp, float * out, int n, int embd_norm);
|
||||||
|
|
||||||
float common_embd_similarity_cos(const float * embd1, const float * embd2, int n);
|
float common_embd_similarity_cos(const float * embd1, const float * embd2, int n);
|
||||||
|
|
||||||
|
@ -75,7 +75,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::vector<float> emb_norm(emb_unorm.size());
|
std::vector<float> emb_norm(emb_unorm.size());
|
||||||
common_embd_normalize(emb_unorm.data(), emb_norm.data(), n_embd);
|
common_embd_normalize(emb_unorm.data(), emb_norm.data(), n_embd, 2);
|
||||||
result.push_back(emb_norm);
|
result.push_back(emb_norm);
|
||||||
|
|
||||||
#ifdef GRIT_DEBUG
|
#ifdef GRIT_DEBUG
|
||||||
|
@ -107,7 +107,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
|
|||||||
}
|
}
|
||||||
|
|
||||||
float * out = output + batch.seq_id[i][0] * n_embd;
|
float * out = output + batch.seq_id[i][0] * n_embd;
|
||||||
common_embd_normalize(embd, out, n_embd);
|
common_embd_normalize(embd, out, n_embd, 2);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -763,6 +763,8 @@ curl http://localhost:8080/v1/chat/completions \
|
|||||||
|
|
||||||
### POST `/v1/embeddings`: OpenAI-compatible embeddings API
|
### POST `/v1/embeddings`: OpenAI-compatible embeddings API
|
||||||
|
|
||||||
|
This endpoint requires that the model uses a pooling different than type `none`. The embeddings are normalized using the Eucledian norm.
|
||||||
|
|
||||||
*Options:*
|
*Options:*
|
||||||
|
|
||||||
See [OpenAI Embeddings API documentation](https://platform.openai.com/docs/api-reference/embeddings).
|
See [OpenAI Embeddings API documentation](https://platform.openai.com/docs/api-reference/embeddings).
|
||||||
@ -795,6 +797,46 @@ See [OpenAI Embeddings API documentation](https://platform.openai.com/docs/api-r
|
|||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### POST `/embeddings`: non-OpenAI-compatible embeddings API
|
||||||
|
|
||||||
|
This endpoint supports all poolings, including `--pooling none`. When the pooling is `none`, the responses will contain the *unnormalized* embeddings for *all* input tokens. For all other pooling types, only the pooled embeddings are returned, normalized using Euclidian norm.
|
||||||
|
|
||||||
|
Note that the response format of this endpoint is different from `/v1/embeddings`.
|
||||||
|
|
||||||
|
*Options:*
|
||||||
|
|
||||||
|
Same as the `/v1/embeddings` endpoint.
|
||||||
|
|
||||||
|
*Examples:*
|
||||||
|
|
||||||
|
Same as the `/v1/embeddings` endpoint.
|
||||||
|
|
||||||
|
**Response format**
|
||||||
|
|
||||||
|
```json
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"embedding": [
|
||||||
|
[ ... embeddings for token 0 ... ],
|
||||||
|
[ ... embeddings for token 1 ... ],
|
||||||
|
[ ... ]
|
||||||
|
[ ... embeddings for token N-1 ... ],
|
||||||
|
]
|
||||||
|
},
|
||||||
|
...
|
||||||
|
{
|
||||||
|
"index": P,
|
||||||
|
"embedding": [
|
||||||
|
[ ... embeddings for token 0 ... ],
|
||||||
|
[ ... embeddings for token 1 ... ],
|
||||||
|
[ ... ]
|
||||||
|
[ ... embeddings for token N-1 ... ],
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
### GET `/slots`: Returns the current slots processing state
|
### GET `/slots`: Returns the current slots processing state
|
||||||
|
|
||||||
> [!WARNING]
|
> [!WARNING]
|
||||||
|
@ -726,18 +726,32 @@ struct server_task_result_cmpl_partial : server_task_result {
|
|||||||
|
|
||||||
struct server_task_result_embd : server_task_result {
|
struct server_task_result_embd : server_task_result {
|
||||||
int index = 0;
|
int index = 0;
|
||||||
std::vector<float> embedding;
|
std::vector<std::vector<float>> embedding;
|
||||||
|
|
||||||
int32_t n_tokens;
|
int32_t n_tokens;
|
||||||
|
|
||||||
|
// OAI-compat fields
|
||||||
|
bool oaicompat = false;
|
||||||
|
|
||||||
virtual int get_index() override {
|
virtual int get_index() override {
|
||||||
return index;
|
return index;
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual json to_json() override {
|
virtual json to_json() override {
|
||||||
|
return oaicompat ? to_json_oaicompat() : to_json_non_oaicompat();
|
||||||
|
}
|
||||||
|
|
||||||
|
json to_json_non_oaicompat() {
|
||||||
return json {
|
return json {
|
||||||
{"index", index},
|
{"index", index},
|
||||||
{"embedding", embedding},
|
{"embedding", embedding},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
json to_json_oaicompat() {
|
||||||
|
return json {
|
||||||
|
{"index", index},
|
||||||
|
{"embedding", embedding[0]},
|
||||||
{"tokens_evaluated", n_tokens},
|
{"tokens_evaluated", n_tokens},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@ -2020,6 +2034,7 @@ struct server_context {
|
|||||||
res->id = slot.id_task;
|
res->id = slot.id_task;
|
||||||
res->index = slot.index;
|
res->index = slot.index;
|
||||||
res->n_tokens = slot.n_prompt_tokens;
|
res->n_tokens = slot.n_prompt_tokens;
|
||||||
|
res->oaicompat = slot.params.oaicompat;
|
||||||
|
|
||||||
const int n_embd = llama_n_embd(model);
|
const int n_embd = llama_n_embd(model);
|
||||||
|
|
||||||
@ -2038,12 +2053,18 @@ struct server_context {
|
|||||||
if (embd == NULL) {
|
if (embd == NULL) {
|
||||||
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
|
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
|
||||||
|
|
||||||
res->embedding = std::vector<float>(n_embd, 0.0f);
|
res->embedding.push_back(std::vector<float>(n_embd, 0.0f));
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
common_embd_normalize(embd, embd_res.data(), n_embd);
|
// normalize only when there is pooling
|
||||||
res->embedding = embd_res;
|
// TODO: configurable
|
||||||
|
if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) {
|
||||||
|
common_embd_normalize(embd, embd_res.data(), n_embd, 2);
|
||||||
|
res->embedding.push_back(embd_res);
|
||||||
|
} else {
|
||||||
|
res->embedding.push_back({ embd, embd + n_embd });
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
SLT_DBG(slot, "%s", "sending embeddings\n");
|
SLT_DBG(slot, "%s", "sending embeddings\n");
|
||||||
@ -2657,7 +2678,10 @@ struct server_context {
|
|||||||
|
|
||||||
// add prompt tokens for processing in the current batch
|
// add prompt tokens for processing in the current batch
|
||||||
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
|
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
|
||||||
common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, false);
|
// without pooling, we want to output the embeddings for all the tokens in the batch
|
||||||
|
const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;
|
||||||
|
|
||||||
|
common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd);
|
||||||
|
|
||||||
if (slot.params.cache_prompt) {
|
if (slot.params.cache_prompt) {
|
||||||
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
|
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
|
||||||
@ -3665,14 +3689,17 @@ int main(int argc, char ** argv) {
|
|||||||
res_ok(res, data);
|
res_ok(res, data);
|
||||||
};
|
};
|
||||||
|
|
||||||
const auto handle_embeddings = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
const auto handle_embeddings_impl = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res, bool oaicompat) {
|
||||||
const json body = json::parse(req.body);
|
const json body = json::parse(req.body);
|
||||||
bool oaicompat = false;
|
|
||||||
|
if (oaicompat && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
|
||||||
|
res_error(res, format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// for the shape of input/content, see tokenize_input_prompts()
|
// for the shape of input/content, see tokenize_input_prompts()
|
||||||
json prompt;
|
json prompt;
|
||||||
if (body.contains("input")) {
|
if (body.count("input") != 0) {
|
||||||
oaicompat = true;
|
|
||||||
prompt = body.at("input");
|
prompt = body.at("input");
|
||||||
} else if (body.contains("content")) {
|
} else if (body.contains("content")) {
|
||||||
oaicompat = false;
|
oaicompat = false;
|
||||||
@ -3698,9 +3725,14 @@ int main(int argc, char ** argv) {
|
|||||||
std::vector<server_task> tasks;
|
std::vector<server_task> tasks;
|
||||||
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
||||||
server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);
|
server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);
|
||||||
|
|
||||||
task.id = ctx_server.queue_tasks.get_new_id();
|
task.id = ctx_server.queue_tasks.get_new_id();
|
||||||
task.index = i;
|
task.index = i;
|
||||||
task.prompt_tokens = std::move(tokenized_prompts[i]);
|
task.prompt_tokens = std::move(tokenized_prompts[i]);
|
||||||
|
|
||||||
|
// OAI-compat
|
||||||
|
task.params.oaicompat = oaicompat;
|
||||||
|
|
||||||
tasks.push_back(task);
|
tasks.push_back(task);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3728,12 +3760,18 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// write JSON response
|
// write JSON response
|
||||||
json root = oaicompat
|
json root = oaicompat ? format_embeddings_response_oaicompat(body, responses) : json(responses);
|
||||||
? format_embeddings_response_oaicompat(body, responses)
|
|
||||||
: responses.size() == 1 ? responses[0] : json(responses);
|
|
||||||
res_ok(res, root);
|
res_ok(res, root);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const auto handle_embeddings = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) {
|
||||||
|
handle_embeddings_impl(req, res, false);
|
||||||
|
};
|
||||||
|
|
||||||
|
const auto handle_embeddings_oai = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) {
|
||||||
|
handle_embeddings_impl(req, res, true);
|
||||||
|
};
|
||||||
|
|
||||||
const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
||||||
if (!ctx_server.params_base.reranking || ctx_server.params_base.embedding) {
|
if (!ctx_server.params_base.reranking || ctx_server.params_base.embedding) {
|
||||||
res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking` and without `--embedding`", ERROR_TYPE_NOT_SUPPORTED));
|
res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking` and without `--embedding`", ERROR_TYPE_NOT_SUPPORTED));
|
||||||
@ -3907,7 +3945,7 @@ int main(int argc, char ** argv) {
|
|||||||
svr->Post("/infill", handle_infill);
|
svr->Post("/infill", handle_infill);
|
||||||
svr->Post("/embedding", handle_embeddings); // legacy
|
svr->Post("/embedding", handle_embeddings); // legacy
|
||||||
svr->Post("/embeddings", handle_embeddings);
|
svr->Post("/embeddings", handle_embeddings);
|
||||||
svr->Post("/v1/embeddings", handle_embeddings);
|
svr->Post("/v1/embeddings", handle_embeddings_oai);
|
||||||
svr->Post("/rerank", handle_rerank);
|
svr->Post("/rerank", handle_rerank);
|
||||||
svr->Post("/reranking", handle_rerank);
|
svr->Post("/reranking", handle_rerank);
|
||||||
svr->Post("/v1/rerank", handle_rerank);
|
svr->Post("/v1/rerank", handle_rerank);
|
||||||
|
@ -14,8 +14,9 @@ def create_server():
|
|||||||
|
|
||||||
def test_embedding_single():
|
def test_embedding_single():
|
||||||
global server
|
global server
|
||||||
|
server.pooling = 'last'
|
||||||
server.start()
|
server.start()
|
||||||
res = server.make_request("POST", "/embeddings", data={
|
res = server.make_request("POST", "/v1/embeddings", data={
|
||||||
"input": "I believe the meaning of life is",
|
"input": "I believe the meaning of life is",
|
||||||
})
|
})
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
@ -29,8 +30,9 @@ def test_embedding_single():
|
|||||||
|
|
||||||
def test_embedding_multiple():
|
def test_embedding_multiple():
|
||||||
global server
|
global server
|
||||||
|
server.pooling = 'last'
|
||||||
server.start()
|
server.start()
|
||||||
res = server.make_request("POST", "/embeddings", data={
|
res = server.make_request("POST", "/v1/embeddings", data={
|
||||||
"input": [
|
"input": [
|
||||||
"I believe the meaning of life is",
|
"I believe the meaning of life is",
|
||||||
"Write a joke about AI from a very long prompt which will not be truncated",
|
"Write a joke about AI from a very long prompt which will not be truncated",
|
||||||
@ -46,7 +48,7 @@ def test_embedding_multiple():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"content,is_multi_prompt",
|
"input,is_multi_prompt",
|
||||||
[
|
[
|
||||||
# single prompt
|
# single prompt
|
||||||
("string", False),
|
("string", False),
|
||||||
@ -59,25 +61,55 @@ def test_embedding_multiple():
|
|||||||
([[12, 34, 56], [12, "string", 34, 56]], True),
|
([[12, 34, 56], [12, "string", 34, 56]], True),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
def test_embedding_mixed_input(content, is_multi_prompt: bool):
|
def test_embedding_mixed_input(input, is_multi_prompt: bool):
|
||||||
global server
|
global server
|
||||||
server.start()
|
server.start()
|
||||||
res = server.make_request("POST", "/embeddings", data={"content": content})
|
res = server.make_request("POST", "/v1/embeddings", data={"input": input})
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
|
data = res.body['data']
|
||||||
if is_multi_prompt:
|
if is_multi_prompt:
|
||||||
assert len(res.body) == len(content)
|
assert len(data) == len(input)
|
||||||
for d in res.body:
|
for d in data:
|
||||||
assert 'embedding' in d
|
assert 'embedding' in d
|
||||||
assert len(d['embedding']) > 1
|
assert len(d['embedding']) > 1
|
||||||
else:
|
else:
|
||||||
assert 'embedding' in res.body
|
assert 'embedding' in data[0]
|
||||||
assert len(res.body['embedding']) > 1
|
assert len(data[0]['embedding']) > 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_embedding_pooling_none():
|
||||||
|
global server
|
||||||
|
server.pooling = 'none'
|
||||||
|
server.start()
|
||||||
|
res = server.make_request("POST", "/embeddings", data={
|
||||||
|
"input": "hello hello hello",
|
||||||
|
})
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert 'embedding' in res.body[0]
|
||||||
|
assert len(res.body[0]['embedding']) == 5 # 3 text tokens + 2 special
|
||||||
|
|
||||||
|
# make sure embedding vector is not normalized
|
||||||
|
for x in res.body[0]['embedding']:
|
||||||
|
assert abs(sum([x ** 2 for x in x]) - 1) > EPSILON
|
||||||
|
|
||||||
|
|
||||||
|
def test_embedding_pooling_none_oai():
|
||||||
|
global server
|
||||||
|
server.pooling = 'none'
|
||||||
|
server.start()
|
||||||
|
res = server.make_request("POST", "/v1/embeddings", data={
|
||||||
|
"input": "hello hello hello",
|
||||||
|
})
|
||||||
|
|
||||||
|
# /v1/embeddings does not support pooling type 'none'
|
||||||
|
assert res.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
def test_embedding_openai_library_single():
|
def test_embedding_openai_library_single():
|
||||||
global server
|
global server
|
||||||
|
server.pooling = 'last'
|
||||||
server.start()
|
server.start()
|
||||||
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
|
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
|
||||||
res = client.embeddings.create(model="text-embedding-3-small", input="I believe the meaning of life is")
|
res = client.embeddings.create(model="text-embedding-3-small", input="I believe the meaning of life is")
|
||||||
assert len(res.data) == 1
|
assert len(res.data) == 1
|
||||||
assert len(res.data[0].embedding) > 1
|
assert len(res.data[0].embedding) > 1
|
||||||
@ -85,8 +117,9 @@ def test_embedding_openai_library_single():
|
|||||||
|
|
||||||
def test_embedding_openai_library_multiple():
|
def test_embedding_openai_library_multiple():
|
||||||
global server
|
global server
|
||||||
|
server.pooling = 'last'
|
||||||
server.start()
|
server.start()
|
||||||
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
|
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
|
||||||
res = client.embeddings.create(model="text-embedding-3-small", input=[
|
res = client.embeddings.create(model="text-embedding-3-small", input=[
|
||||||
"I believe the meaning of life is",
|
"I believe the meaning of life is",
|
||||||
"Write a joke about AI from a very long prompt which will not be truncated",
|
"Write a joke about AI from a very long prompt which will not be truncated",
|
||||||
@ -100,8 +133,9 @@ def test_embedding_openai_library_multiple():
|
|||||||
|
|
||||||
def test_embedding_error_prompt_too_long():
|
def test_embedding_error_prompt_too_long():
|
||||||
global server
|
global server
|
||||||
|
server.pooling = 'last'
|
||||||
server.start()
|
server.start()
|
||||||
res = server.make_request("POST", "/embeddings", data={
|
res = server.make_request("POST", "/v1/embeddings", data={
|
||||||
"input": "This is a test " * 512,
|
"input": "This is a test " * 512,
|
||||||
})
|
})
|
||||||
assert res.status_code != 200
|
assert res.status_code != 200
|
||||||
@ -109,8 +143,9 @@ def test_embedding_error_prompt_too_long():
|
|||||||
|
|
||||||
|
|
||||||
def test_same_prompt_give_same_result():
|
def test_same_prompt_give_same_result():
|
||||||
|
server.pooling = 'last'
|
||||||
server.start()
|
server.start()
|
||||||
res = server.make_request("POST", "/embeddings", data={
|
res = server.make_request("POST", "/v1/embeddings", data={
|
||||||
"input": [
|
"input": [
|
||||||
"I believe the meaning of life is",
|
"I believe the meaning of life is",
|
||||||
"I believe the meaning of life is",
|
"I believe the meaning of life is",
|
||||||
@ -138,7 +173,7 @@ def test_same_prompt_give_same_result():
|
|||||||
def test_embedding_usage_single(content, n_tokens):
|
def test_embedding_usage_single(content, n_tokens):
|
||||||
global server
|
global server
|
||||||
server.start()
|
server.start()
|
||||||
res = server.make_request("POST", "/embeddings", data={"input": content})
|
res = server.make_request("POST", "/v1/embeddings", data={"input": content})
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
|
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
|
||||||
assert res.body['usage']['prompt_tokens'] == n_tokens
|
assert res.body['usage']['prompt_tokens'] == n_tokens
|
||||||
@ -147,7 +182,7 @@ def test_embedding_usage_single(content, n_tokens):
|
|||||||
def test_embedding_usage_multiple():
|
def test_embedding_usage_multiple():
|
||||||
global server
|
global server
|
||||||
server.start()
|
server.start()
|
||||||
res = server.make_request("POST", "/embeddings", data={
|
res = server.make_request("POST", "/v1/embeddings", data={
|
||||||
"input": [
|
"input": [
|
||||||
"I believe the meaning of life is",
|
"I believe the meaning of life is",
|
||||||
"I believe the meaning of life is",
|
"I believe the meaning of life is",
|
||||||
|
@ -65,6 +65,7 @@ class ServerProcess:
|
|||||||
server_reranking: bool | None = False
|
server_reranking: bool | None = False
|
||||||
server_metrics: bool | None = False
|
server_metrics: bool | None = False
|
||||||
server_slots: bool | None = False
|
server_slots: bool | None = False
|
||||||
|
pooling: str | None = None
|
||||||
draft: int | None = None
|
draft: int | None = None
|
||||||
api_key: str | None = None
|
api_key: str | None = None
|
||||||
response_format: str | None = None
|
response_format: str | None = None
|
||||||
@ -132,6 +133,8 @@ class ServerProcess:
|
|||||||
server_args.append("--metrics")
|
server_args.append("--metrics")
|
||||||
if self.server_slots:
|
if self.server_slots:
|
||||||
server_args.append("--slots")
|
server_args.append("--slots")
|
||||||
|
if self.pooling:
|
||||||
|
server_args.extend(["--pooling", self.pooling])
|
||||||
if self.model_alias:
|
if self.model_alias:
|
||||||
server_args.extend(["--alias", self.model_alias])
|
server_args.extend(["--alias", self.model_alias])
|
||||||
if self.n_ctx:
|
if self.n_ctx:
|
||||||
|
Loading…
Reference in New Issue
Block a user