fix long prompt than ctx proposed in #3639

This commit is contained in:
FSSRepo 2023-10-15 19:07:18 -04:00
parent b727e022d6
commit fd64f04fc2

View File

@ -395,7 +395,7 @@ struct llama_client_slot
} }
bool isProcessing() { bool isProcessing() {
return (state == IDLE || state == SLEEPING) && command == LOAD_PROMPT || state == PROCESSING; return ((state == IDLE || state == SLEEPING) && command == LOAD_PROMPT) || state == PROCESSING;
} }
completion_token_output next() { completion_token_output next() {
@ -1041,26 +1041,22 @@ struct llama_server_context
//if input prompt is too big, truncate like normal //if input prompt is too big, truncate like normal
if (slot.num_prompt_tokens >= (size_t)n_ctx) if (slot.num_prompt_tokens >= (size_t)n_ctx)
{ {
const int n_left = (n_ctx - params.n_keep) / 2; const int n_left = n_ctx - params.n_keep;
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep); std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
const int erased_blocks = (slot.num_prompt_tokens - params.n_keep - n_left - 1) / n_left; // Use half the left-over space in the context for the prompt
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end()); new_tokens.insert(new_tokens.end(), prompt_tokens.end() - n_left / 2, prompt_tokens.end());
std::copy(prompt_tokens.end() - n_ctx, prompt_tokens.end(), slot.last_n_tokens.begin());
LOG_VERBOSE("input truncated", { LOG_VERBOSE("input truncated", {
{"n_ctx", n_ctx}, {"n_ctx", n_ctx},
{"n_keep", params.n_keep}, {"n_keep", params.n_keep},
{"n_left", n_left}, {"n_left", n_left},
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())}, {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
}); });
slot.truncated = true; slot.truncated = true;
prompt_tokens = new_tokens; prompt_tokens = new_tokens;
} else {
const size_t ps = slot.num_prompt_tokens;
std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end() - ps, 0);
std::copy(prompt_tokens.begin(), prompt_tokens.end(), slot.last_n_tokens.end() - ps);
} }
const size_t ps = slot.num_prompt_tokens;
std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end() - ps, 0);
std::copy(prompt_tokens.begin(), prompt_tokens.end(), slot.last_n_tokens.end() - ps);
} }
llama_kv_cache_seq_rm(ctx, slot.id, num_tokens_system + slot.n_past, -1); llama_kv_cache_seq_rm(ctx, slot.id, num_tokens_system + slot.n_past, -1);
@ -1925,7 +1921,7 @@ static void beam_search_callback(void *callback_data, llama_beams_state beams_st
llama.slot->generated_token_probs.resize(llama.slot->generated_token_probs.size() + n); llama.slot->generated_token_probs.resize(llama.slot->generated_token_probs.size() + n);
assert(0u < beams_state.n_beams); assert(0u < beams_state.n_beams);
const llama_token * tokens = beams_state.beam_views[0].tokens; const llama_token * tokens = beams_state.beam_views[0].tokens;
const auto map = [](llama_token tok) { return completion_token_output{{},tok}; }; const auto map = [](llama_token tok) { return completion_token_output{{},tok,""}; };
std::transform(tokens, tokens + n, llama.slot->generated_token_probs.end() - n, map); std::transform(tokens, tokens + n, llama.slot->generated_token_probs.end() - n, map);
printf("%zu", n); printf("%zu", n);
} }