mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 10:54:36 +00:00
server : smart slot selection using Longest Common Prefix (#7728)
* server : Smart selection of available slot using Longest Common Substring * add usage * remove trailing whitespaces * Use Longest Common Prefix (LCP) instead of LCS * Rename argument
This commit is contained in:
parent
da799b4189
commit
7a16ce7db2
@ -1491,6 +1491,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
|||||||
params.chat_template = argv[i];
|
params.chat_template = argv[i];
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
if (arg == "--slot-prompt-similarity" || arg == "-sps") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
params.slot_prompt_similarity = std::stof(argv[i]);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
if (arg == "-pps") {
|
if (arg == "-pps") {
|
||||||
params.is_pp_shared = true;
|
params.is_pp_shared = true;
|
||||||
return true;
|
return true;
|
||||||
@ -1913,6 +1921,8 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
|
|||||||
"set custom jinja chat template (default: template taken from model's metadata)\n"
|
"set custom jinja chat template (default: template taken from model's metadata)\n"
|
||||||
"only commonly used templates are accepted:\n"
|
"only commonly used templates are accepted:\n"
|
||||||
"https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template" });
|
"https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template" });
|
||||||
|
options.push_back({ "server", "-sps, --slot-prompt-similarity SIMILARITY",
|
||||||
|
"how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity });
|
||||||
|
|
||||||
#ifndef LOG_DISABLE_LOGS
|
#ifndef LOG_DISABLE_LOGS
|
||||||
options.push_back({ "logging" });
|
options.push_back({ "logging" });
|
||||||
|
@ -203,6 +203,8 @@ struct gpt_params {
|
|||||||
|
|
||||||
std::string slot_save_path;
|
std::string slot_save_path;
|
||||||
|
|
||||||
|
float slot_prompt_similarity = 0.5f;
|
||||||
|
|
||||||
// batched-bench params
|
// batched-bench params
|
||||||
bool is_pp_shared = false;
|
bool is_pp_shared = false;
|
||||||
|
|
||||||
|
@ -647,6 +647,9 @@ struct server_context {
|
|||||||
|
|
||||||
server_metrics metrics;
|
server_metrics metrics;
|
||||||
|
|
||||||
|
// Necessary similarity of prompt for slot selection
|
||||||
|
float slot_prompt_similarity = 0.0f;
|
||||||
|
|
||||||
~server_context() {
|
~server_context() {
|
||||||
if (ctx) {
|
if (ctx) {
|
||||||
llama_free(ctx);
|
llama_free(ctx);
|
||||||
@ -795,24 +798,88 @@ struct server_context {
|
|||||||
return prompt_tokens;
|
return prompt_tokens;
|
||||||
}
|
}
|
||||||
|
|
||||||
server_slot * get_slot(int id) {
|
server_slot * get_slot_by_id(int id) {
|
||||||
int64_t t_last = ggml_time_us();
|
|
||||||
|
|
||||||
server_slot * last_used = nullptr;
|
|
||||||
|
|
||||||
for (server_slot & slot : slots) {
|
for (server_slot & slot : slots) {
|
||||||
if (slot.id == id && slot.available()) {
|
if (slot.id == id) {
|
||||||
return &slot;
|
return &slot;
|
||||||
}
|
}
|
||||||
|
|
||||||
// among all available slots, find the one that has been least recently used
|
|
||||||
if (slot.available() && slot.t_last_used < t_last) {
|
|
||||||
last_used = &slot;
|
|
||||||
t_last = slot.t_last_used;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return last_used;
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
server_slot * get_available_slot(const std::string & prompt) {
|
||||||
|
server_slot * ret = nullptr;
|
||||||
|
|
||||||
|
// find the slot that has at least n% prompt similarity
|
||||||
|
if (ret == nullptr && slot_prompt_similarity != 0.0f && !prompt.empty()) {
|
||||||
|
int max_lcp_len = 0;
|
||||||
|
float similarity = 0;
|
||||||
|
|
||||||
|
for (server_slot & slot : slots) {
|
||||||
|
// skip the slot if it is not available
|
||||||
|
if (!slot.available()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// skip the slot if it does not contains prompt
|
||||||
|
if (!slot.prompt.is_string()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// current slot's prompt
|
||||||
|
std::string slot_prompt = slot.prompt.get<std::string>();
|
||||||
|
|
||||||
|
// length of the current slot's prompt
|
||||||
|
int slot_prompt_len = slot_prompt.size();
|
||||||
|
|
||||||
|
// length of the Longest Common Prefix between the current slot's prompt and the input prompt
|
||||||
|
int lcp_len = common_part(slot_prompt, prompt);
|
||||||
|
|
||||||
|
// fraction of the common substring length compared to the current slot's prompt length
|
||||||
|
similarity = static_cast<float>(lcp_len) / slot_prompt_len;
|
||||||
|
|
||||||
|
// select the current slot if the criteria match
|
||||||
|
if (lcp_len > max_lcp_len && similarity > slot_prompt_similarity) {
|
||||||
|
max_lcp_len = lcp_len;
|
||||||
|
ret = &slot;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ret != nullptr) {
|
||||||
|
LOG_VERBOSE("selected slot by lcp similarity", {
|
||||||
|
{"id_slot", ret->id},
|
||||||
|
{"max_lcp_len", max_lcp_len},
|
||||||
|
{"similarity", similarity},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// find the slot that has been least recently used
|
||||||
|
if (ret == nullptr) {
|
||||||
|
int64_t t_last = ggml_time_us();
|
||||||
|
for (server_slot & slot : slots) {
|
||||||
|
// skip the slot if it is not available
|
||||||
|
if (!slot.available()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// select the current slot if the criteria match
|
||||||
|
if (slot.t_last_used < t_last) {
|
||||||
|
t_last = slot.t_last_used;
|
||||||
|
ret = &slot;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ret != nullptr) {
|
||||||
|
LOG_VERBOSE("selected slot by lru", {
|
||||||
|
{"id_slot", ret->id},
|
||||||
|
{"t_last", t_last},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool launch_slot_with_task(server_slot & slot, const server_task & task) {
|
bool launch_slot_with_task(server_slot & slot, const server_task & task) {
|
||||||
@ -1515,13 +1582,29 @@ struct server_context {
|
|||||||
switch (task.type) {
|
switch (task.type) {
|
||||||
case SERVER_TASK_TYPE_COMPLETION:
|
case SERVER_TASK_TYPE_COMPLETION:
|
||||||
{
|
{
|
||||||
server_slot * slot = get_slot(json_value(task.data, "id_slot", -1));
|
int id_slot = json_value(task.data, "id_slot", -1);
|
||||||
|
std::string prompt = json_value(task.data, "prompt", std::string());
|
||||||
|
|
||||||
|
server_slot * slot;
|
||||||
|
|
||||||
|
if (id_slot != -1) {
|
||||||
|
slot = get_slot_by_id(id_slot);
|
||||||
|
} else {
|
||||||
|
slot = get_available_slot(prompt);
|
||||||
|
}
|
||||||
|
|
||||||
if (slot == nullptr) {
|
if (slot == nullptr) {
|
||||||
// if no slot is available, we defer this task for processing later
|
// if no slot is available, we defer this task for processing later
|
||||||
LOG_VERBOSE("no slot is available", {{"id_task", task.id}});
|
LOG_VERBOSE("no slot is available", {{"id_task", task.id}});
|
||||||
queue_tasks.defer(task);
|
queue_tasks.defer(task);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
if (!slot->available()) {
|
||||||
|
// if requested slot is unavailable, we defer this task for processing later
|
||||||
|
LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}});
|
||||||
|
queue_tasks.defer(task);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
if (task.data.contains("system_prompt")) {
|
if (task.data.contains("system_prompt")) {
|
||||||
std::string sys_prompt = json_value(task.data, "system_prompt", std::string());
|
std::string sys_prompt = json_value(task.data, "system_prompt", std::string());
|
||||||
@ -1638,11 +1721,17 @@ struct server_context {
|
|||||||
case SERVER_TASK_TYPE_SLOT_SAVE:
|
case SERVER_TASK_TYPE_SLOT_SAVE:
|
||||||
{
|
{
|
||||||
int id_slot = task.data.at("id_slot");
|
int id_slot = task.data.at("id_slot");
|
||||||
server_slot * slot = get_slot(id_slot);
|
server_slot * slot = get_slot_by_id(id_slot);
|
||||||
if (slot == nullptr) {
|
if (slot == nullptr) {
|
||||||
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
|
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
if (!slot->available()) {
|
||||||
|
// if requested slot is unavailable, we defer this task for processing later
|
||||||
|
LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}});
|
||||||
|
queue_tasks.defer(task);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
const size_t token_count = slot->cache_tokens.size();
|
const size_t token_count = slot->cache_tokens.size();
|
||||||
const int64_t t_start = ggml_time_us();
|
const int64_t t_start = ggml_time_us();
|
||||||
@ -1673,11 +1762,17 @@ struct server_context {
|
|||||||
case SERVER_TASK_TYPE_SLOT_RESTORE:
|
case SERVER_TASK_TYPE_SLOT_RESTORE:
|
||||||
{
|
{
|
||||||
int id_slot = task.data.at("id_slot");
|
int id_slot = task.data.at("id_slot");
|
||||||
server_slot * slot = get_slot(id_slot);
|
server_slot * slot = get_slot_by_id(id_slot);
|
||||||
if (slot == nullptr) {
|
if (slot == nullptr) {
|
||||||
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
|
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
if (!slot->available()) {
|
||||||
|
// if requested slot is unavailable, we defer this task for processing later
|
||||||
|
LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}});
|
||||||
|
queue_tasks.defer(task);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
const int64_t t_start = ggml_time_us();
|
const int64_t t_start = ggml_time_us();
|
||||||
|
|
||||||
@ -1715,11 +1810,17 @@ struct server_context {
|
|||||||
case SERVER_TASK_TYPE_SLOT_ERASE:
|
case SERVER_TASK_TYPE_SLOT_ERASE:
|
||||||
{
|
{
|
||||||
int id_slot = task.data.at("id_slot");
|
int id_slot = task.data.at("id_slot");
|
||||||
server_slot * slot = get_slot(id_slot);
|
server_slot * slot = get_slot_by_id(id_slot);
|
||||||
if (slot == nullptr) {
|
if (slot == nullptr) {
|
||||||
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
|
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
if (!slot->available()) {
|
||||||
|
// if requested slot is unavailable, we defer this task for processing later
|
||||||
|
LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}});
|
||||||
|
queue_tasks.defer(task);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
// Erase token cache
|
// Erase token cache
|
||||||
const size_t n_erased = slot->cache_tokens.size();
|
const size_t n_erased = slot->cache_tokens.size();
|
||||||
@ -2467,6 +2568,9 @@ int main(int argc, char ** argv) {
|
|||||||
log_data["api_key"] = "api_key: " + std::to_string(params.api_keys.size()) + " keys loaded";
|
log_data["api_key"] = "api_key: " + std::to_string(params.api_keys.size()) + " keys loaded";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Necessary similarity of prompt for slot selection
|
||||||
|
ctx_server.slot_prompt_similarity = params.slot_prompt_similarity;
|
||||||
|
|
||||||
// load the model
|
// load the model
|
||||||
if (!ctx_server.load_model(params)) {
|
if (!ctx_server.load_model(params)) {
|
||||||
state.store(SERVER_STATE_ERROR);
|
state.store(SERVER_STATE_ERROR);
|
||||||
|
@ -253,6 +253,13 @@ static size_t common_part(const std::vector<llama_token> & a, const std::vector<
|
|||||||
return i;
|
return i;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static size_t common_part(const std::string & a, const std::string & b) {
|
||||||
|
size_t i;
|
||||||
|
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
|
||||||
|
|
||||||
|
return i;
|
||||||
|
}
|
||||||
|
|
||||||
static bool ends_with(const std::string & str, const std::string & suffix) {
|
static bool ends_with(const std::string & str, const std::string & suffix) {
|
||||||
return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
|
return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user