mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 20:04:35 +00:00
cached prompt support
This commit is contained in:
parent
83c2b3553a
commit
500ac7120e
@ -37,6 +37,11 @@ struct server_params
|
||||
int32_t write_timeout = 600;
|
||||
};
|
||||
|
||||
// struct beam_search_callback_data {
|
||||
// llama_server_context* ctx;
|
||||
// llama_client_slot* slot;
|
||||
// };
|
||||
|
||||
static bool server_verbose = false;
|
||||
|
||||
#if SERVER_VERBOSE != 1
|
||||
@ -76,7 +81,7 @@ struct slot_params {
|
||||
uint32_t seed = -1; // RNG seed
|
||||
int32_t n_predict = 128; // new tokens to predict
|
||||
std::string grammar = ""; // optional BNF-like grammar to constrain sampling
|
||||
bool remember_generation = false; // remember a the prompt to avoid reprocessing all prompt
|
||||
bool cache_prompt = false; // remember a the prompt to avoid reprocessing all prompt
|
||||
std::vector<std::string> antiprompt;
|
||||
json input_prefix;
|
||||
json input_suffix;
|
||||
@ -246,12 +251,9 @@ struct llama_client_slot
|
||||
llama_grammar *grammar = nullptr;
|
||||
|
||||
void reset() {
|
||||
state = IDLE;
|
||||
command = NONE;
|
||||
num_prompt_tokens = 0;
|
||||
num_tokens_predicted = 0;
|
||||
generated_text = "";
|
||||
generated_token_probs.clear();
|
||||
truncated = false;
|
||||
stopped_eos = false;
|
||||
stopped_word = false;
|
||||
@ -261,6 +263,7 @@ struct llama_client_slot
|
||||
n_past = 0;
|
||||
sent_count = 0;
|
||||
infill = false;
|
||||
clean_tokens();
|
||||
|
||||
if (grammar != nullptr) {
|
||||
llama_grammar_free(grammar);
|
||||
@ -300,7 +303,7 @@ struct llama_client_slot
|
||||
}
|
||||
|
||||
bool hasNewToken() {
|
||||
return generated_token_probs.size() > sent_tokens;
|
||||
return num_tokens_predicted > sent_tokens;
|
||||
}
|
||||
|
||||
bool available() {
|
||||
@ -308,7 +311,7 @@ struct llama_client_slot
|
||||
}
|
||||
|
||||
bool isProcessing() {
|
||||
return (state == IDLE && command == LOAD_PROMPT) || state == PROCESSING;
|
||||
return (state == IDLE || state == SLEEPING) && command == LOAD_PROMPT || state == PROCESSING;
|
||||
}
|
||||
|
||||
completion_token_output next() {
|
||||
@ -319,8 +322,6 @@ struct llama_client_slot
|
||||
|
||||
void addTokenString(completion_token_output token) {
|
||||
if(command == RELEASE) {
|
||||
generated_token_probs.clear();
|
||||
sent_tokens = 0;
|
||||
return;
|
||||
}
|
||||
context_tokens.push_back(token.tok);
|
||||
@ -333,6 +334,11 @@ struct llama_client_slot
|
||||
command = RELEASE;
|
||||
}
|
||||
}
|
||||
|
||||
void clean_tokens() {
|
||||
sent_tokens = 0;
|
||||
generated_token_probs.clear();
|
||||
}
|
||||
};
|
||||
|
||||
struct llama_server_context
|
||||
@ -626,9 +632,7 @@ struct llama_server_context
|
||||
const std::string token_str = llama_token_to_piece(ctx, result.tok);
|
||||
slot.sampled = result.tok;
|
||||
slot.generated_text += token_str;
|
||||
|
||||
size_t pos = std::min(slot.sent_count, slot.generated_text.size());
|
||||
|
||||
const std::string str_test = slot.generated_text.substr(pos);
|
||||
bool is_stop_full = false;
|
||||
size_t stop_pos = findStoppingStrings(str_test, token_str.size(), STOP_FULL, slot);
|
||||
@ -737,14 +741,20 @@ struct llama_server_context
|
||||
if (slot.state == PROCESSING && slot.command == RELEASE && !slot.hasNewToken())
|
||||
{
|
||||
LOG_TEE("slot %i released\n", slot.id);
|
||||
slot.state = slot.params.remember_generation ? SLEEPING : IDLE;
|
||||
slot.state = slot.params.cache_prompt ? SLEEPING : IDLE;
|
||||
if(slot.state == SLEEPING) {
|
||||
printf("%i has cached prompt.");
|
||||
}
|
||||
slot.command = NONE;
|
||||
continue;
|
||||
}
|
||||
|
||||
kv_cache_free -= slot.num_prompt_tokens;
|
||||
|
||||
if (slot.state == IDLE || slot.command == RELEASE) {
|
||||
if (
|
||||
slot.state == IDLE ||
|
||||
slot.state == SLEEPING ||
|
||||
slot.command == RELEASE) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -765,8 +775,6 @@ struct llama_server_context
|
||||
// need process the prompt
|
||||
bool keep_gen = slot.state == SLEEPING; // remember generation
|
||||
if ((slot.state == IDLE || keep_gen) && slot.command == LOAD_PROMPT) {
|
||||
slot.state = PROCESSING;
|
||||
slot.command = NONE;
|
||||
std::vector<llama_token> prompt_tokens;
|
||||
if(slot.infill) {
|
||||
bool suff_rm_leading_spc = true;
|
||||
@ -794,6 +802,9 @@ struct llama_server_context
|
||||
|
||||
slot.n_past = keep_gen ? common_part(slot.context_tokens, prompt_tokens) : 0;
|
||||
|
||||
printf("n_past: %i, context: %i, prompt: %i, cache: %s\n",
|
||||
slot.n_past ,slot.context_tokens.size(), prompt_tokens.size(), keep_gen ? "true" : "false");
|
||||
|
||||
slot.context_tokens = prompt_tokens;
|
||||
|
||||
if (slot.n_past == slot.num_prompt_tokens) {
|
||||
@ -812,7 +823,7 @@ struct llama_server_context
|
||||
|
||||
std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end(), 0);
|
||||
for (; slot.n_past < prompt_tokens.size(); ++slot.n_past) {
|
||||
//printf(llama_token_to_piece(ctx, prompt_tokens[slot.n_past]).c_str());
|
||||
printf(llama_token_to_piece(ctx, prompt_tokens[slot.n_past]).c_str());
|
||||
batch.token [batch.n_tokens] = prompt_tokens[slot.n_past];
|
||||
batch.pos [batch.n_tokens] = slot.n_past + num_tokens_system;
|
||||
batch.seq_id[batch.n_tokens] = slot.id;
|
||||
@ -827,6 +838,8 @@ struct llama_server_context
|
||||
|
||||
slot.n_decoded = 0;
|
||||
slot.i_batch = batch.n_tokens - 1;
|
||||
slot.state = PROCESSING;
|
||||
slot.command = NONE;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -868,10 +881,18 @@ struct llama_server_context
|
||||
}
|
||||
|
||||
for (auto & slot : slots) {
|
||||
|
||||
if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// prompt evaluated for embedding
|
||||
if(params.embedding) {
|
||||
slot.release();
|
||||
slot.i_batch = -1;
|
||||
return true;
|
||||
}
|
||||
|
||||
completion_token_output result;
|
||||
const llama_token id = llama_sampling_sample(ctx, NULL, slot.ctx_sampling, slot.last_n_tokens, candidates, slot.i_batch - i);
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
@ -1316,6 +1337,7 @@ static json format_final_response(llama_server_context &llama, llama_client_slot
|
||||
|
||||
json res = json{
|
||||
{"content", content},
|
||||
{"slot_id", slot->id},
|
||||
{"stop", true},
|
||||
{"model", llama.params.model_alias},
|
||||
{"tokens_predicted", slot->num_tokens_predicted},
|
||||
@ -1327,7 +1349,7 @@ static json format_final_response(llama_server_context &llama, llama_client_slot
|
||||
{"stopped_word", slot->stopped_word},
|
||||
{"stopped_limit", slot->stopped_limit},
|
||||
{"stopping_word", slot->stopping_word},
|
||||
{"tokens_cached", slot->n_past},
|
||||
{"tokens_cached", slot->n_past}
|
||||
// {"timings", format_timings(llama)},
|
||||
};
|
||||
|
||||
@ -1383,6 +1405,7 @@ static void parse_options_completion(const json &body, llama_client_slot* slot,
|
||||
llama_sampling_params default_sparams;
|
||||
|
||||
slot->params.stream = json_value(body, "stream", false);
|
||||
slot->params.cache_prompt = json_value(body, "cache_prompt", false);
|
||||
slot->params.n_predict = json_value(body, "n_predict", default_params.n_predict);
|
||||
slot->sparams.top_k = json_value(body, "top_k", default_sparams.top_k);
|
||||
slot->sparams.top_p = json_value(body, "top_p", default_sparams.top_p);
|
||||
@ -1495,8 +1518,8 @@ static void log_server_request(const Request &req, const Response &res)
|
||||
});
|
||||
}
|
||||
|
||||
static bool is_at_eob(llama_server_context &server_context, const llama_token *tokens, const size_t n_tokens) {
|
||||
return n_tokens && tokens[n_tokens-1] == llama_token_eos(server_context.ctx);
|
||||
static bool is_at_eob(llama_server_context * server_context, const llama_token *tokens, const size_t n_tokens) {
|
||||
return n_tokens && tokens[n_tokens-1] == llama_token_eos(server_context->ctx);
|
||||
}
|
||||
|
||||
// Function matching type llama_beam_search_callback_fn_t.
|
||||
@ -1509,21 +1532,21 @@ static bool is_at_eob(llama_server_context &server_context, const llama_token *t
|
||||
// AVOID HEADACHES unnecessaries
|
||||
|
||||
// static void beam_search_callback(void *callback_data, llama_beams_state beams_state) {
|
||||
// auto & llama = *static_cast<llama_server_context*>(callback_data);
|
||||
// auto & llama = *static_cast<beam_search_callback_data*>(callback_data);
|
||||
// // Mark beams as EOS as needed.
|
||||
// for (size_t i = 0 ; i < beams_state.n_beams ; ++i) {
|
||||
// llama_beam_view& beam_view = beams_state.beam_views[i];
|
||||
// if (!beam_view.eob && is_at_eob(llama, beam_view.tokens, beam_view.n_tokens)) {
|
||||
// if (!beam_view.eob && is_at_eob(llama.ctx, beam_view.tokens, beam_view.n_tokens)) {
|
||||
// beam_view.eob = true;
|
||||
// }
|
||||
// }
|
||||
// printf(","); // Show progress
|
||||
// if (const size_t n = beams_state.common_prefix_length) {
|
||||
// llama.generated_token_probs.resize(llama.generated_token_probs.size() + n);
|
||||
// llama.slot->generated_token_probs.resize(llama.slot->generated_token_probs.size() + n);
|
||||
// assert(0u < beams_state.n_beams);
|
||||
// const llama_token * tokens = beams_state.beam_views[0].tokens;
|
||||
// const auto map = [](llama_token tok) { return completion_token_output{{},tok}; };
|
||||
// std::transform(tokens, tokens + n, llama.generated_token_probs.end() - n, map);
|
||||
// std::transform(tokens, tokens + n, llama.slot->generated_token_probs.end() - n, map);
|
||||
// printf("%zu", n);
|
||||
// }
|
||||
// fflush(stdout);
|
||||
@ -1541,17 +1564,17 @@ struct token_translator {
|
||||
std::string operator()(const completion_token_output & cto) const { return (*this)(cto.tok); }
|
||||
};
|
||||
|
||||
static void append_to_generated_text_from_generated_token_probs(llama_server_context &llama, llama_client_slot & slot)
|
||||
static void append_to_generated_text_from_generated_token_probs(llama_server_context &llama, llama_client_slot* slot)
|
||||
{
|
||||
auto & gtps = slot.generated_token_probs;
|
||||
auto & gtps = slot->generated_token_probs;
|
||||
auto translator = token_translator{llama.ctx};
|
||||
auto add_strlen = [=](size_t sum, const completion_token_output & cto) { return sum + translator(cto).size(); };
|
||||
const size_t len = std::accumulate(gtps.begin(), gtps.end(), size_t(0), add_strlen);
|
||||
if (slot.generated_text.capacity() < slot.generated_text.size() + len) {
|
||||
slot.generated_text.reserve(slot.generated_text.size() + len);
|
||||
if (slot->generated_text.capacity() < slot->generated_text.size() + len) {
|
||||
slot->generated_text.reserve(slot->generated_text.size() + len);
|
||||
}
|
||||
for (const completion_token_output & cto : gtps) {
|
||||
slot.generated_text += translator(cto);
|
||||
slot->generated_text += translator(cto);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1662,17 +1685,19 @@ int main(int argc, char **argv)
|
||||
std::string completion_text = "";
|
||||
if (llama.params.n_beams) {
|
||||
// // Fill llama.generated_token_probs vector with final beam.
|
||||
// llama_beam_search(llama.ctx, beam_search_callback, &llama, llama.params.n_beams,
|
||||
// slot->n_past, llama.n_remain);
|
||||
// beam_search_callback_data data_;
|
||||
// data_.slot = slot;
|
||||
// data_.ctx = &llama;
|
||||
// llama_beam_search(llama.ctx, beam_search_callback, &data_, llama.params.n_beams,
|
||||
// slot->n_past, llama.params.n_predict);
|
||||
// // Translate llama.generated_token_probs to llama.generated_text.
|
||||
// append_to_generated_text_from_generated_token_probs(llama);
|
||||
// append_to_generated_text_from_generated_token_probs(llama, slot);
|
||||
} else {
|
||||
|
||||
while (slot->isProcessing()) {
|
||||
if(slot->hasNewToken()) {
|
||||
completion_text += slot->next().text_to_send;
|
||||
} else {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(5));
|
||||
std::this_thread::sleep_for(std::chrono::microseconds(5));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1686,7 +1711,7 @@ int main(int argc, char **argv)
|
||||
const json data = format_final_response(llama, slot, completion_text, probs);
|
||||
|
||||
//llama_print_timings(llama.ctx);
|
||||
|
||||
slot->release();
|
||||
res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace),
|
||||
"application/json");
|
||||
} else {
|
||||
@ -1743,8 +1768,7 @@ int main(int argc, char **argv)
|
||||
return true;
|
||||
};
|
||||
auto on_complete = [slot, &llama] (bool) {
|
||||
slot->sent_tokens = 0;
|
||||
slot->generated_token_probs.clear();
|
||||
slot->clean_tokens();
|
||||
slot->release();
|
||||
};
|
||||
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
|
||||
@ -1781,6 +1805,26 @@ int main(int argc, char **argv)
|
||||
return;
|
||||
}
|
||||
|
||||
if(!slot->params.stream) {
|
||||
std::string completion_text = "";
|
||||
while (slot->isProcessing()) {
|
||||
if(slot->hasNewToken()) {
|
||||
completion_text += slot->next().text_to_send;
|
||||
} else {
|
||||
std::this_thread::sleep_for(std::chrono::microseconds(5));
|
||||
}
|
||||
}
|
||||
auto probs = slot->generated_token_probs;
|
||||
if (slot->sparams.n_probs > 0 && slot->stopped_word) {
|
||||
const std::vector<llama_token> stop_word_toks = llama_tokenize(llama.ctx, slot->stopping_word, false);
|
||||
probs = std::vector<completion_token_output>(slot->generated_token_probs.begin(), slot->generated_token_probs.end() - stop_word_toks.size());
|
||||
}
|
||||
|
||||
const json data = format_final_response(llama, slot, completion_text, probs);
|
||||
//llama_print_timings(llama.ctx);
|
||||
res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace),
|
||||
"application/json");
|
||||
} else {
|
||||
const auto chunked_content_provider = [slot, &llama](size_t, DataSink & sink) {
|
||||
size_t sent_token_probs_index = 0;
|
||||
while(slot->isProcessing()) {
|
||||
@ -1834,11 +1878,11 @@ int main(int argc, char **argv)
|
||||
return true;
|
||||
};
|
||||
auto on_complete = [slot, &llama] (bool) {
|
||||
slot->sent_tokens = 0;
|
||||
slot->generated_token_probs.clear();
|
||||
slot->clean_tokens();
|
||||
slot->release();
|
||||
};
|
||||
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
|
||||
}
|
||||
});
|
||||
|
||||
svr.Get("/model.json", [&llama](const Request &, Response &res)
|
||||
@ -1878,9 +1922,7 @@ int main(int argc, char **argv)
|
||||
svr.Post("/embedding", [&llama](const Request &req, Response &res)
|
||||
{
|
||||
const json body = json::parse(req.body);
|
||||
|
||||
llama_client_slot* slot = llama.getSlot(-1);
|
||||
|
||||
slot->reset();
|
||||
//llama_reset_timings(llama.ctx);
|
||||
if (body.count("content") != 0)
|
||||
|
Loading…
Reference in New Issue
Block a user