diff --git a/examples/server/server.cpp b/examples/server/server.cpp index f0b89b22c..1c7f0fd1d 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -378,8 +378,8 @@ struct server_queue { std::condition_variable condition_tasks; // callback functions - std::function callback_new_task; - std::function callback_update_slots; + std::function callback_new_task; + std::function callback_update_slots; // Add a new task to the end of the queue int post(server_task task, bool front = false) { @@ -431,7 +431,7 @@ struct server_queue { } // Register function to process a new task - void on_new_task(std::function callback) { + void on_new_task(std::function callback) { callback_new_task = std::move(callback); } @@ -481,7 +481,7 @@ struct server_queue { lock.unlock(); QUE_DBG("processing task, id = %d\n", task.id); - callback_new_task(task); + callback_new_task(std::move(task)); } // all tasks in the current loop is processed, slots data is now ready @@ -644,17 +644,12 @@ struct server_context { bool load_model(const common_params & params_) { params = params_; - // reserve one extra sequence (seq_id == 0) for extra features - params.n_parallel += 1; - common_init_result llama_init = common_init_from_params(params); model = llama_init.model; ctx = llama_init.context; loras = llama_init.lora_adapters; - params.n_parallel -= 1; // but be sneaky about it - if (model == nullptr) { SRV_ERR("failed to load model, '%s'\n", params.model.c_str()); return false; @@ -1288,16 +1283,16 @@ struct server_context { void send_embedding(const server_slot & slot, const llama_batch & batch) { server_task_result res; - res.id = slot.id_task; - res.error = false; - res.stop = true; + res.id = slot.id_task; + res.error = false; + res.stop = true; const int n_embd = llama_n_embd(model); std::vector embd_res(n_embd, 0.0f); for (int i = 0; i < batch.n_tokens; ++i) { - if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) { + if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { continue; } @@ -1332,12 +1327,12 @@ struct server_context { void send_rerank(const server_slot & slot, const llama_batch & batch) { server_task_result res; - res.id = slot.id_task; - res.error = false; - res.stop = true; + res.id = slot.id_task; + res.error = false; + res.stop = true; for (int i = 0; i < batch.n_tokens; ++i) { - if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) { + if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { continue; } @@ -1510,7 +1505,7 @@ struct server_context { // Functions to process the task // - void process_single_task(const server_task & task) { + void process_single_task(server_task task) { switch (task.type) { case SERVER_TASK_TYPE_INFERENCE: { @@ -1646,7 +1641,7 @@ struct server_context { std::string filename = task.data.at("filename"); std::string filepath = task.data.at("filepath"); - const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), token_count); + const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), token_count); const int64_t t_end = ggml_time_us(); const double t_save_ms = (t_end - t_start) / 1000.0; @@ -1688,7 +1683,7 @@ struct server_context { slot->cache_tokens.resize(slot->n_ctx); size_t token_count = 0; - size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), slot->cache_tokens.size(), &token_count); + size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), slot->cache_tokens.size(), &token_count); if (nread == 0) { slot->cache_tokens.resize(0); send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST); @@ -1731,7 +1726,7 @@ struct server_context { // Erase token cache const size_t n_erased = slot->cache_tokens.size(); - llama_kv_cache_seq_rm(ctx, slot->id + 1, -1, -1); + llama_kv_cache_seq_rm(ctx, slot->id, -1, -1); slot->cache_tokens.clear(); server_task_result result; @@ -1808,8 +1803,8 @@ struct server_context { SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard); - llama_kv_cache_seq_rm (ctx, slot.id + 1, n_keep , n_keep + n_discard); - llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_discard, slot.n_past, -n_discard); + llama_kv_cache_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard); + llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past, -n_discard); if (slot.params.cache_prompt) { for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) { @@ -1836,7 +1831,7 @@ struct server_context { slot.i_batch = batch.n_tokens; - common_batch_add(batch, slot.sampled, slot.n_past, { slot.id + 1 }, true); + common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true); slot.n_past += 1; @@ -1983,8 +1978,8 @@ struct server_context { const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c; - llama_kv_cache_seq_rm (ctx, slot.id + 1, head_p, head_c); - llama_kv_cache_seq_add(ctx, slot.id + 1, head_c, -1, kv_shift); + llama_kv_cache_seq_rm (ctx, slot.id, head_p, head_c); + llama_kv_cache_seq_add(ctx, slot.id, head_c, -1, kv_shift); for (size_t i = 0; i < n_match; i++) { slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i]; @@ -2033,9 +2028,9 @@ struct server_context { } // keep only the common part - if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, slot.n_past, -1)) { + if (!llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1)) { // could not partially delete (likely using a non-Transformer model) - llama_kv_cache_seq_rm(ctx, slot.id + 1, -1, -1); + llama_kv_cache_seq_rm(ctx, slot.id, -1, -1); // there is no common part left slot.n_past = 0; @@ -2048,7 +2043,7 @@ struct server_context { // add prompt tokens for processing in the current batch while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) { - common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id + 1 }, false); + common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, false); if (slot.params.cache_prompt) { slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);