completion endpoint working

This commit is contained in:
FSSRepo 2023-10-12 16:17:27 -04:00
parent 29c8cdd65d
commit 81484805f0

View File

@ -91,6 +91,7 @@ struct completion_token_output
std::vector<token_prob> probs;
llama_token tok;
std::string text_to_send;
};
static size_t common_part(const std::vector<llama_token> &a, const std::vector<llama_token> &b)
@ -231,6 +232,7 @@ struct llama_client_slot
bool stopped_limit = false;
std::string stopping_word;
int32_t multibyte_pending = 0;
size_t sent_count = 0;
struct slot_params params;
struct llama_sampling_params sparams;
@ -453,7 +455,6 @@ struct llama_server_context
else
{
auto s = json_prompt.template get<std::string>();
printf("----------------------\nprompt:\n%s-----------------------\n", s.c_str());
prompt_tokens = ::llama_tokenize(ctx, s, add_bos);
}
@ -492,7 +493,6 @@ struct llama_server_context
// compare the evaluated prompt with the new prompt
}
llama_client_slot* getSlot(int id) {
for (llama_client_slot & slot : slots)
{
@ -703,17 +703,36 @@ struct llama_server_context
slot.last_n_tokens.push_back(result.tok);
const std::string token_str = llama_token_to_piece(ctx, result.tok);
slot.sampled = result.tok;
slot.addTokenString(result);
slot.generated_text += token_str;
size_t stop_pos = findStoppingStrings(slot.generated_text, token_str.size(), STOP_FULL, slot);
bool has_next_token = !(slot.n_decoded > 2 &&
(result.tok == llama_token_eos(ctx) ||
(slot.n_decoded + slot.n_past >=
params.n_predict) ||
stop_pos != std::string::npos));
size_t pos = std::min(slot.sent_count, slot.generated_text.size());
const std::string str_test = slot.generated_text.substr(pos);
bool is_stop_full = false;
size_t stop_pos = findStoppingStrings(str_test, token_str.size(), STOP_FULL, slot);
if (stop_pos != std::string::npos) {
is_stop_full = true;
slot.generated_text.erase(
slot.generated_text.begin() + pos + stop_pos,
slot.generated_text.end());
pos = std::min(slot.sent_count, slot.generated_text.size());
result.tok = -1;
} else {
is_stop_full = false;
stop_pos = findStoppingStrings(str_test, token_str.size(),
STOP_PARTIAL, slot);
}
bool has_next_token = !is_stop_full && stop_pos > 0;
if(stop_pos == std::string::npos) {
result.text_to_send = slot.generated_text.substr(pos, std::string::npos);
slot.sent_count += result.text_to_send.size();
has_next_token = true;
}
slot.addTokenString(result);
if(slot.n_decoded > 2 && (result.tok == llama_token_eos(ctx) ||
slot.n_past + slot.n_decoded >= params.n_predict)) {
has_next_token = false;
}
if (slot.sparams.n_probs > 0)
{
slot.generated_token_probs.push_back(result);
@ -804,6 +823,7 @@ struct llama_server_context
slot.state = IDLE;
slot.command = NONE;
slot.generated_text.clear();
return true;
} else {
slot.state = SLEEPING;
slot.command = NONE;
@ -1748,10 +1768,9 @@ int main(int argc, char **argv)
while(slot->isProcessing()) {
if(slot->hasNewToken()) { // new token notification
const completion_token_output token = slot->next();
std::string token_str = llama_token_to_piece(llama.ctx, token.tok);
std::vector<completion_token_output> probs_output = {};
if (slot->sparams.n_probs > 0) {
const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, token_str, false);
const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, token.text_to_send, false);
size_t probs_pos = std::min(sent_token_probs_index, slot->generated_token_probs.size());
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), slot->generated_token_probs.size());
if (probs_pos < probs_stop_pos) {
@ -1759,7 +1778,7 @@ int main(int argc, char **argv)
}
sent_token_probs_index = probs_stop_pos;
}
const json data = format_partial_response(llama, slot, token_str, probs_output);
const json data = format_partial_response(llama, slot, token.text_to_send, probs_output);
const std::string str =
"data: " +
data.dump(-1, ' ', false, json::error_handler_t::replace) +
@ -1796,7 +1815,7 @@ int main(int argc, char **argv)
sink.done();
return true;
};
auto on_complete = [&] (bool) {
auto on_complete = [slot, &llama] (bool) {
llama.mutex.unlock();
slot->sent_tokens = 0;
slot->generated_token_probs.clear();