save dev progress

This commit is contained in:
FSSRepo 2023-10-12 12:51:48 -04:00
parent 471230202d
commit 78504218b9
2 changed files with 59 additions and 43 deletions

1
.gitignore vendored
View File

@ -10,6 +10,7 @@
*.gcno
*.gcda
*.dot
*.bat
*.metallib
.DS_Store
.build/

View File

@ -382,7 +382,7 @@ struct llama_server_context
gpt_params params;
int n_ctx;
int n_vocab;
bool clean_kv_cache = true;
std::mutex mutex;
std::unique_lock<std::mutex> lock()
@ -484,6 +484,7 @@ struct llama_server_context
else
{
auto s = json_prompt.template get<std::string>();
printf("----------------------\nprompt:\n%s-----------------------\n", s.c_str());
prompt_tokens = ::llama_tokenize(ctx, s, add_bos);
}
@ -622,17 +623,22 @@ struct llama_server_context
// has_next_token = true;
}
void cleanKVCache() {
// clear the entire KV cache
for (int i = 0; i < params.n_parallel; ++i)
{
llama_kv_cache_seq_rm(ctx, i, 0, -1);
}
clean_kv_cache = false;
}
void updateSystemPrompt() {
tokens_system = ::llama_tokenize(ctx, system_prompt, true);
n_tokens_system = tokens_system.size();
batch.n_tokens = n_tokens_system;
// clear the entire KV cache
for (int i = 0; i < params.n_parallel; ++i)
{
llama_kv_cache_seq_rm(ctx, i, 0, -1);
}
cleanKVCache();
for (int32_t i = 0; i < batch.n_tokens; ++i)
{
@ -732,6 +738,7 @@ struct llama_server_context
slot.last_n_tokens.erase(slot.last_n_tokens.begin());
slot.last_n_tokens.push_back(result.tok);
const std::string token_str = llama_token_to_piece(ctx, result.tok);
printf("%s", token_str.c_str());
slot.sampled = result.tok;
size_t stop_pos =
@ -819,6 +826,9 @@ struct llama_server_context
int kv_cache_free = (n_ctx - n_tokens_system);
if(all_slots_are_idle) {
if(system_prompt.empty() && clean_kv_cache) {
cleanKVCache();
}
// avoid 100% usage of cpu all time
std::this_thread::sleep_for(std::chrono::milliseconds(5));
}
@ -865,6 +875,7 @@ struct llama_server_context
// need process the prompt
bool keep_gen = slot.state == SLEEPING; // remember generation
if ((slot.state == IDLE || keep_gen) && slot.command == LOAD_PROMPT) {
LOG_TEE("processing prompt\n");
slot.state = PROCESSING;
slot.command = NONE;
@ -881,8 +892,12 @@ struct llama_server_context
{"to_eval", tokens_to_str(ctx, slot.context_tokens.cbegin() + slot.n_past, slot.context_tokens.cend())},
});
std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end(), 0);
if(system_prompt.empty()) {
LOG_TEE("cleaning kv: %i\n", slot.n_past);
llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1);
}
std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end(), 0);
for (size_t i = slot.n_past; i < slot.context_tokens.size(); ++i) {
batch.token [batch.n_tokens] = slot.context_tokens[i];
batch.pos [batch.n_tokens] = i + n_tokens_system;
@ -912,7 +927,6 @@ struct llama_server_context
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
llama_batch batch_view = {
n_tokens,
batch.token + i,
@ -1773,55 +1787,56 @@ int main(int argc, char **argv)
// res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace),
// "application/json");
} else {
auto chunked_content_provider = [&](size_t /*offset*/, DataSink &sink) {
printf("processing -> %s\n", slot->isProcessing() ? "true" : "false");
const auto chunked_content_provider = [slot](size_t, DataSink & sink) {
size_t sent_count = 0;
size_t sent_token_probs_index = 0;
while(slot->isProcessing()) {
if(slot->hasNewToken()) { // new token notification
const completion_token_output token = slot->next();
std::string token_str = llama_token_to_piece(llama.ctx, token.tok);
// const completion_token_output token = slot->next();
// std::string token_str = llama_token_to_piece(llama.ctx, token.tok);
std::vector<completion_token_output> probs_output = {};
// std::vector<completion_token_output> probs_output = {};
const json data = format_partial_response(llama, slot, token_str, probs_output);
const std::string str =
"data: " +
data.dump(-1, ' ', false, json::error_handler_t::replace) +
"\n\n";
// const json data = format_partial_response(llama, slot, token_str, probs_output);
// const std::string str =
// "data: " +
// data.dump(-1, ' ', false, json::error_handler_t::replace) +
// "\n\n";
LOG_VERBOSE("data stream", {
{ "to_send", str }
});
if(!sink.write(str.c_str(), str.size())) {
slot->release();
return false;
}
// LOG_VERBOSE("data stream", {
// { "to_send", str }
// });
// if(!sink.write(str.c_str(), str.size())) {
// slot->release();
// return false;
// }
} else {
std::this_thread::sleep_for(std::chrono::milliseconds(5));
}
}
const json data = format_final_response(
llama, slot,
"",
std::vector<completion_token_output>(
slot->generated_token_probs.begin(),
slot->generated_token_probs.begin() + sent_token_probs_index)
);
// const json data = format_final_response(
// llama, slot,
// "",
// std::vector<completion_token_output>(
// slot->generated_token_probs.begin(),
// slot->generated_token_probs.begin() + sent_token_probs_index)
// );
const std::string str =
"data: " +
data.dump(-1, ' ', false, json::error_handler_t::replace) +
"\n\n";
// const std::string str =
// "data: " +
// data.dump(-1, ' ', false, json::error_handler_t::replace) +
// "\n\n";
LOG_VERBOSE("data stream", {
{ "to_send", str }
});
// LOG_VERBOSE("data stream", {
// { "to_send", str }
// });
if (!sink.write(str.data(), str.size())) {
LOG_VERBOSE("stream closed", {});
llama_print_timings(llama.ctx);
return false;
}
// if (!sink.write(str.data(), str.size())) {
// LOG_VERBOSE("stream closed", {});
// llama_print_timings(llama.ctx);
// return false;
// }
sink.done();
return true;
};