server : use refs + use llama_batch_clear()

This commit is contained in:
Georgi Gerganov 2023-10-19 14:44:04 +03:00
parent 3d5929e8ee
commit e3a2c3fe32
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -68,7 +68,8 @@ static const std::string base64_chars =
"abcdefghijklmnopqrstuvwxyz"
"0123456789+/";
static inline bool is_base64(uint8_t c) {
static inline bool is_base64(uint8_t c)
{
return (isalnum(c) || (c == '+') || (c == '/'));
}
@ -174,7 +175,7 @@ struct slot_image
float* image_embedding = nullptr;
int image_tokens = 0;
int id;
std::string prefix_prompt = ""; // before of this image
std::string prefix_prompt; // before of this image
};
// completion token output with probabilities
@ -350,7 +351,7 @@ struct llama_client_slot
int32_t n_remaining = -1;
json prompt;
std::string generated_text = "";
std::string generated_text;
int num_tokens_predicted = 0;
llama_token sampled;
std::vector<llama_token> cache_tokens;
@ -404,7 +405,7 @@ struct llama_client_slot
ctx_sampling = llama_sampling_init_srv(sparams, params.grammar, max_context_size);
for (slot_image img : images)
for (slot_image &img : images)
{
free(img.image_embedding);
delete[] img.img_data.data;
@ -489,14 +490,14 @@ struct llama_server_context
std::vector<llama_client_slot> slots;
// system prompt
std::string system_prompt = "";
std::string system_prompt;
bool need_update_system_prompt = false;
std::vector<llama_token> tokens_system;
int32_t num_tokens_system;
// broadcast to all clients to keep the same prompt format
std::string user_name = ""; // this should be the anti prompt
std::string assistant_name = ""; // this is for generate the prompt
std::string user_name; // this should be the anti prompt
std::string assistant_name; // this is for generate the prompt
bool multimodal = false;
clip_ctx *clp_ctx = nullptr;
@ -870,7 +871,7 @@ struct llama_server_context
return slot.has_next_token; // continue
}
bool processImages(llama_client_slot &slot)
bool processImages(llama_client_slot &slot) const
{
for (slot_image &img : slot.images)
{
@ -901,6 +902,7 @@ struct llama_server_context
}
img.request_encode_image = false;
}
return slot.images.size() > 0;
}
@ -908,9 +910,10 @@ struct llama_server_context
bool ingest_images(llama_client_slot &slot, int n_batch)
{
int image_idx = 0;
while (image_idx < (int) slot.images.size())
{
slot_image img = slot.images[image_idx];
slot_image &img = slot.images[image_idx];
// process prefix prompt
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch)
@ -942,7 +945,7 @@ struct llama_server_context
n_eval = n_batch;
}
llama_batch batch_img = {int32_t(n_eval), nullptr, (img.image_embedding + i * n_embd), nullptr, nullptr, nullptr, nullptr, slot.n_past, 1, 0, };
llama_batch batch_img = { n_eval, nullptr, (img.image_embedding + i * n_embd), nullptr, nullptr, nullptr, nullptr, slot.n_past, 1, 0, };
if (llama_decode(ctx, batch_img))
{
LOG_TEE("%s : failed to eval image\n", __func__);
@ -952,8 +955,9 @@ struct llama_server_context
}
image_idx++;
llama_batch_clear(batch);
// append prefix of next image
batch.n_tokens = 0;
const auto json_prompt = (image_idx >= (int) slot.images.size()) ?
slot.params.input_suffix : // no more images, then process suffix prompt
(json)(slot.images[image_idx].prefix_prompt);
@ -975,7 +979,8 @@ struct llama_server_context
update_system_prompt();
}
batch.n_tokens = 0;
llama_batch_clear(batch);
int kv_cache_free = (n_ctx - num_tokens_system);
if (all_slots_are_idle)