save dev progress

This commit is contained in:
FSSRepo 2023-10-12 12:51:48 -04:00
parent 471230202d
commit 78504218b9
2 changed files with 59 additions and 43 deletions

1
.gitignore vendored
View File

@ -10,6 +10,7 @@
*.gcno *.gcno
*.gcda *.gcda
*.dot *.dot
*.bat
*.metallib *.metallib
.DS_Store .DS_Store
.build/ .build/

View File

@ -382,7 +382,7 @@ struct llama_server_context
gpt_params params; gpt_params params;
int n_ctx; int n_ctx;
int n_vocab; int n_vocab;
bool clean_kv_cache = true;
std::mutex mutex; std::mutex mutex;
std::unique_lock<std::mutex> lock() std::unique_lock<std::mutex> lock()
@ -484,6 +484,7 @@ struct llama_server_context
else else
{ {
auto s = json_prompt.template get<std::string>(); auto s = json_prompt.template get<std::string>();
printf("----------------------\nprompt:\n%s-----------------------\n", s.c_str());
prompt_tokens = ::llama_tokenize(ctx, s, add_bos); prompt_tokens = ::llama_tokenize(ctx, s, add_bos);
} }
@ -622,17 +623,22 @@ struct llama_server_context
// has_next_token = true; // has_next_token = true;
} }
void cleanKVCache() {
// clear the entire KV cache
for (int i = 0; i < params.n_parallel; ++i)
{
llama_kv_cache_seq_rm(ctx, i, 0, -1);
}
clean_kv_cache = false;
}
void updateSystemPrompt() { void updateSystemPrompt() {
tokens_system = ::llama_tokenize(ctx, system_prompt, true); tokens_system = ::llama_tokenize(ctx, system_prompt, true);
n_tokens_system = tokens_system.size(); n_tokens_system = tokens_system.size();
batch.n_tokens = n_tokens_system; batch.n_tokens = n_tokens_system;
// clear the entire KV cache cleanKVCache();
for (int i = 0; i < params.n_parallel; ++i)
{
llama_kv_cache_seq_rm(ctx, i, 0, -1);
}
for (int32_t i = 0; i < batch.n_tokens; ++i) for (int32_t i = 0; i < batch.n_tokens; ++i)
{ {
@ -732,6 +738,7 @@ struct llama_server_context
slot.last_n_tokens.erase(slot.last_n_tokens.begin()); slot.last_n_tokens.erase(slot.last_n_tokens.begin());
slot.last_n_tokens.push_back(result.tok); slot.last_n_tokens.push_back(result.tok);
const std::string token_str = llama_token_to_piece(ctx, result.tok); const std::string token_str = llama_token_to_piece(ctx, result.tok);
printf("%s", token_str.c_str());
slot.sampled = result.tok; slot.sampled = result.tok;
size_t stop_pos = size_t stop_pos =
@ -819,6 +826,9 @@ struct llama_server_context
int kv_cache_free = (n_ctx - n_tokens_system); int kv_cache_free = (n_ctx - n_tokens_system);
if(all_slots_are_idle) { if(all_slots_are_idle) {
if(system_prompt.empty() && clean_kv_cache) {
cleanKVCache();
}
// avoid 100% usage of cpu all time // avoid 100% usage of cpu all time
std::this_thread::sleep_for(std::chrono::milliseconds(5)); std::this_thread::sleep_for(std::chrono::milliseconds(5));
} }
@ -865,6 +875,7 @@ struct llama_server_context
// need process the prompt // need process the prompt
bool keep_gen = slot.state == SLEEPING; // remember generation bool keep_gen = slot.state == SLEEPING; // remember generation
if ((slot.state == IDLE || keep_gen) && slot.command == LOAD_PROMPT) { if ((slot.state == IDLE || keep_gen) && slot.command == LOAD_PROMPT) {
LOG_TEE("processing prompt\n");
slot.state = PROCESSING; slot.state = PROCESSING;
slot.command = NONE; slot.command = NONE;
@ -881,8 +892,12 @@ struct llama_server_context
{"to_eval", tokens_to_str(ctx, slot.context_tokens.cbegin() + slot.n_past, slot.context_tokens.cend())}, {"to_eval", tokens_to_str(ctx, slot.context_tokens.cbegin() + slot.n_past, slot.context_tokens.cend())},
}); });
std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end(), 0); if(system_prompt.empty()) {
LOG_TEE("cleaning kv: %i\n", slot.n_past);
llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1);
}
std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end(), 0);
for (size_t i = slot.n_past; i < slot.context_tokens.size(); ++i) { for (size_t i = slot.n_past; i < slot.context_tokens.size(); ++i) {
batch.token [batch.n_tokens] = slot.context_tokens[i]; batch.token [batch.n_tokens] = slot.context_tokens[i];
batch.pos [batch.n_tokens] = i + n_tokens_system; batch.pos [batch.n_tokens] = i + n_tokens_system;
@ -912,7 +927,6 @@ struct llama_server_context
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
llama_batch batch_view = { llama_batch batch_view = {
n_tokens, n_tokens,
batch.token + i, batch.token + i,
@ -1773,55 +1787,56 @@ int main(int argc, char **argv)
// res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace), // res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace),
// "application/json"); // "application/json");
} else { } else {
auto chunked_content_provider = [&](size_t /*offset*/, DataSink &sink) { printf("processing -> %s\n", slot->isProcessing() ? "true" : "false");
const auto chunked_content_provider = [slot](size_t, DataSink & sink) {
size_t sent_count = 0; size_t sent_count = 0;
size_t sent_token_probs_index = 0; size_t sent_token_probs_index = 0;
while(slot->isProcessing()) { while(slot->isProcessing()) {
if(slot->hasNewToken()) { // new token notification if(slot->hasNewToken()) { // new token notification
const completion_token_output token = slot->next(); // const completion_token_output token = slot->next();
std::string token_str = llama_token_to_piece(llama.ctx, token.tok); // std::string token_str = llama_token_to_piece(llama.ctx, token.tok);
std::vector<completion_token_output> probs_output = {}; // std::vector<completion_token_output> probs_output = {};
const json data = format_partial_response(llama, slot, token_str, probs_output); // const json data = format_partial_response(llama, slot, token_str, probs_output);
const std::string str = // const std::string str =
"data: " + // "data: " +
data.dump(-1, ' ', false, json::error_handler_t::replace) + // data.dump(-1, ' ', false, json::error_handler_t::replace) +
"\n\n"; // "\n\n";
LOG_VERBOSE("data stream", { // LOG_VERBOSE("data stream", {
{ "to_send", str } // { "to_send", str }
}); // });
if(!sink.write(str.c_str(), str.size())) { // if(!sink.write(str.c_str(), str.size())) {
slot->release(); // slot->release();
return false; // return false;
} // }
} else { } else {
std::this_thread::sleep_for(std::chrono::milliseconds(5)); std::this_thread::sleep_for(std::chrono::milliseconds(5));
} }
} }
const json data = format_final_response( // const json data = format_final_response(
llama, slot, // llama, slot,
"", // "",
std::vector<completion_token_output>( // std::vector<completion_token_output>(
slot->generated_token_probs.begin(), // slot->generated_token_probs.begin(),
slot->generated_token_probs.begin() + sent_token_probs_index) // slot->generated_token_probs.begin() + sent_token_probs_index)
); // );
const std::string str = // const std::string str =
"data: " + // "data: " +
data.dump(-1, ' ', false, json::error_handler_t::replace) + // data.dump(-1, ' ', false, json::error_handler_t::replace) +
"\n\n"; // "\n\n";
LOG_VERBOSE("data stream", { // LOG_VERBOSE("data stream", {
{ "to_send", str } // { "to_send", str }
}); // });
if (!sink.write(str.data(), str.size())) { // if (!sink.write(str.data(), str.size())) {
LOG_VERBOSE("stream closed", {}); // LOG_VERBOSE("stream closed", {});
llama_print_timings(llama.ctx); // llama_print_timings(llama.ctx);
return false; // return false;
} // }
sink.done(); sink.done();
return true; return true;
}; };