examplse : de-shadow

ggml-ci
This commit is contained in:
Georgi Gerganov 2025-01-12 14:25:32 +02:00
parent 82caffa74e
commit 9a735ae6d8
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
16 changed files with 152 additions and 159 deletions

View File

@ -17,19 +17,19 @@
using json = nlohmann::ordered_json;
common_arg & common_arg::set_examples(std::initializer_list<enum llama_example> examples) {
this->examples = std::move(examples);
common_arg & common_arg::set_examples(std::initializer_list<enum llama_example> vals) {
examples = std::move(vals);
return *this;
}
common_arg & common_arg::set_excludes(std::initializer_list<enum llama_example> excludes) {
this->excludes = std::move(excludes);
common_arg & common_arg::set_excludes(std::initializer_list<enum llama_example> vals) {
excludes = std::move(vals);
return *this;
}
common_arg & common_arg::set_env(const char * env) {
help = help + "\n(env: " + env + ")";
this->env = env;
common_arg & common_arg::set_env(const char * val) {
help = help + "\n(env: " + val + ")";
env = val;
return *this;
}
@ -46,8 +46,10 @@ bool common_arg::is_exclude(enum llama_example ex) {
return excludes.find(ex) != excludes.end();
}
bool common_arg::get_value_from_env(std::string & output) {
if (env == nullptr) return false;
bool common_arg::get_value_from_env(std::string & output) const {
if (env == nullptr) {
return false;
}
char * value = std::getenv(env);
if (value) {
output = value;
@ -56,7 +58,7 @@ bool common_arg::get_value_from_env(std::string & output) {
return false;
}
bool common_arg::has_value_from_env() {
bool common_arg::has_value_from_env() const {
return env != nullptr && std::getenv(env);
}
@ -87,7 +89,7 @@ static std::vector<std::string> break_str_into_lines(std::string input, size_t m
return result;
}
std::string common_arg::to_string() {
std::string common_arg::to_string() const {
// params for printing to console
const static int n_leading_spaces = 40;
const static int n_char_per_line_help = 70; // TODO: detect this based on current console
@ -192,8 +194,6 @@ static std::string get_all_kv_cache_types() {
//
static bool common_params_parse_ex(int argc, char ** argv, common_params_context & ctx_arg) {
std::string arg;
const std::string arg_prefix = "--";
common_params & params = ctx_arg.params;
std::unordered_map<std::string, common_arg *> arg_to_options;

View File

@ -53,15 +53,15 @@ struct common_arg {
void (*handler)(common_params & params, const std::string &, const std::string &)
) : args(args), value_hint(value_hint), value_hint_2(value_hint_2), help(help), handler_str_str(handler) {}
common_arg & set_examples(std::initializer_list<enum llama_example> examples);
common_arg & set_excludes(std::initializer_list<enum llama_example> excludes);
common_arg & set_env(const char * env);
common_arg & set_examples(std::initializer_list<enum llama_example> vals);
common_arg & set_excludes(std::initializer_list<enum llama_example> vals);
common_arg & set_env(const char * val);
common_arg & set_sparam();
bool in_example(enum llama_example ex);
bool is_exclude(enum llama_example ex);
bool get_value_from_env(std::string & output);
bool has_value_from_env();
std::string to_string();
bool get_value_from_env(std::string & output) const;
bool has_value_from_env() const;
std::string to_string() const;
};
struct common_params_context {

View File

@ -763,10 +763,12 @@ bool fs_create_directory_with_parents(const std::string & path) {
return true;
#else
// if the path already exists, check whether it's a directory
{
struct stat info;
if (stat(path.c_str(), &info) == 0) {
return S_ISDIR(info.st_mode);
}
}
size_t pos_slash = 1; // skip leading slashes for directory creation
@ -796,7 +798,7 @@ bool fs_create_directory_with_parents(const std::string & path) {
}
std::string fs_get_cache_directory() {
std::string cache_directory = "";
std::string cache_directory;
auto ensure_trailing_slash = [](std::string p) {
// Make sure to add trailing slash
if (p.back() != DIRECTORY_SEPARATOR) {

View File

@ -43,7 +43,7 @@ namespace console {
static bool simple_io = true;
static display_t current_display = reset;
static FILE* out = stdout;
static FILE* fout = stdout;
#if defined (_WIN32)
static void* hConsole;
@ -110,7 +110,7 @@ namespace console {
tty = fopen("/dev/tty", "w+");
if (tty != nullptr) {
out = tty;
fout = tty;
}
}
@ -126,7 +126,7 @@ namespace console {
// Restore settings on POSIX systems
if (!simple_io) {
if (tty != nullptr) {
out = stdout;
fout = stdout;
fclose(tty);
tty = nullptr;
}
@ -145,19 +145,19 @@ namespace console {
fflush(stdout);
switch(display) {
case reset:
fprintf(out, ANSI_COLOR_RESET);
fprintf(fout, ANSI_COLOR_RESET);
break;
case prompt:
fprintf(out, ANSI_COLOR_YELLOW);
fprintf(fout, ANSI_COLOR_YELLOW);
break;
case user_input:
fprintf(out, ANSI_BOLD ANSI_COLOR_GREEN);
fprintf(fout, ANSI_BOLD ANSI_COLOR_GREEN);
break;
case error:
fprintf(out, ANSI_BOLD ANSI_COLOR_RED);
fprintf(fout, ANSI_BOLD ANSI_COLOR_RED);
}
current_display = display;
fflush(out);
fflush(fout);
}
}
@ -233,7 +233,7 @@ namespace console {
return;
}
#endif
putc('\b', out);
putc('\b', fout);
}
static int estimateWidth(char32_t codepoint) {
@ -274,7 +274,7 @@ namespace console {
#else
// We can trust expectedWidth if we've got one
if (expectedWidth >= 0 || tty == nullptr) {
fwrite(utf8_codepoint, length, 1, out);
fwrite(utf8_codepoint, length, 1, fout);
return expectedWidth;
}
@ -311,7 +311,7 @@ namespace console {
pop_cursor();
put_codepoint(&ch, 1, 1);
#else
fprintf(out, "\b%c", ch);
fprintf(fout, "\b%c", ch);
#endif
}
@ -353,7 +353,7 @@ namespace console {
}
static bool readline_advanced(std::string & line, bool multiline_input) {
if (out != stdout) {
if (fout != stdout) {
fflush(stdout);
}
@ -364,7 +364,7 @@ namespace console {
char32_t input_char;
while (true) {
fflush(out); // Ensure all output is displayed before waiting for input
fflush(fout); // Ensure all output is displayed before waiting for input
input_char = getchar32();
if (input_char == '\r' || input_char == '\n') {
@ -432,7 +432,7 @@ namespace console {
line.pop_back();
if (last == '\\') {
line += '\n';
fputc('\n', out);
fputc('\n', fout);
has_more = !has_more;
} else {
// llama will just eat the single space, it won't act as a space
@ -447,11 +447,11 @@ namespace console {
has_more = false;
} else {
line += '\n';
fputc('\n', out);
fputc('\n', fout);
}
}
fflush(out);
fflush(fout);
return has_more;
}

View File

@ -338,16 +338,16 @@ public:
resume();
}
void set_prefix(bool prefix) {
void set_prefix(bool val) {
std::lock_guard<std::mutex> lock(mtx);
this->prefix = prefix;
prefix = val;
}
void set_timestamps(bool timestamps) {
void set_timestamps(bool val) {
std::lock_guard<std::mutex> lock(mtx);
this->timestamps = timestamps;
timestamps = val;
}
};

View File

@ -471,12 +471,12 @@ struct my_llama_file {
GGML_ASSERT(ret == 0); // same
}
void read_raw(void * ptr, size_t size) {
if (size == 0) {
void read_raw(void * ptr, size_t size_cur) {
if (size_cur == 0) {
return;
}
errno = 0;
std::size_t ret = std::fread(ptr, size, 1, fp);
std::size_t ret = std::fread(ptr, size_cur, 1, fp);
if (ferror(fp)) {
die_fmt("fread failed: %s", strerror(errno));
}

View File

@ -60,13 +60,6 @@ int main(int argc, char** argv) {
const std::string grammar_filename = argv[1];
const std::string input_filename = argv[2];
// Read the GBNF grammar file
FILE* grammar_file = fopen(grammar_filename.c_str(), "r");
if (!grammar_file) {
fprintf(stdout, "Failed to open grammar file: %s\n", grammar_filename.c_str());
return 1;
}
std::string grammar_str;
{
std::ifstream grammar_file(grammar_filename);

View File

@ -294,7 +294,7 @@ void IMatrixCollector::save_imatrix(int ncall) const {
bool IMatrixCollector::load_imatrix(const char * fname) {
std::ifstream in(fname, std::ios::binary);
if (!in) {
LOG_ERR("%s: failed to open %s\n",__func__, fname);
LOG_ERR("%s: failed to open %s\n", __func__, fname);
return false;
}
int n_entries;
@ -308,7 +308,7 @@ bool IMatrixCollector::load_imatrix(const char * fname) {
std::vector<char> name_as_vec(len+1);
in.read((char *)name_as_vec.data(), len);
if (in.fail()) {
LOG_ERR("%s: failed reading name for entry %d from %s\n",__func__,i+1, fname);
LOG_ERR("%s: failed reading name for entry %d from %s\n", __func__, i + 1, fname);
return false;
}
name_as_vec[len] = 0;
@ -319,7 +319,7 @@ bool IMatrixCollector::load_imatrix(const char * fname) {
int nval;
in.read((char *)&nval, sizeof(nval));
if (in.fail() || nval < 1) {
LOG_ERR("%s: failed reading number of values for entry %d\n",__func__,i);
LOG_ERR("%s: failed reading number of values for entry %d\n", __func__, i);
m_stats = {};
return false;
}
@ -332,15 +332,15 @@ bool IMatrixCollector::load_imatrix(const char * fname) {
std::vector<float> tmp(nval);
in.read((char*)tmp.data(), nval*sizeof(float));
if (in.fail()) {
LOG_ERR("%s: failed reading data for entry %d\n",__func__,i);
LOG_ERR("%s: failed reading data for entry %d\n", __func__, i);
m_stats = {};
return false;
}
// Recreate the state as expected by save_imatrix(), and corerct for weighted sum.
for (int i = 0; i < nval; i++) {
e.values[i] += tmp[i];
e.counts[i] += ncall;
for (int j = 0; j < nval; j++) {
e.values[j] += tmp[j];
e.counts[j] += ncall;
}
e.ncall += ncall;
@ -488,12 +488,10 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
logits.reserve((size_t)n_ctx * n_vocab);
}
for (int i = 0; i < n_chunk; ++i) {
const int start = i * n_ctx;
for (int ich = 0; ich < n_chunk; ++ich) {
const int start = ich * n_ctx;
const int end = start + n_ctx;
std::vector<float> logits;
const auto t_start = std::chrono::high_resolution_clock::now();
// clear the KV cache
@ -537,7 +535,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
const auto t_end = std::chrono::high_resolution_clock::now();
if (i == 0) {
if (ich == 0) {
const float t_total = std::chrono::duration<float>(t_end - t_start).count();
LOG_INF("%s: %.2f seconds per pass - ETA ", __func__, t_total);
int total_seconds = (int)(t_total * n_chunk);
@ -555,7 +553,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first);
count += n_ctx - first - 1;
LOG("[%d]%.4lf,", i + 1, std::exp(nll / count));
LOG("[%d]%.4lf,", ich + 1, std::exp(nll / count));
fflush(stdout);
logits.clear();

View File

@ -462,14 +462,14 @@ int main(int argc, char ** argv) {
}
// tokenize new prefix and suffix
std::vector<llama_token> inp_pfx = common_tokenize(ctx, params.input_prefix, false);
std::vector<llama_token> inp_sfx = common_tokenize(ctx, params.input_suffix, false);
std::vector<llama_token> inp_pfx_cur = common_tokenize(ctx, params.input_prefix, false);
std::vector<llama_token> inp_sfx_cur = common_tokenize(ctx, params.input_suffix, false);
inp_pfx.insert(inp_pfx.begin(), llama_vocab_fim_pre(vocab));
inp_sfx.insert(inp_sfx.begin(), llama_vocab_fim_suf(vocab));
inp_pfx_cur.insert(inp_pfx_cur.begin(), llama_vocab_fim_pre(vocab));
inp_sfx_cur.insert(inp_sfx_cur.begin(), llama_vocab_fim_suf(vocab));
embd_inp = params.spm_infill ? inp_sfx : inp_pfx;
embd_end = params.spm_infill ? inp_pfx : inp_sfx;
embd_inp = params.spm_infill ? inp_sfx_cur : inp_pfx_cur;
embd_end = params.spm_infill ? inp_pfx_cur : inp_sfx_cur;
if (add_bos) {
embd_inp.insert(embd_inp.begin(), llama_vocab_bos(vocab));
}

View File

@ -548,11 +548,11 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
GGML_ASSERT(split_arg.size() <= llama_max_devices());
std::vector<float> tensor_split(llama_max_devices());
for (size_t i = 0; i < llama_max_devices(); ++i) {
if (i < split_arg.size()) {
tensor_split[i] = std::stof(split_arg[i]);
for (size_t is = 0; is < llama_max_devices(); ++is) {
if (is < split_arg.size()) {
tensor_split[is] = std::stof(split_arg[is]);
} else {
tensor_split[i] = 0.0f;
tensor_split[is] = 0.0f;
}
}
params.tensor_split.push_back(tensor_split);

View File

@ -1039,41 +1039,40 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
}
{ // attention
int hidden_size = 4096;
const int d_head = 128;
int n_head = hidden_size/d_head;
int hidden_size_cur = 4096;
int num_query = 96;
if (ctx->minicpmv_version == 2) {
hidden_size = 4096;
n_head = hidden_size/d_head;
hidden_size_cur = 4096;
num_query = 96;
}
else if (ctx->minicpmv_version == 3) {
hidden_size = 3584;
n_head = hidden_size/d_head;
hidden_size_cur = 3584;
num_query = 64;
}
const int d_head_cur = 128;
const int n_head_cur = hidden_size_cur/d_head_cur;
struct ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q), model.mm_model_attn_q_b);
Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head));
Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head_cur));
struct ggml_tensor * K = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_k_w, k), model.mm_model_attn_k_b);
struct ggml_tensor * V = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_v_w, v), model.mm_model_attn_v_b);
// permute
Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_query, batch_size);
Q = ggml_reshape_4d(ctx0, Q, d_head_cur, n_head_cur, num_query, batch_size);
Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
Q = ggml_reshape_3d(ctx0, Q, d_head, num_query, n_head * batch_size);
K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size);
Q = ggml_reshape_3d(ctx0, Q, d_head_cur, num_query, n_head_cur * batch_size);
K = ggml_reshape_4d(ctx0, K, d_head_cur, n_head_cur, num_positions, batch_size);
K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size);
V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size);
K = ggml_reshape_3d(ctx0, K, d_head_cur, num_positions, n_head_cur * batch_size);
V = ggml_reshape_4d(ctx0, V, d_head_cur, n_head_cur, num_positions, batch_size);
V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size);
V = ggml_reshape_3d(ctx0, V, num_positions, d_head_cur, n_head_cur * batch_size);
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
KQ = ggml_soft_max_inplace(ctx0, KQ);
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_query, n_head, batch_size);
KQV = ggml_reshape_4d(ctx0, KQV, d_head_cur, num_query, n_head_cur, batch_size);
KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
KQV = ggml_cont_3d(ctx0, KQV, hidden_size, num_query, batch_size);
KQV = ggml_cont_3d(ctx0, KQV, hidden_size_cur, num_query, batch_size);
embeddings = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_o_w, KQV), model.mm_model_attn_o_b);
}
@ -1113,12 +1112,12 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
struct ggml_context * meta = NULL;
struct gguf_init_params params = {
struct gguf_init_params params_meta = {
/*.no_alloc = */ true,
/*.ctx = */ &meta,
};
struct gguf_context * ctx = gguf_init_from_file(fname, params);
struct gguf_context * ctx = gguf_init_from_file(fname, params_meta);
if (!ctx) {
throw std::runtime_error(format("%s: failed to load CLIP model from %s. Does this file exist?\n", __func__, fname));
}
@ -1310,13 +1309,13 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
// load tensors
{
std::vector<uint8_t> read_buf;
struct ggml_init_params params = {
struct ggml_init_params params_data = {
/*.mem_size =*/ (n_tensors + 1) * ggml_tensor_overhead(),
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true,
};
new_clip->ctx_data = ggml_init(params);
new_clip->ctx_data = ggml_init(params_data);
if (!new_clip->ctx_data) {
LOG_ERR("%s: ggml_init() failed\n", __func__);
clip_free(new_clip);

View File

@ -348,8 +348,8 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
LOG_INF("%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch);
for (int i = 0; i < n_chunk; ++i) {
const int start = i * params.ppl_stride;
for (int ich = 0; ich < n_chunk; ++ich) {
const int start = ich * params.ppl_stride;
const int end = start + calc_chunk;
const int num_batches = (calc_chunk + n_batch - 1) / n_batch;
@ -400,7 +400,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
const auto t_end = std::chrono::high_resolution_clock::now();
if (i == 0) {
if (ich == 0) {
const float t_total = std::chrono::duration<float>(t_end - t_start).count();
LOG_INF("%s: %.2f seconds per pass - ETA ", __func__, t_total);
int total_seconds = (int)(t_total * n_chunk);
@ -427,9 +427,9 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
}
// perplexity is e^(average negative log-likelihood)
if (params.ppl_output_type == 0) {
LOG("[%d]%.4lf,", i + 1, std::exp(nll / count));
LOG("[%d]%.4lf,", ich + 1, std::exp(nll / count));
} else {
LOG("%8d %.4lf\n", i*params.ppl_stride, std::exp(nll / count));
LOG("%8d %.4lf\n", ich*params.ppl_stride, std::exp(nll / count));
}
}
LOG("\n");
@ -659,7 +659,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<float> & batch_logits, int n_batch, int n_vocab) {
int prev_outputs = 0;
for (int i = 0; i < (int) batch.n_tokens; i += n_batch) {
for (int i = 0; i < batch.n_tokens; i += n_batch) {
const int n_tokens = std::min<int>(n_batch, batch.n_tokens - i);
llama_batch batch_view = {
@ -679,8 +679,8 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<
}
int n_outputs = 0;
for (int i = 0; i < n_tokens; ++i) {
n_outputs += batch_view.logits[i] != 0;
for (int iv = 0; iv < n_tokens; ++iv) {
n_outputs += batch_view.logits[iv] != 0;
}
memcpy(batch_logits.data() + size_t(prev_outputs)*n_vocab, llama_get_logits(ctx), size_t(n_outputs)*n_vocab*sizeof(float));
@ -1752,14 +1752,14 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
auto kld_ptr = kld_values.data();
auto p_diff_ptr = p_diff_values.data();
for (int i = 0; i < n_chunk; ++i) {
const int start = i * n_ctx;
for (int ich = 0; ich < n_chunk; ++ich) {
const int start = ich * n_ctx;
const int end = start + n_ctx;
const auto t_start = std::chrono::high_resolution_clock::now();
if (in.read((char *)log_probs_uint16.data(), log_probs_uint16.size()*sizeof(uint16_t)).fail()) {
LOG_ERR("%s: failed reading log-probs for chunk %d\n", __func__, i);
LOG_ERR("%s: failed reading log-probs for chunk %d\n", __func__, ich);
return;
}
@ -1804,7 +1804,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
const auto t_end = std::chrono::high_resolution_clock::now();
if (i == 0) {
if (ich == 0) {
const float t_total = std::chrono::duration<float>(t_end - t_start).count();
LOG_INF("%s: %.2f seconds per pass - ETA ", __func__, t_total);
int total_seconds = (int)(t_total * n_chunk);
@ -1824,7 +1824,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
p_diff_ptr += n_ctx - 1 - first;
kld_ptr += n_ctx - 1 - first;
LOG("%4d", i+1);
LOG("%4d", ich + 1);
auto log_ppl = mean_and_uncertainty(kld.sum_nll, kld.sum_nll2, kld.count);
const double ppl_val = exp(log_ppl.first);

View File

@ -3,3 +3,4 @@ add_executable(${TARGET} run.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_17)
target_compile_options(${TARGET} PRIVATE -Wno-shadow) # TMP

View File

@ -122,9 +122,9 @@ struct slot_params {
samplers.emplace_back(common_sampler_type_to_str(sampler));
}
json lora = json::array();
for (size_t i = 0; i < this->lora.size(); ++i) {
lora.push_back({{"id", i}, {"scale", this->lora[i].scale}});
json json_lora = json::array();
for (size_t i = 0; i < lora.size(); ++i) {
json_lora.push_back({{"id", i}, {"scale", lora[i].scale}});
}
return json {
@ -167,7 +167,7 @@ struct slot_params {
{"speculative.p_min", speculative.p_min},
{"timings_per_token", timings_per_token},
{"post_sampling_probs", post_sampling_probs},
{"lora", lora},
{"lora", json_lora},
};
}
};
@ -1641,7 +1641,7 @@ struct server_context {
llama_context_params cparams_dft;
llama_batch batch = {};
llama_batch batch_main = {};
bool clean_kv_cache = true;
bool add_bos_token = true;
@ -1676,7 +1676,7 @@ struct server_context {
llama_batch_free(slot.batch_spec);
}
llama_batch_free(batch);
llama_batch_free(batch_main);
}
bool load_model(const common_params & params) {
@ -1797,7 +1797,7 @@ struct server_context {
const int32_t n_batch = llama_n_batch(ctx);
// only a single seq_id per token is needed
batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1);
batch_main = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1);
}
metrics.init();
@ -2655,7 +2655,7 @@ struct server_context {
}
// start populating the batch for this iteration
common_batch_clear(batch);
common_batch_clear(batch_main);
// track if given slot can be batched with slots already in the batch
server_slot * slot_batched = nullptr;
@ -2673,9 +2673,9 @@ struct server_context {
continue;
}
slot.i_batch = batch.n_tokens;
slot.i_batch = batch_main.n_tokens;
common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true);
common_batch_add(batch_main, slot.sampled, slot.n_past, { slot.id }, true);
slot.n_past += 1;
@ -2692,7 +2692,7 @@ struct server_context {
int32_t n_ubatch = llama_n_ubatch(ctx);
// next, batch any pending prompts without exceeding n_batch
if (params_base.cont_batching || batch.n_tokens == 0) {
if (params_base.cont_batching || batch_main.n_tokens == 0) {
for (auto & slot : slots) {
// check if we can batch this slot with the previous one
if (slot.is_processing()) {
@ -2858,7 +2858,7 @@ struct server_context {
// non-causal tasks require to fit the entire prompt in the physical batch
if (slot.is_non_causal()) {
// cannot fit the prompt in the current batch - will try next iter
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
if (batch_main.n_tokens + slot.n_prompt_tokens > n_batch) {
continue;
}
}
@ -2878,11 +2878,11 @@ struct server_context {
slot.cache_tokens.resize(slot.n_past);
// add prompt tokens for processing in the current batch
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
while (slot.n_past < slot.n_prompt_tokens && batch_main.n_tokens < n_batch) {
// without pooling, we want to output the embeddings for all the tokens in the batch
const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;
common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd);
common_batch_add(batch_main, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd);
if (slot.params.cache_prompt) {
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
@ -2892,13 +2892,13 @@ struct server_context {
slot.n_past++;
}
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch_main.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
// entire prompt has been processed
if (slot.n_past == slot.n_prompt_tokens) {
slot.state = SLOT_STATE_DONE_PROMPT;
GGML_ASSERT(batch.n_tokens > 0);
GGML_ASSERT(batch_main.n_tokens > 0);
common_sampler_reset(slot.smpl);
@ -2908,27 +2908,27 @@ struct server_context {
}
// extract the logits only for the last token
batch.logits[batch.n_tokens - 1] = true;
batch_main.logits[batch_main.n_tokens - 1] = true;
slot.n_decoded = 0;
slot.i_batch = batch.n_tokens - 1;
slot.i_batch = batch_main.n_tokens - 1;
SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens);
SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch_main.n_tokens);
}
}
if (batch.n_tokens >= n_batch) {
if (batch_main.n_tokens >= n_batch) {
break;
}
}
}
if (batch.n_tokens == 0) {
if (batch_main.n_tokens == 0) {
SRV_WRN("%s", "no tokens to decode\n");
return;
}
SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens);
SRV_DBG("decoding batch, n_tokens = %d\n", batch_main.n_tokens);
if (slot_batched) {
// make sure we're in the right embedding mode
@ -2938,17 +2938,17 @@ struct server_context {
}
// process the created batch of tokens
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
for (int32_t i_batch = 0; i_batch < batch_main.n_tokens; i_batch += n_batch) {
const int32_t n_tokens = std::min(n_batch, batch_main.n_tokens - i_batch);
llama_batch batch_view = {
n_tokens,
batch.token + i,
batch_main.token + i_batch,
nullptr,
batch.pos + i,
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
batch_main.pos + i_batch,
batch_main.n_seq_id + i_batch,
batch_main.seq_id + i_batch,
batch_main.logits + i_batch,
};
const int ret = llama_decode(ctx, batch_view);
@ -2957,7 +2957,7 @@ struct server_context {
if (ret != 0) {
if (n_batch == 1 || ret < 0) {
// if you get here, it means the KV cache is full - try increasing it via the context size
SRV_ERR("failed to decode the batch: KV cache is full - try increasing it via the context size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);
SRV_ERR("failed to decode the batch: KV cache is full - try increasing it via the context size, i_batch = %d, n_batch = %d, ret = %d\n", i_batch, n_batch, ret);
for (auto & slot : slots) {
slot.release();
send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size.");
@ -2967,15 +2967,15 @@ struct server_context {
// retry with half the batch size to try to find a free slot in the KV cache
n_batch /= 2;
i -= n_batch;
i_batch -= n_batch;
SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size - try increasing it via the context size or enable defragmentation, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);
SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size - try increasing it via the context size or enable defragmentation, i_batch = %d, n_batch = %d, ret = %d\n", i_batch, n_batch, ret);
continue; // continue loop of n_batch
}
for (auto & slot : slots) {
if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) {
if (slot.i_batch < (int) i_batch || slot.i_batch >= (int) (i_batch + n_tokens)) {
continue; // continue loop of slots
}
@ -3001,7 +3001,7 @@ struct server_context {
continue; // continue loop of slots
}
const int tok_idx = slot.i_batch - i;
const int tok_idx = slot.i_batch - i_batch;
llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx);
@ -3687,8 +3687,8 @@ int main(int argc, char ** argv) {
} else {
// multiple results (multitask)
json arr = json::array();
for (auto & res : results) {
arr.push_back(res->to_json());
for (auto & result : results) {
arr.push_back(result->to_json());
}
res_ok(res, arr);
}

View File

@ -129,15 +129,15 @@ static llama_tokens tokenize_mixed(const llama_vocab * vocab, const json & json_
if (p.is_string()) {
auto s = p.template get<std::string>();
llama_tokens p;
llama_tokens ids;
if (first) {
p = common_tokenize(vocab, s, add_special, parse_special);
ids = common_tokenize(vocab, s, add_special, parse_special);
first = false;
} else {
p = common_tokenize(vocab, s, false, parse_special);
ids = common_tokenize(vocab, s, false, parse_special);
}
prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
prompt_tokens.insert(prompt_tokens.end(), ids.begin(), ids.end());
} else {
if (first) {
first = false;

View File

@ -544,26 +544,26 @@ int main(int argc, char ** argv) {
for (int is = 0; is < (int) sa.size(); ++is) {
const llama_token id = cur_p->data[is].id;
const int s = sa[is];
const int sd = sa[is];
common_sampler_accept(drafts[s].smpl, id, true);
common_sampler_accept(drafts[sd].smpl, id, true);
drafts[s].tokens.push_back(id);
// save cur_p.data into drafts[s].dists
drafts[s].dists.push_back({cur_p->data, cur_p->data + cur_p->size});
drafts[sd].tokens.push_back(id);
// save cur_p.data into drafts[sd].dists
drafts[sd].dists.push_back({cur_p->data, cur_p->data + cur_p->size});
// add unique drafted tokens to the target batch
drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens);
drafts[sd].i_batch_tgt.push_back(batch_tgt.n_tokens);
common_batch_add(batch_tgt, id, n_past_tgt + i + 1, { s }, true);
common_batch_add(batch_tgt, id, n_past_tgt + i + 1, { sd }, true);
// add the token to the batch for batched decoding with the draft model
drafts[s].i_batch_dft = batch_dft.n_tokens;
drafts[sd].i_batch_dft = batch_dft.n_tokens;
common_batch_add(batch_dft, id, n_past_cur, { s }, true);
common_batch_add(batch_dft, id, n_past_cur, { sd }, true);
if (batch_tgt.n_tokens > n_draft) {
drafts[s].drafting = false;
drafts[sd].drafting = false;
}
}
}