From fd64f04fc25e7d6b269e813843bfc47b26ce984d Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Sun, 15 Oct 2023 19:07:18 -0400 Subject: [PATCH] fix long prompt than ctx proposed in #3639 --- examples/server/server.cpp | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index d2abee864..66632a4ed 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -395,7 +395,7 @@ struct llama_client_slot } bool isProcessing() { - return (state == IDLE || state == SLEEPING) && command == LOAD_PROMPT || state == PROCESSING; + return ((state == IDLE || state == SLEEPING) && command == LOAD_PROMPT) || state == PROCESSING; } completion_token_output next() { @@ -1041,26 +1041,22 @@ struct llama_server_context //if input prompt is too big, truncate like normal if (slot.num_prompt_tokens >= (size_t)n_ctx) { - const int n_left = (n_ctx - params.n_keep) / 2; + const int n_left = n_ctx - params.n_keep; std::vector new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep); - const int erased_blocks = (slot.num_prompt_tokens - params.n_keep - n_left - 1) / n_left; - new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end()); - std::copy(prompt_tokens.end() - n_ctx, prompt_tokens.end(), slot.last_n_tokens.begin()); - + // Use half the left-over space in the context for the prompt + new_tokens.insert(new_tokens.end(), prompt_tokens.end() - n_left / 2, prompt_tokens.end()); LOG_VERBOSE("input truncated", { {"n_ctx", n_ctx}, {"n_keep", params.n_keep}, {"n_left", n_left}, {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())}, }); - slot.truncated = true; prompt_tokens = new_tokens; - } else { - const size_t ps = slot.num_prompt_tokens; - std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end() - ps, 0); - std::copy(prompt_tokens.begin(), prompt_tokens.end(), slot.last_n_tokens.end() - ps); } + const size_t ps = slot.num_prompt_tokens; + std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end() - ps, 0); + std::copy(prompt_tokens.begin(), prompt_tokens.end(), slot.last_n_tokens.end() - ps); } llama_kv_cache_seq_rm(ctx, slot.id, num_tokens_system + slot.n_past, -1); @@ -1925,7 +1921,7 @@ static void beam_search_callback(void *callback_data, llama_beams_state beams_st llama.slot->generated_token_probs.resize(llama.slot->generated_token_probs.size() + n); assert(0u < beams_state.n_beams); const llama_token * tokens = beams_state.beam_views[0].tokens; - const auto map = [](llama_token tok) { return completion_token_output{{},tok}; }; + const auto map = [](llama_token tok) { return completion_token_output{{},tok,""}; }; std::transform(tokens, tokens + n, llama.slot->generated_token_probs.end() - n, map); printf("%zu", n); }