mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 20:04:35 +00:00
add context swap
This commit is contained in:
parent
b6d9e212e5
commit
a2c2d98c16
@ -79,7 +79,7 @@ enum slot_command {
|
|||||||
struct slot_params {
|
struct slot_params {
|
||||||
bool stream = true;
|
bool stream = true;
|
||||||
uint32_t seed = -1; // RNG seed
|
uint32_t seed = -1; // RNG seed
|
||||||
int32_t n_predict = 128; // new tokens to predict
|
int32_t n_predict = -1; // new tokens to predict
|
||||||
std::string grammar = ""; // optional BNF-like grammar to constrain sampling
|
std::string grammar = ""; // optional BNF-like grammar to constrain sampling
|
||||||
bool cache_prompt = false; // remember a the prompt to avoid reprocessing all prompt
|
bool cache_prompt = false; // remember a the prompt to avoid reprocessing all prompt
|
||||||
std::vector<std::string> antiprompt;
|
std::vector<std::string> antiprompt;
|
||||||
@ -224,6 +224,7 @@ struct llama_client_slot
|
|||||||
int32_t i_batch = -1;
|
int32_t i_batch = -1;
|
||||||
int32_t num_prompt_tokens = 0;
|
int32_t num_prompt_tokens = 0;
|
||||||
int32_t num_prompt_tokens_processed = 0;
|
int32_t num_prompt_tokens_processed = 0;
|
||||||
|
int32_t n_remaining = -1;
|
||||||
|
|
||||||
json prompt;
|
json prompt;
|
||||||
std::string generated_text = "";
|
std::string generated_text = "";
|
||||||
@ -308,6 +309,16 @@ struct llama_client_slot
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool hasBudget(gpt_params &global_params) {
|
||||||
|
n_remaining = -1;
|
||||||
|
if(params.n_predict != -1) {
|
||||||
|
n_remaining = params.n_predict - n_decoded;
|
||||||
|
} else if(global_params.n_predict != -1) {
|
||||||
|
n_remaining = global_params.n_predict - n_decoded;
|
||||||
|
}
|
||||||
|
return n_remaining > 0 || n_remaining == -1; // no budget || limitless
|
||||||
|
}
|
||||||
|
|
||||||
bool hasNewToken() {
|
bool hasNewToken() {
|
||||||
return num_tokens_predicted > sent_tokens;
|
return num_tokens_predicted > sent_tokens;
|
||||||
}
|
}
|
||||||
@ -607,6 +618,8 @@ struct llama_server_context
|
|||||||
slot.last_n_tokens.push_back(result.tok);
|
slot.last_n_tokens.push_back(result.tok);
|
||||||
const std::string token_str = llama_token_to_piece(ctx, result.tok);
|
const std::string token_str = llama_token_to_piece(ctx, result.tok);
|
||||||
slot.sampled = result.tok;
|
slot.sampled = result.tok;
|
||||||
|
|
||||||
|
// search stop word and delete it
|
||||||
slot.generated_text += token_str;
|
slot.generated_text += token_str;
|
||||||
size_t pos = std::min(slot.sent_count, slot.generated_text.size());
|
size_t pos = std::min(slot.sent_count, slot.generated_text.size());
|
||||||
const std::string str_test = slot.generated_text.substr(pos);
|
const std::string str_test = slot.generated_text.substr(pos);
|
||||||
@ -623,17 +636,17 @@ struct llama_server_context
|
|||||||
stop_pos = findStoppingStrings(str_test, token_str.size(),
|
stop_pos = findStoppingStrings(str_test, token_str.size(),
|
||||||
STOP_PARTIAL, slot);
|
STOP_PARTIAL, slot);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// check if there is any token to predict
|
||||||
bool has_next_token = !is_stop_full && stop_pos > 0;
|
bool has_next_token = !is_stop_full && stop_pos > 0;
|
||||||
if(stop_pos == std::string::npos) {
|
if(stop_pos == std::string::npos) {
|
||||||
|
// no send the stop word in the response
|
||||||
result.text_to_send = slot.generated_text.substr(pos, std::string::npos);
|
result.text_to_send = slot.generated_text.substr(pos, std::string::npos);
|
||||||
slot.sent_count += result.text_to_send.size();
|
slot.sent_count += result.text_to_send.size();
|
||||||
has_next_token = true;
|
has_next_token = true;
|
||||||
}
|
}
|
||||||
|
// add the token to slot queue and cache
|
||||||
slot.addTokenString(result);
|
slot.addTokenString(result);
|
||||||
if(slot.n_decoded > 2 && (result.tok == llama_token_eos(ctx) ||
|
|
||||||
slot.n_past + slot.n_decoded >= params.n_predict)) {
|
|
||||||
has_next_token = false;
|
|
||||||
}
|
|
||||||
if (slot.sparams.n_probs > 0)
|
if (slot.sparams.n_probs > 0)
|
||||||
{
|
{
|
||||||
slot.generated_token_probs.push_back(result);
|
slot.generated_token_probs.push_back(result);
|
||||||
@ -671,20 +684,25 @@ struct llama_server_context
|
|||||||
has_next_token = true;
|
has_next_token = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!has_next_token && (slot.n_decoded + slot.n_past >= params.n_predict))
|
// check the limits
|
||||||
|
if (
|
||||||
|
slot.n_decoded > 2 && has_next_token && !slot.hasBudget(params))
|
||||||
{
|
{
|
||||||
slot.stopped_limit = true;
|
slot.stopped_limit = true;
|
||||||
|
has_next_token = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!slot.cache_tokens.empty() && result.tok == llama_token_eos(ctx)){
|
if (!slot.cache_tokens.empty() && result.tok == llama_token_eos(ctx)){
|
||||||
slot.stopped_eos = true;
|
slot.stopped_eos = true;
|
||||||
|
has_next_token = false;
|
||||||
LOG_VERBOSE("eos token found", {});
|
LOG_VERBOSE("eos token found", {});
|
||||||
}
|
}
|
||||||
|
|
||||||
LOG_VERBOSE("next token", {
|
LOG_VERBOSE("next token", {
|
||||||
{"token", result.tok},
|
{"token", result.tok},
|
||||||
{"token_text", tokens_to_output_formatted_string(ctx, result.tok)},
|
{"token_text", tokens_to_output_formatted_string(ctx, result.tok)},
|
||||||
{"has_next_token", has_next_token},
|
{"has_next_token", has_next_token},
|
||||||
{"n_remain", (params.n_predict - slot.n_decoded + slot.n_past)},
|
{"n_remain", slot.n_remaining},
|
||||||
{"num_tokens_predicted", slot.num_tokens_predicted},
|
{"num_tokens_predicted", slot.num_tokens_predicted},
|
||||||
{"stopped_eos", slot.stopped_eos},
|
{"stopped_eos", slot.stopped_eos},
|
||||||
{"stopped_word", slot.stopped_word},
|
{"stopped_word", slot.stopped_word},
|
||||||
@ -736,12 +754,13 @@ struct llama_server_context
|
|||||||
}
|
}
|
||||||
|
|
||||||
batch.token [batch.n_tokens] = slot.sampled;
|
batch.token [batch.n_tokens] = slot.sampled;
|
||||||
batch.pos [batch.n_tokens] = num_tokens_system + slot.n_past + slot.n_decoded;
|
batch.pos [batch.n_tokens] = num_tokens_system + slot.n_past;
|
||||||
batch.seq_id[batch.n_tokens] = slot.id;
|
batch.seq_id[batch.n_tokens] = slot.id;
|
||||||
batch.logits[batch.n_tokens] = true;
|
batch.logits[batch.n_tokens] = true;
|
||||||
|
|
||||||
slot.n_decoded += 1;
|
slot.n_decoded += 1;
|
||||||
slot.i_batch = batch.n_tokens;
|
slot.i_batch = batch.n_tokens;
|
||||||
|
slot.n_past += 1;
|
||||||
|
|
||||||
batch.n_tokens += 1;
|
batch.n_tokens += 1;
|
||||||
}
|
}
|
||||||
@ -853,6 +872,37 @@ struct llama_server_context
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// context shift
|
||||||
|
if(slots.size() == 1) {
|
||||||
|
llama_client_slot slot = slots[0];
|
||||||
|
if (slot.cache_tokens.size() >= (size_t)n_ctx)
|
||||||
|
{
|
||||||
|
// Shift context
|
||||||
|
const int n_left = slot.n_past - params.n_keep - 1;
|
||||||
|
const int n_discard = n_left / 2;
|
||||||
|
|
||||||
|
llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
|
||||||
|
llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, slot.n_past, -n_discard);
|
||||||
|
|
||||||
|
for (size_t i = params.n_keep + 1 + n_discard; i < slot.cache_tokens.size(); i++)
|
||||||
|
{
|
||||||
|
slot.cache_tokens[i - n_discard] = slot.cache_tokens[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard);
|
||||||
|
|
||||||
|
slot.n_past -= n_discard;
|
||||||
|
|
||||||
|
slot.truncated = true;
|
||||||
|
|
||||||
|
LOG_VERBOSE("input truncated", {
|
||||||
|
{"n_ctx", n_ctx},
|
||||||
|
{"n_keep", params.n_keep},
|
||||||
|
{"n_left", n_left},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// process in chunks of params.n_batch
|
// process in chunks of params.n_batch
|
||||||
int32_t n_batch = params.n_batch;
|
int32_t n_batch = params.n_batch;
|
||||||
|
|
||||||
@ -1264,9 +1314,6 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params.n_predict = std::stoi(argv[i]);
|
params.n_predict = std::stoi(argv[i]);
|
||||||
if(params.n_predict <= 128) { // this example don't support long prompts
|
|
||||||
params.n_predict = 128;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
@ -1428,7 +1475,7 @@ static void parse_options_completion(const json &body, llama_client_slot* slot,
|
|||||||
slot->sparams.mirostat_tau = json_value(body, "mirostat_tau", default_sparams.mirostat_tau);
|
slot->sparams.mirostat_tau = json_value(body, "mirostat_tau", default_sparams.mirostat_tau);
|
||||||
slot->sparams.mirostat_eta = json_value(body, "mirostat_eta", default_sparams.mirostat_eta);
|
slot->sparams.mirostat_eta = json_value(body, "mirostat_eta", default_sparams.mirostat_eta);
|
||||||
slot->sparams.penalize_nl = json_value(body, "penalize_nl", default_sparams.penalize_nl);
|
slot->sparams.penalize_nl = json_value(body, "penalize_nl", default_sparams.penalize_nl);
|
||||||
//llama.params.n_keep = json_value(body, "n_keep", default_params.n_keep);
|
llama.params.n_keep = json_value(body, "n_keep", -1);
|
||||||
slot->params.seed = json_value(body, "seed", default_params.seed);
|
slot->params.seed = json_value(body, "seed", default_params.seed);
|
||||||
slot->params.grammar = json_value(body, "grammar", default_params.grammar);
|
slot->params.grammar = json_value(body, "grammar", default_params.grammar);
|
||||||
slot->sparams.n_probs = json_value(body, "n_probs", default_sparams.n_probs);
|
slot->sparams.n_probs = json_value(body, "n_probs", default_sparams.n_probs);
|
||||||
|
Loading…
Reference in New Issue
Block a user