[Fix] Reenable server embedding endpoint (#1937)

* Add back embedding feature

* Update README
This commit is contained in:
Henri Vasserman 2023-06-20 01:12:39 +03:00 committed by GitHub
parent 18b35625c3
commit 20568fe60f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 54 additions and 3 deletions

View File

@ -21,6 +21,7 @@ Command line options:
- `-to N`, `--timeout N`: Server read/write timeout in seconds. Default `600`. - `-to N`, `--timeout N`: Server read/write timeout in seconds. Default `600`.
- `--host`: Set the hostname or ip address to listen. Default `127.0.0.1`. - `--host`: Set the hostname or ip address to listen. Default `127.0.0.1`.
- `--port`: Set the port to listen. Default: `8080`. - `--port`: Set the port to listen. Default: `8080`.
- `--embedding`: Enable embedding extraction, Default: disabled.
## Build ## Build
@ -119,14 +120,14 @@ node .
`top_p`: Limit the next token selection to a subset of tokens with a cumulative probability above a threshold P (default: 0.9). `top_p`: Limit the next token selection to a subset of tokens with a cumulative probability above a threshold P (default: 0.9).
`n_predict`: Set the number of tokens to predict when generating text. **Note:** May exceed the set limit slightly if the last token is a partial multibyte character. (default: 128, -1 = infinity). `n_predict`: Set the number of tokens to predict when generating text. **Note:** May exceed the set limit slightly if the last token is a partial multibyte character. When 0, no tokens will be generated but the prompt is evaluated into the cache. (default: 128, -1 = infinity).
`n_keep`: Specify the number of tokens from the initial prompt to retain when the model resets its internal context. `n_keep`: Specify the number of tokens from the initial prompt to retain when the model resets its internal context.
By default, this value is set to 0 (meaning no tokens are kept). Use `-1` to retain all tokens from the initial prompt. By default, this value is set to 0 (meaning no tokens are kept). Use `-1` to retain all tokens from the initial prompt.
`stream`: It allows receiving each predicted token in real-time instead of waiting for the completion to finish. To enable this, set to `true`. `stream`: It allows receiving each predicted token in real-time instead of waiting for the completion to finish. To enable this, set to `true`.
`prompt`: Provide a prompt. Internally, the prompt is compared, and it detects if a part has already been evaluated, and the remaining part will be evaluate. `prompt`: Provide a prompt. Internally, the prompt is compared, and it detects if a part has already been evaluated, and the remaining part will be evaluate. A space is inserted in the front like main.cpp does.
`stop`: Specify a JSON array of stopping strings. `stop`: Specify a JSON array of stopping strings.
These words will not be included in the completion, so make sure to add them to the prompt for the next iteration (default: []). These words will not be included in the completion, so make sure to add them to the prompt for the next iteration (default: []).
@ -163,6 +164,14 @@ node .
`content`: Set the text to tokenize. `content`: Set the text to tokenize.
Note that the special `BOS` token is not added in fron of the text and also a space character is not inserted automatically as it is for `/completion`.
- **POST** `/embedding`: Generate embedding of a given text just as [the embedding example](../embedding) does.
*Options:*
`content`: Set the text to process.
## More examples ## More examples
### Interactive mode ### Interactive mode

View File

@ -254,6 +254,11 @@ struct llama_server_context {
n_past += n_eval; n_past += n_eval;
} }
if (params.n_predict == 0) {
has_next_token = false;
return llama_token_eos();
}
// out of user input, sample next token // out of user input, sample next token
const float temp = params.temp; const float temp = params.temp;
const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k; const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k;
@ -419,6 +424,19 @@ struct llama_server_context {
return token_text; return token_text;
} }
std::vector<float> getEmbedding() {
static const int n_embd = llama_n_embd(ctx);
if (!params.embedding) {
LOG_WARNING("embedding disabled", {
{ "params.embedding", params.embedding },
});
return std::vector<float>(n_embd, 0.0f);
}
const float * data = llama_get_embeddings(ctx);
std::vector<float> embedding(data, data + n_embd);
return embedding;
}
}; };
static void server_print_usage(const char * argv0, const gpt_params & params, static void server_print_usage(const char * argv0, const gpt_params & params,
@ -457,6 +475,7 @@ static void server_print_usage(const char * argv0, const gpt_params & params,
fprintf(stderr, " --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str()); fprintf(stderr, " --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str());
fprintf(stderr, " --port PORT port to listen (default (default: %d)\n", sparams.port); fprintf(stderr, " --port PORT port to listen (default (default: %d)\n", sparams.port);
fprintf(stderr, " -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout); fprintf(stderr, " -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout);
fprintf(stderr, " --embedding enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled");
fprintf(stderr, "\n"); fprintf(stderr, "\n");
} }
@ -603,6 +622,8 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
params.use_mlock = true; params.use_mlock = true;
} else if (arg == "--no-mmap") { } else if (arg == "--no-mmap") {
params.use_mmap = false; params.use_mmap = false;
} else if (arg == "--embedding") {
params.embedding = true;
} else { } else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
server_print_usage(argv[0], default_params, default_sparams); server_print_usage(argv[0], default_params, default_sparams);
@ -646,6 +667,12 @@ static json format_generation_settings(llama_server_context & llama) {
}; };
} }
static json format_embedding_response(llama_server_context & llama) {
return json {
{ "embedding", llama.getEmbedding() },
};
}
static json format_final_response(llama_server_context & llama, const std::string & content) { static json format_final_response(llama_server_context & llama, const std::string & content) {
return json { return json {
{ "content", content }, { "content", content },
@ -881,12 +908,27 @@ int main(int argc, char ** argv) {
svr.Post("/tokenize", [&llama](const Request & req, Response & res) { svr.Post("/tokenize", [&llama](const Request & req, Response & res) {
const json body = json::parse(req.body); const json body = json::parse(req.body);
const std::string content = body["content"].get<std::string>(); const std::string content = body.value("content", "");
const std::vector<llama_token> tokens = llama_tokenize(llama.ctx, content, false); const std::vector<llama_token> tokens = llama_tokenize(llama.ctx, content, false);
const json data = format_tokenizer_response(tokens); const json data = format_tokenizer_response(tokens);
return res.set_content(data.dump(), "application/json"); return res.set_content(data.dump(), "application/json");
}); });
svr.Post("/embedding", [&llama](const Request & req, Response & res) {
const json body = json::parse(req.body);
llama.rewind();
llama_reset_timings(llama.ctx);
llama.params.prompt = body.value("content", "");
llama.params.n_predict = 0;
llama.loadPrompt();
llama.beginCompletion();
llama.doCompletion();
const json data = format_embedding_response(llama);
return res.set_content(data.dump(), "application/json");
});
svr.set_logger(log_server_request); svr.set_logger(log_server_request);
svr.set_exception_handler([](const Request &, Response & res, std::exception_ptr ep) { svr.set_exception_handler([](const Request &, Response & res, std::exception_ptr ep) {