cached prompt support

This commit is contained in:
FSSRepo 2023-10-12 21:16:12 -04:00
parent 83c2b3553a
commit 500ac7120e

View File

@ -37,6 +37,11 @@ struct server_params
int32_t write_timeout = 600;
};
// struct beam_search_callback_data {
// llama_server_context* ctx;
// llama_client_slot* slot;
// };
static bool server_verbose = false;
#if SERVER_VERBOSE != 1
@ -76,7 +81,7 @@ struct slot_params {
uint32_t seed = -1; // RNG seed
int32_t n_predict = 128; // new tokens to predict
std::string grammar = ""; // optional BNF-like grammar to constrain sampling
bool remember_generation = false; // remember a the prompt to avoid reprocessing all prompt
bool cache_prompt = false; // remember a the prompt to avoid reprocessing all prompt
std::vector<std::string> antiprompt;
json input_prefix;
json input_suffix;
@ -246,12 +251,9 @@ struct llama_client_slot
llama_grammar *grammar = nullptr;
void reset() {
state = IDLE;
command = NONE;
num_prompt_tokens = 0;
num_tokens_predicted = 0;
generated_text = "";
generated_token_probs.clear();
truncated = false;
stopped_eos = false;
stopped_word = false;
@ -261,6 +263,7 @@ struct llama_client_slot
n_past = 0;
sent_count = 0;
infill = false;
clean_tokens();
if (grammar != nullptr) {
llama_grammar_free(grammar);
@ -300,7 +303,7 @@ struct llama_client_slot
}
bool hasNewToken() {
return generated_token_probs.size() > sent_tokens;
return num_tokens_predicted > sent_tokens;
}
bool available() {
@ -308,7 +311,7 @@ struct llama_client_slot
}
bool isProcessing() {
return (state == IDLE && command == LOAD_PROMPT) || state == PROCESSING;
return (state == IDLE || state == SLEEPING) && command == LOAD_PROMPT || state == PROCESSING;
}
completion_token_output next() {
@ -319,8 +322,6 @@ struct llama_client_slot
void addTokenString(completion_token_output token) {
if(command == RELEASE) {
generated_token_probs.clear();
sent_tokens = 0;
return;
}
context_tokens.push_back(token.tok);
@ -333,6 +334,11 @@ struct llama_client_slot
command = RELEASE;
}
}
void clean_tokens() {
sent_tokens = 0;
generated_token_probs.clear();
}
};
struct llama_server_context
@ -626,9 +632,7 @@ struct llama_server_context
const std::string token_str = llama_token_to_piece(ctx, result.tok);
slot.sampled = result.tok;
slot.generated_text += token_str;
size_t pos = std::min(slot.sent_count, slot.generated_text.size());
const std::string str_test = slot.generated_text.substr(pos);
bool is_stop_full = false;
size_t stop_pos = findStoppingStrings(str_test, token_str.size(), STOP_FULL, slot);
@ -737,14 +741,20 @@ struct llama_server_context
if (slot.state == PROCESSING && slot.command == RELEASE && !slot.hasNewToken())
{
LOG_TEE("slot %i released\n", slot.id);
slot.state = slot.params.remember_generation ? SLEEPING : IDLE;
slot.state = slot.params.cache_prompt ? SLEEPING : IDLE;
if(slot.state == SLEEPING) {
printf("%i has cached prompt.");
}
slot.command = NONE;
continue;
}
kv_cache_free -= slot.num_prompt_tokens;
if (slot.state == IDLE || slot.command == RELEASE) {
if (
slot.state == IDLE ||
slot.state == SLEEPING ||
slot.command == RELEASE) {
continue;
}
@ -765,8 +775,6 @@ struct llama_server_context
// need process the prompt
bool keep_gen = slot.state == SLEEPING; // remember generation
if ((slot.state == IDLE || keep_gen) && slot.command == LOAD_PROMPT) {
slot.state = PROCESSING;
slot.command = NONE;
std::vector<llama_token> prompt_tokens;
if(slot.infill) {
bool suff_rm_leading_spc = true;
@ -794,6 +802,9 @@ struct llama_server_context
slot.n_past = keep_gen ? common_part(slot.context_tokens, prompt_tokens) : 0;
printf("n_past: %i, context: %i, prompt: %i, cache: %s\n",
slot.n_past ,slot.context_tokens.size(), prompt_tokens.size(), keep_gen ? "true" : "false");
slot.context_tokens = prompt_tokens;
if (slot.n_past == slot.num_prompt_tokens) {
@ -812,7 +823,7 @@ struct llama_server_context
std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end(), 0);
for (; slot.n_past < prompt_tokens.size(); ++slot.n_past) {
//printf(llama_token_to_piece(ctx, prompt_tokens[slot.n_past]).c_str());
printf(llama_token_to_piece(ctx, prompt_tokens[slot.n_past]).c_str());
batch.token [batch.n_tokens] = prompt_tokens[slot.n_past];
batch.pos [batch.n_tokens] = slot.n_past + num_tokens_system;
batch.seq_id[batch.n_tokens] = slot.id;
@ -827,6 +838,8 @@ struct llama_server_context
slot.n_decoded = 0;
slot.i_batch = batch.n_tokens - 1;
slot.state = PROCESSING;
slot.command = NONE;
}
}
}
@ -868,10 +881,18 @@ struct llama_server_context
}
for (auto & slot : slots) {
if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) {
continue;
}
// prompt evaluated for embedding
if(params.embedding) {
slot.release();
slot.i_batch = -1;
return true;
}
completion_token_output result;
const llama_token id = llama_sampling_sample(ctx, NULL, slot.ctx_sampling, slot.last_n_tokens, candidates, slot.i_batch - i);
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
@ -1316,6 +1337,7 @@ static json format_final_response(llama_server_context &llama, llama_client_slot
json res = json{
{"content", content},
{"slot_id", slot->id},
{"stop", true},
{"model", llama.params.model_alias},
{"tokens_predicted", slot->num_tokens_predicted},
@ -1327,7 +1349,7 @@ static json format_final_response(llama_server_context &llama, llama_client_slot
{"stopped_word", slot->stopped_word},
{"stopped_limit", slot->stopped_limit},
{"stopping_word", slot->stopping_word},
{"tokens_cached", slot->n_past},
{"tokens_cached", slot->n_past}
// {"timings", format_timings(llama)},
};
@ -1383,6 +1405,7 @@ static void parse_options_completion(const json &body, llama_client_slot* slot,
llama_sampling_params default_sparams;
slot->params.stream = json_value(body, "stream", false);
slot->params.cache_prompt = json_value(body, "cache_prompt", false);
slot->params.n_predict = json_value(body, "n_predict", default_params.n_predict);
slot->sparams.top_k = json_value(body, "top_k", default_sparams.top_k);
slot->sparams.top_p = json_value(body, "top_p", default_sparams.top_p);
@ -1495,8 +1518,8 @@ static void log_server_request(const Request &req, const Response &res)
});
}
static bool is_at_eob(llama_server_context &server_context, const llama_token *tokens, const size_t n_tokens) {
return n_tokens && tokens[n_tokens-1] == llama_token_eos(server_context.ctx);
static bool is_at_eob(llama_server_context * server_context, const llama_token *tokens, const size_t n_tokens) {
return n_tokens && tokens[n_tokens-1] == llama_token_eos(server_context->ctx);
}
// Function matching type llama_beam_search_callback_fn_t.
@ -1509,21 +1532,21 @@ static bool is_at_eob(llama_server_context &server_context, const llama_token *t
// AVOID HEADACHES unnecessaries
// static void beam_search_callback(void *callback_data, llama_beams_state beams_state) {
// auto & llama = *static_cast<llama_server_context*>(callback_data);
// auto & llama = *static_cast<beam_search_callback_data*>(callback_data);
// // Mark beams as EOS as needed.
// for (size_t i = 0 ; i < beams_state.n_beams ; ++i) {
// llama_beam_view& beam_view = beams_state.beam_views[i];
// if (!beam_view.eob && is_at_eob(llama, beam_view.tokens, beam_view.n_tokens)) {
// if (!beam_view.eob && is_at_eob(llama.ctx, beam_view.tokens, beam_view.n_tokens)) {
// beam_view.eob = true;
// }
// }
// printf(","); // Show progress
// if (const size_t n = beams_state.common_prefix_length) {
// llama.generated_token_probs.resize(llama.generated_token_probs.size() + n);
// llama.slot->generated_token_probs.resize(llama.slot->generated_token_probs.size() + n);
// assert(0u < beams_state.n_beams);
// const llama_token * tokens = beams_state.beam_views[0].tokens;
// const auto map = [](llama_token tok) { return completion_token_output{{},tok}; };
// std::transform(tokens, tokens + n, llama.generated_token_probs.end() - n, map);
// std::transform(tokens, tokens + n, llama.slot->generated_token_probs.end() - n, map);
// printf("%zu", n);
// }
// fflush(stdout);
@ -1541,17 +1564,17 @@ struct token_translator {
std::string operator()(const completion_token_output & cto) const { return (*this)(cto.tok); }
};
static void append_to_generated_text_from_generated_token_probs(llama_server_context &llama, llama_client_slot & slot)
static void append_to_generated_text_from_generated_token_probs(llama_server_context &llama, llama_client_slot* slot)
{
auto & gtps = slot.generated_token_probs;
auto & gtps = slot->generated_token_probs;
auto translator = token_translator{llama.ctx};
auto add_strlen = [=](size_t sum, const completion_token_output & cto) { return sum + translator(cto).size(); };
const size_t len = std::accumulate(gtps.begin(), gtps.end(), size_t(0), add_strlen);
if (slot.generated_text.capacity() < slot.generated_text.size() + len) {
slot.generated_text.reserve(slot.generated_text.size() + len);
if (slot->generated_text.capacity() < slot->generated_text.size() + len) {
slot->generated_text.reserve(slot->generated_text.size() + len);
}
for (const completion_token_output & cto : gtps) {
slot.generated_text += translator(cto);
slot->generated_text += translator(cto);
}
}
@ -1662,17 +1685,19 @@ int main(int argc, char **argv)
std::string completion_text = "";
if (llama.params.n_beams) {
// // Fill llama.generated_token_probs vector with final beam.
// llama_beam_search(llama.ctx, beam_search_callback, &llama, llama.params.n_beams,
// slot->n_past, llama.n_remain);
// beam_search_callback_data data_;
// data_.slot = slot;
// data_.ctx = &llama;
// llama_beam_search(llama.ctx, beam_search_callback, &data_, llama.params.n_beams,
// slot->n_past, llama.params.n_predict);
// // Translate llama.generated_token_probs to llama.generated_text.
// append_to_generated_text_from_generated_token_probs(llama);
// append_to_generated_text_from_generated_token_probs(llama, slot);
} else {
while (slot->isProcessing()) {
if(slot->hasNewToken()) {
completion_text += slot->next().text_to_send;
} else {
std::this_thread::sleep_for(std::chrono::milliseconds(5));
std::this_thread::sleep_for(std::chrono::microseconds(5));
}
}
}
@ -1686,7 +1711,7 @@ int main(int argc, char **argv)
const json data = format_final_response(llama, slot, completion_text, probs);
//llama_print_timings(llama.ctx);
slot->release();
res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace),
"application/json");
} else {
@ -1743,8 +1768,7 @@ int main(int argc, char **argv)
return true;
};
auto on_complete = [slot, &llama] (bool) {
slot->sent_tokens = 0;
slot->generated_token_probs.clear();
slot->clean_tokens();
slot->release();
};
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
@ -1781,6 +1805,26 @@ int main(int argc, char **argv)
return;
}
if(!slot->params.stream) {
std::string completion_text = "";
while (slot->isProcessing()) {
if(slot->hasNewToken()) {
completion_text += slot->next().text_to_send;
} else {
std::this_thread::sleep_for(std::chrono::microseconds(5));
}
}
auto probs = slot->generated_token_probs;
if (slot->sparams.n_probs > 0 && slot->stopped_word) {
const std::vector<llama_token> stop_word_toks = llama_tokenize(llama.ctx, slot->stopping_word, false);
probs = std::vector<completion_token_output>(slot->generated_token_probs.begin(), slot->generated_token_probs.end() - stop_word_toks.size());
}
const json data = format_final_response(llama, slot, completion_text, probs);
//llama_print_timings(llama.ctx);
res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace),
"application/json");
} else {
const auto chunked_content_provider = [slot, &llama](size_t, DataSink & sink) {
size_t sent_token_probs_index = 0;
while(slot->isProcessing()) {
@ -1834,11 +1878,11 @@ int main(int argc, char **argv)
return true;
};
auto on_complete = [slot, &llama] (bool) {
slot->sent_tokens = 0;
slot->generated_token_probs.clear();
slot->clean_tokens();
slot->release();
};
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
}
});
svr.Get("/model.json", [&llama](const Request &, Response &res)
@ -1878,9 +1922,7 @@ int main(int argc, char **argv)
svr.Post("/embedding", [&llama](const Request &req, Response &res)
{
const json body = json::parse(req.body);
llama_client_slot* slot = llama.getSlot(-1);
slot->reset();
//llama_reset_timings(llama.ctx);
if (body.count("content") != 0)