mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 20:04:35 +00:00
completion endpoint working
This commit is contained in:
parent
29c8cdd65d
commit
81484805f0
@ -91,6 +91,7 @@ struct completion_token_output
|
|||||||
|
|
||||||
std::vector<token_prob> probs;
|
std::vector<token_prob> probs;
|
||||||
llama_token tok;
|
llama_token tok;
|
||||||
|
std::string text_to_send;
|
||||||
};
|
};
|
||||||
|
|
||||||
static size_t common_part(const std::vector<llama_token> &a, const std::vector<llama_token> &b)
|
static size_t common_part(const std::vector<llama_token> &a, const std::vector<llama_token> &b)
|
||||||
@ -231,6 +232,7 @@ struct llama_client_slot
|
|||||||
bool stopped_limit = false;
|
bool stopped_limit = false;
|
||||||
std::string stopping_word;
|
std::string stopping_word;
|
||||||
int32_t multibyte_pending = 0;
|
int32_t multibyte_pending = 0;
|
||||||
|
size_t sent_count = 0;
|
||||||
|
|
||||||
struct slot_params params;
|
struct slot_params params;
|
||||||
struct llama_sampling_params sparams;
|
struct llama_sampling_params sparams;
|
||||||
@ -453,7 +455,6 @@ struct llama_server_context
|
|||||||
else
|
else
|
||||||
{
|
{
|
||||||
auto s = json_prompt.template get<std::string>();
|
auto s = json_prompt.template get<std::string>();
|
||||||
printf("----------------------\nprompt:\n%s-----------------------\n", s.c_str());
|
|
||||||
prompt_tokens = ::llama_tokenize(ctx, s, add_bos);
|
prompt_tokens = ::llama_tokenize(ctx, s, add_bos);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -492,7 +493,6 @@ struct llama_server_context
|
|||||||
// compare the evaluated prompt with the new prompt
|
// compare the evaluated prompt with the new prompt
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
llama_client_slot* getSlot(int id) {
|
llama_client_slot* getSlot(int id) {
|
||||||
for (llama_client_slot & slot : slots)
|
for (llama_client_slot & slot : slots)
|
||||||
{
|
{
|
||||||
@ -703,17 +703,36 @@ struct llama_server_context
|
|||||||
slot.last_n_tokens.push_back(result.tok);
|
slot.last_n_tokens.push_back(result.tok);
|
||||||
const std::string token_str = llama_token_to_piece(ctx, result.tok);
|
const std::string token_str = llama_token_to_piece(ctx, result.tok);
|
||||||
slot.sampled = result.tok;
|
slot.sampled = result.tok;
|
||||||
slot.addTokenString(result);
|
|
||||||
slot.generated_text += token_str;
|
slot.generated_text += token_str;
|
||||||
|
|
||||||
size_t stop_pos = findStoppingStrings(slot.generated_text, token_str.size(), STOP_FULL, slot);
|
size_t pos = std::min(slot.sent_count, slot.generated_text.size());
|
||||||
|
|
||||||
bool has_next_token = !(slot.n_decoded > 2 &&
|
|
||||||
(result.tok == llama_token_eos(ctx) ||
|
|
||||||
(slot.n_decoded + slot.n_past >=
|
|
||||||
params.n_predict) ||
|
|
||||||
stop_pos != std::string::npos));
|
|
||||||
|
|
||||||
|
const std::string str_test = slot.generated_text.substr(pos);
|
||||||
|
bool is_stop_full = false;
|
||||||
|
size_t stop_pos = findStoppingStrings(str_test, token_str.size(), STOP_FULL, slot);
|
||||||
|
if (stop_pos != std::string::npos) {
|
||||||
|
is_stop_full = true;
|
||||||
|
slot.generated_text.erase(
|
||||||
|
slot.generated_text.begin() + pos + stop_pos,
|
||||||
|
slot.generated_text.end());
|
||||||
|
pos = std::min(slot.sent_count, slot.generated_text.size());
|
||||||
|
result.tok = -1;
|
||||||
|
} else {
|
||||||
|
is_stop_full = false;
|
||||||
|
stop_pos = findStoppingStrings(str_test, token_str.size(),
|
||||||
|
STOP_PARTIAL, slot);
|
||||||
|
}
|
||||||
|
bool has_next_token = !is_stop_full && stop_pos > 0;
|
||||||
|
if(stop_pos == std::string::npos) {
|
||||||
|
result.text_to_send = slot.generated_text.substr(pos, std::string::npos);
|
||||||
|
slot.sent_count += result.text_to_send.size();
|
||||||
|
has_next_token = true;
|
||||||
|
}
|
||||||
|
slot.addTokenString(result);
|
||||||
|
if(slot.n_decoded > 2 && (result.tok == llama_token_eos(ctx) ||
|
||||||
|
slot.n_past + slot.n_decoded >= params.n_predict)) {
|
||||||
|
has_next_token = false;
|
||||||
|
}
|
||||||
if (slot.sparams.n_probs > 0)
|
if (slot.sparams.n_probs > 0)
|
||||||
{
|
{
|
||||||
slot.generated_token_probs.push_back(result);
|
slot.generated_token_probs.push_back(result);
|
||||||
@ -804,6 +823,7 @@ struct llama_server_context
|
|||||||
slot.state = IDLE;
|
slot.state = IDLE;
|
||||||
slot.command = NONE;
|
slot.command = NONE;
|
||||||
slot.generated_text.clear();
|
slot.generated_text.clear();
|
||||||
|
return true;
|
||||||
} else {
|
} else {
|
||||||
slot.state = SLEEPING;
|
slot.state = SLEEPING;
|
||||||
slot.command = NONE;
|
slot.command = NONE;
|
||||||
@ -1748,10 +1768,9 @@ int main(int argc, char **argv)
|
|||||||
while(slot->isProcessing()) {
|
while(slot->isProcessing()) {
|
||||||
if(slot->hasNewToken()) { // new token notification
|
if(slot->hasNewToken()) { // new token notification
|
||||||
const completion_token_output token = slot->next();
|
const completion_token_output token = slot->next();
|
||||||
std::string token_str = llama_token_to_piece(llama.ctx, token.tok);
|
|
||||||
std::vector<completion_token_output> probs_output = {};
|
std::vector<completion_token_output> probs_output = {};
|
||||||
if (slot->sparams.n_probs > 0) {
|
if (slot->sparams.n_probs > 0) {
|
||||||
const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, token_str, false);
|
const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, token.text_to_send, false);
|
||||||
size_t probs_pos = std::min(sent_token_probs_index, slot->generated_token_probs.size());
|
size_t probs_pos = std::min(sent_token_probs_index, slot->generated_token_probs.size());
|
||||||
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), slot->generated_token_probs.size());
|
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), slot->generated_token_probs.size());
|
||||||
if (probs_pos < probs_stop_pos) {
|
if (probs_pos < probs_stop_pos) {
|
||||||
@ -1759,7 +1778,7 @@ int main(int argc, char **argv)
|
|||||||
}
|
}
|
||||||
sent_token_probs_index = probs_stop_pos;
|
sent_token_probs_index = probs_stop_pos;
|
||||||
}
|
}
|
||||||
const json data = format_partial_response(llama, slot, token_str, probs_output);
|
const json data = format_partial_response(llama, slot, token.text_to_send, probs_output);
|
||||||
const std::string str =
|
const std::string str =
|
||||||
"data: " +
|
"data: " +
|
||||||
data.dump(-1, ' ', false, json::error_handler_t::replace) +
|
data.dump(-1, ' ', false, json::error_handler_t::replace) +
|
||||||
@ -1796,7 +1815,7 @@ int main(int argc, char **argv)
|
|||||||
sink.done();
|
sink.done();
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
auto on_complete = [&] (bool) {
|
auto on_complete = [slot, &llama] (bool) {
|
||||||
llama.mutex.unlock();
|
llama.mutex.unlock();
|
||||||
slot->sent_tokens = 0;
|
slot->sent_tokens = 0;
|
||||||
slot->generated_token_probs.clear();
|
slot->generated_token_probs.clear();
|
||||||
|
Loading…
Reference in New Issue
Block a user