diff --git a/examples/server/server.cpp b/examples/server/server.cpp index a130f5e09..d839d6ce5 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -79,7 +79,7 @@ enum slot_command { struct slot_params { bool stream = true; uint32_t seed = -1; // RNG seed - int32_t n_predict = 128; // new tokens to predict + int32_t n_predict = -1; // new tokens to predict std::string grammar = ""; // optional BNF-like grammar to constrain sampling bool cache_prompt = false; // remember a the prompt to avoid reprocessing all prompt std::vector antiprompt; @@ -224,6 +224,7 @@ struct llama_client_slot int32_t i_batch = -1; int32_t num_prompt_tokens = 0; int32_t num_prompt_tokens_processed = 0; + int32_t n_remaining = -1; json prompt; std::string generated_text = ""; @@ -308,6 +309,16 @@ struct llama_client_slot return true; } + bool hasBudget(gpt_params &global_params) { + n_remaining = -1; + if(params.n_predict != -1) { + n_remaining = params.n_predict - n_decoded; + } else if(global_params.n_predict != -1) { + n_remaining = global_params.n_predict - n_decoded; + } + return n_remaining > 0 || n_remaining == -1; // no budget || limitless + } + bool hasNewToken() { return num_tokens_predicted > sent_tokens; } @@ -607,6 +618,8 @@ struct llama_server_context slot.last_n_tokens.push_back(result.tok); const std::string token_str = llama_token_to_piece(ctx, result.tok); slot.sampled = result.tok; + + // search stop word and delete it slot.generated_text += token_str; size_t pos = std::min(slot.sent_count, slot.generated_text.size()); const std::string str_test = slot.generated_text.substr(pos); @@ -623,17 +636,17 @@ struct llama_server_context stop_pos = findStoppingStrings(str_test, token_str.size(), STOP_PARTIAL, slot); } + + // check if there is any token to predict bool has_next_token = !is_stop_full && stop_pos > 0; if(stop_pos == std::string::npos) { + // no send the stop word in the response result.text_to_send = slot.generated_text.substr(pos, std::string::npos); slot.sent_count += result.text_to_send.size(); has_next_token = true; } + // add the token to slot queue and cache slot.addTokenString(result); - if(slot.n_decoded > 2 && (result.tok == llama_token_eos(ctx) || - slot.n_past + slot.n_decoded >= params.n_predict)) { - has_next_token = false; - } if (slot.sparams.n_probs > 0) { slot.generated_token_probs.push_back(result); @@ -671,20 +684,25 @@ struct llama_server_context has_next_token = true; } - if (!has_next_token && (slot.n_decoded + slot.n_past >= params.n_predict)) + // check the limits + if ( + slot.n_decoded > 2 && has_next_token && !slot.hasBudget(params)) { slot.stopped_limit = true; + has_next_token = false; } if (!slot.cache_tokens.empty() && result.tok == llama_token_eos(ctx)){ slot.stopped_eos = true; + has_next_token = false; LOG_VERBOSE("eos token found", {}); } + LOG_VERBOSE("next token", { {"token", result.tok}, {"token_text", tokens_to_output_formatted_string(ctx, result.tok)}, {"has_next_token", has_next_token}, - {"n_remain", (params.n_predict - slot.n_decoded + slot.n_past)}, + {"n_remain", slot.n_remaining}, {"num_tokens_predicted", slot.num_tokens_predicted}, {"stopped_eos", slot.stopped_eos}, {"stopped_word", slot.stopped_word}, @@ -736,12 +754,13 @@ struct llama_server_context } batch.token [batch.n_tokens] = slot.sampled; - batch.pos [batch.n_tokens] = num_tokens_system + slot.n_past + slot.n_decoded; + batch.pos [batch.n_tokens] = num_tokens_system + slot.n_past; batch.seq_id[batch.n_tokens] = slot.id; batch.logits[batch.n_tokens] = true; slot.n_decoded += 1; slot.i_batch = batch.n_tokens; + slot.n_past += 1; batch.n_tokens += 1; } @@ -853,6 +872,37 @@ struct llama_server_context return true; } + // context shift + if(slots.size() == 1) { + llama_client_slot slot = slots[0]; + if (slot.cache_tokens.size() >= (size_t)n_ctx) + { + // Shift context + const int n_left = slot.n_past - params.n_keep - 1; + const int n_discard = n_left / 2; + + llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1); + llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, slot.n_past, -n_discard); + + for (size_t i = params.n_keep + 1 + n_discard; i < slot.cache_tokens.size(); i++) + { + slot.cache_tokens[i - n_discard] = slot.cache_tokens[i]; + } + + slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard); + + slot.n_past -= n_discard; + + slot.truncated = true; + + LOG_VERBOSE("input truncated", { + {"n_ctx", n_ctx}, + {"n_keep", params.n_keep}, + {"n_left", n_left}, + }); + } + } + // process in chunks of params.n_batch int32_t n_batch = params.n_batch; @@ -1264,9 +1314,6 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, break; } params.n_predict = std::stoi(argv[i]); - if(params.n_predict <= 128) { // this example don't support long prompts - params.n_predict = 128; - } } else { @@ -1428,7 +1475,7 @@ static void parse_options_completion(const json &body, llama_client_slot* slot, slot->sparams.mirostat_tau = json_value(body, "mirostat_tau", default_sparams.mirostat_tau); slot->sparams.mirostat_eta = json_value(body, "mirostat_eta", default_sparams.mirostat_eta); slot->sparams.penalize_nl = json_value(body, "penalize_nl", default_sparams.penalize_nl); - //llama.params.n_keep = json_value(body, "n_keep", default_params.n_keep); + llama.params.n_keep = json_value(body, "n_keep", -1); slot->params.seed = json_value(body, "seed", default_params.seed); slot->params.grammar = json_value(body, "grammar", default_params.grammar); slot->sparams.n_probs = json_value(body, "n_probs", default_sparams.n_probs);