mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 20:04:35 +00:00
completion endpoint working
This commit is contained in:
parent
29c8cdd65d
commit
81484805f0
@ -91,6 +91,7 @@ struct completion_token_output
|
||||
|
||||
std::vector<token_prob> probs;
|
||||
llama_token tok;
|
||||
std::string text_to_send;
|
||||
};
|
||||
|
||||
static size_t common_part(const std::vector<llama_token> &a, const std::vector<llama_token> &b)
|
||||
@ -231,6 +232,7 @@ struct llama_client_slot
|
||||
bool stopped_limit = false;
|
||||
std::string stopping_word;
|
||||
int32_t multibyte_pending = 0;
|
||||
size_t sent_count = 0;
|
||||
|
||||
struct slot_params params;
|
||||
struct llama_sampling_params sparams;
|
||||
@ -453,7 +455,6 @@ struct llama_server_context
|
||||
else
|
||||
{
|
||||
auto s = json_prompt.template get<std::string>();
|
||||
printf("----------------------\nprompt:\n%s-----------------------\n", s.c_str());
|
||||
prompt_tokens = ::llama_tokenize(ctx, s, add_bos);
|
||||
}
|
||||
|
||||
@ -492,7 +493,6 @@ struct llama_server_context
|
||||
// compare the evaluated prompt with the new prompt
|
||||
}
|
||||
|
||||
|
||||
llama_client_slot* getSlot(int id) {
|
||||
for (llama_client_slot & slot : slots)
|
||||
{
|
||||
@ -703,17 +703,36 @@ struct llama_server_context
|
||||
slot.last_n_tokens.push_back(result.tok);
|
||||
const std::string token_str = llama_token_to_piece(ctx, result.tok);
|
||||
slot.sampled = result.tok;
|
||||
slot.addTokenString(result);
|
||||
slot.generated_text += token_str;
|
||||
|
||||
size_t stop_pos = findStoppingStrings(slot.generated_text, token_str.size(), STOP_FULL, slot);
|
||||
|
||||
bool has_next_token = !(slot.n_decoded > 2 &&
|
||||
(result.tok == llama_token_eos(ctx) ||
|
||||
(slot.n_decoded + slot.n_past >=
|
||||
params.n_predict) ||
|
||||
stop_pos != std::string::npos));
|
||||
size_t pos = std::min(slot.sent_count, slot.generated_text.size());
|
||||
|
||||
const std::string str_test = slot.generated_text.substr(pos);
|
||||
bool is_stop_full = false;
|
||||
size_t stop_pos = findStoppingStrings(str_test, token_str.size(), STOP_FULL, slot);
|
||||
if (stop_pos != std::string::npos) {
|
||||
is_stop_full = true;
|
||||
slot.generated_text.erase(
|
||||
slot.generated_text.begin() + pos + stop_pos,
|
||||
slot.generated_text.end());
|
||||
pos = std::min(slot.sent_count, slot.generated_text.size());
|
||||
result.tok = -1;
|
||||
} else {
|
||||
is_stop_full = false;
|
||||
stop_pos = findStoppingStrings(str_test, token_str.size(),
|
||||
STOP_PARTIAL, slot);
|
||||
}
|
||||
bool has_next_token = !is_stop_full && stop_pos > 0;
|
||||
if(stop_pos == std::string::npos) {
|
||||
result.text_to_send = slot.generated_text.substr(pos, std::string::npos);
|
||||
slot.sent_count += result.text_to_send.size();
|
||||
has_next_token = true;
|
||||
}
|
||||
slot.addTokenString(result);
|
||||
if(slot.n_decoded > 2 && (result.tok == llama_token_eos(ctx) ||
|
||||
slot.n_past + slot.n_decoded >= params.n_predict)) {
|
||||
has_next_token = false;
|
||||
}
|
||||
if (slot.sparams.n_probs > 0)
|
||||
{
|
||||
slot.generated_token_probs.push_back(result);
|
||||
@ -804,6 +823,7 @@ struct llama_server_context
|
||||
slot.state = IDLE;
|
||||
slot.command = NONE;
|
||||
slot.generated_text.clear();
|
||||
return true;
|
||||
} else {
|
||||
slot.state = SLEEPING;
|
||||
slot.command = NONE;
|
||||
@ -1748,10 +1768,9 @@ int main(int argc, char **argv)
|
||||
while(slot->isProcessing()) {
|
||||
if(slot->hasNewToken()) { // new token notification
|
||||
const completion_token_output token = slot->next();
|
||||
std::string token_str = llama_token_to_piece(llama.ctx, token.tok);
|
||||
std::vector<completion_token_output> probs_output = {};
|
||||
if (slot->sparams.n_probs > 0) {
|
||||
const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, token_str, false);
|
||||
const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, token.text_to_send, false);
|
||||
size_t probs_pos = std::min(sent_token_probs_index, slot->generated_token_probs.size());
|
||||
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), slot->generated_token_probs.size());
|
||||
if (probs_pos < probs_stop_pos) {
|
||||
@ -1759,7 +1778,7 @@ int main(int argc, char **argv)
|
||||
}
|
||||
sent_token_probs_index = probs_stop_pos;
|
||||
}
|
||||
const json data = format_partial_response(llama, slot, token_str, probs_output);
|
||||
const json data = format_partial_response(llama, slot, token.text_to_send, probs_output);
|
||||
const std::string str =
|
||||
"data: " +
|
||||
data.dump(-1, ' ', false, json::error_handler_t::replace) +
|
||||
@ -1796,7 +1815,7 @@ int main(int argc, char **argv)
|
||||
sink.done();
|
||||
return true;
|
||||
};
|
||||
auto on_complete = [&] (bool) {
|
||||
auto on_complete = [slot, &llama] (bool) {
|
||||
llama.mutex.unlock();
|
||||
slot->sent_tokens = 0;
|
||||
slot->generated_token_probs.clear();
|
||||
|
Loading…
Reference in New Issue
Block a user