mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-13 04:00:16 +00:00
examplse : de-shadow
ggml-ci
This commit is contained in:
parent
82caffa74e
commit
9a735ae6d8
@ -17,19 +17,19 @@
|
|||||||
|
|
||||||
using json = nlohmann::ordered_json;
|
using json = nlohmann::ordered_json;
|
||||||
|
|
||||||
common_arg & common_arg::set_examples(std::initializer_list<enum llama_example> examples) {
|
common_arg & common_arg::set_examples(std::initializer_list<enum llama_example> vals) {
|
||||||
this->examples = std::move(examples);
|
examples = std::move(vals);
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
common_arg & common_arg::set_excludes(std::initializer_list<enum llama_example> excludes) {
|
common_arg & common_arg::set_excludes(std::initializer_list<enum llama_example> vals) {
|
||||||
this->excludes = std::move(excludes);
|
excludes = std::move(vals);
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
common_arg & common_arg::set_env(const char * env) {
|
common_arg & common_arg::set_env(const char * val) {
|
||||||
help = help + "\n(env: " + env + ")";
|
help = help + "\n(env: " + val + ")";
|
||||||
this->env = env;
|
env = val;
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -46,8 +46,10 @@ bool common_arg::is_exclude(enum llama_example ex) {
|
|||||||
return excludes.find(ex) != excludes.end();
|
return excludes.find(ex) != excludes.end();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool common_arg::get_value_from_env(std::string & output) {
|
bool common_arg::get_value_from_env(std::string & output) const {
|
||||||
if (env == nullptr) return false;
|
if (env == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
char * value = std::getenv(env);
|
char * value = std::getenv(env);
|
||||||
if (value) {
|
if (value) {
|
||||||
output = value;
|
output = value;
|
||||||
@ -56,7 +58,7 @@ bool common_arg::get_value_from_env(std::string & output) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool common_arg::has_value_from_env() {
|
bool common_arg::has_value_from_env() const {
|
||||||
return env != nullptr && std::getenv(env);
|
return env != nullptr && std::getenv(env);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -87,7 +89,7 @@ static std::vector<std::string> break_str_into_lines(std::string input, size_t m
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string common_arg::to_string() {
|
std::string common_arg::to_string() const {
|
||||||
// params for printing to console
|
// params for printing to console
|
||||||
const static int n_leading_spaces = 40;
|
const static int n_leading_spaces = 40;
|
||||||
const static int n_char_per_line_help = 70; // TODO: detect this based on current console
|
const static int n_char_per_line_help = 70; // TODO: detect this based on current console
|
||||||
@ -192,8 +194,6 @@ static std::string get_all_kv_cache_types() {
|
|||||||
//
|
//
|
||||||
|
|
||||||
static bool common_params_parse_ex(int argc, char ** argv, common_params_context & ctx_arg) {
|
static bool common_params_parse_ex(int argc, char ** argv, common_params_context & ctx_arg) {
|
||||||
std::string arg;
|
|
||||||
const std::string arg_prefix = "--";
|
|
||||||
common_params & params = ctx_arg.params;
|
common_params & params = ctx_arg.params;
|
||||||
|
|
||||||
std::unordered_map<std::string, common_arg *> arg_to_options;
|
std::unordered_map<std::string, common_arg *> arg_to_options;
|
||||||
|
12
common/arg.h
12
common/arg.h
@ -53,15 +53,15 @@ struct common_arg {
|
|||||||
void (*handler)(common_params & params, const std::string &, const std::string &)
|
void (*handler)(common_params & params, const std::string &, const std::string &)
|
||||||
) : args(args), value_hint(value_hint), value_hint_2(value_hint_2), help(help), handler_str_str(handler) {}
|
) : args(args), value_hint(value_hint), value_hint_2(value_hint_2), help(help), handler_str_str(handler) {}
|
||||||
|
|
||||||
common_arg & set_examples(std::initializer_list<enum llama_example> examples);
|
common_arg & set_examples(std::initializer_list<enum llama_example> vals);
|
||||||
common_arg & set_excludes(std::initializer_list<enum llama_example> excludes);
|
common_arg & set_excludes(std::initializer_list<enum llama_example> vals);
|
||||||
common_arg & set_env(const char * env);
|
common_arg & set_env(const char * val);
|
||||||
common_arg & set_sparam();
|
common_arg & set_sparam();
|
||||||
bool in_example(enum llama_example ex);
|
bool in_example(enum llama_example ex);
|
||||||
bool is_exclude(enum llama_example ex);
|
bool is_exclude(enum llama_example ex);
|
||||||
bool get_value_from_env(std::string & output);
|
bool get_value_from_env(std::string & output) const;
|
||||||
bool has_value_from_env();
|
bool has_value_from_env() const;
|
||||||
std::string to_string();
|
std::string to_string() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct common_params_context {
|
struct common_params_context {
|
||||||
|
@ -763,10 +763,12 @@ bool fs_create_directory_with_parents(const std::string & path) {
|
|||||||
return true;
|
return true;
|
||||||
#else
|
#else
|
||||||
// if the path already exists, check whether it's a directory
|
// if the path already exists, check whether it's a directory
|
||||||
|
{
|
||||||
struct stat info;
|
struct stat info;
|
||||||
if (stat(path.c_str(), &info) == 0) {
|
if (stat(path.c_str(), &info) == 0) {
|
||||||
return S_ISDIR(info.st_mode);
|
return S_ISDIR(info.st_mode);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
size_t pos_slash = 1; // skip leading slashes for directory creation
|
size_t pos_slash = 1; // skip leading slashes for directory creation
|
||||||
|
|
||||||
@ -796,7 +798,7 @@ bool fs_create_directory_with_parents(const std::string & path) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::string fs_get_cache_directory() {
|
std::string fs_get_cache_directory() {
|
||||||
std::string cache_directory = "";
|
std::string cache_directory;
|
||||||
auto ensure_trailing_slash = [](std::string p) {
|
auto ensure_trailing_slash = [](std::string p) {
|
||||||
// Make sure to add trailing slash
|
// Make sure to add trailing slash
|
||||||
if (p.back() != DIRECTORY_SEPARATOR) {
|
if (p.back() != DIRECTORY_SEPARATOR) {
|
||||||
|
@ -43,7 +43,7 @@ namespace console {
|
|||||||
static bool simple_io = true;
|
static bool simple_io = true;
|
||||||
static display_t current_display = reset;
|
static display_t current_display = reset;
|
||||||
|
|
||||||
static FILE* out = stdout;
|
static FILE* fout = stdout;
|
||||||
|
|
||||||
#if defined (_WIN32)
|
#if defined (_WIN32)
|
||||||
static void* hConsole;
|
static void* hConsole;
|
||||||
@ -110,7 +110,7 @@ namespace console {
|
|||||||
|
|
||||||
tty = fopen("/dev/tty", "w+");
|
tty = fopen("/dev/tty", "w+");
|
||||||
if (tty != nullptr) {
|
if (tty != nullptr) {
|
||||||
out = tty;
|
fout = tty;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -126,7 +126,7 @@ namespace console {
|
|||||||
// Restore settings on POSIX systems
|
// Restore settings on POSIX systems
|
||||||
if (!simple_io) {
|
if (!simple_io) {
|
||||||
if (tty != nullptr) {
|
if (tty != nullptr) {
|
||||||
out = stdout;
|
fout = stdout;
|
||||||
fclose(tty);
|
fclose(tty);
|
||||||
tty = nullptr;
|
tty = nullptr;
|
||||||
}
|
}
|
||||||
@ -145,19 +145,19 @@ namespace console {
|
|||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
switch(display) {
|
switch(display) {
|
||||||
case reset:
|
case reset:
|
||||||
fprintf(out, ANSI_COLOR_RESET);
|
fprintf(fout, ANSI_COLOR_RESET);
|
||||||
break;
|
break;
|
||||||
case prompt:
|
case prompt:
|
||||||
fprintf(out, ANSI_COLOR_YELLOW);
|
fprintf(fout, ANSI_COLOR_YELLOW);
|
||||||
break;
|
break;
|
||||||
case user_input:
|
case user_input:
|
||||||
fprintf(out, ANSI_BOLD ANSI_COLOR_GREEN);
|
fprintf(fout, ANSI_BOLD ANSI_COLOR_GREEN);
|
||||||
break;
|
break;
|
||||||
case error:
|
case error:
|
||||||
fprintf(out, ANSI_BOLD ANSI_COLOR_RED);
|
fprintf(fout, ANSI_BOLD ANSI_COLOR_RED);
|
||||||
}
|
}
|
||||||
current_display = display;
|
current_display = display;
|
||||||
fflush(out);
|
fflush(fout);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -233,7 +233,7 @@ namespace console {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
putc('\b', out);
|
putc('\b', fout);
|
||||||
}
|
}
|
||||||
|
|
||||||
static int estimateWidth(char32_t codepoint) {
|
static int estimateWidth(char32_t codepoint) {
|
||||||
@ -274,7 +274,7 @@ namespace console {
|
|||||||
#else
|
#else
|
||||||
// We can trust expectedWidth if we've got one
|
// We can trust expectedWidth if we've got one
|
||||||
if (expectedWidth >= 0 || tty == nullptr) {
|
if (expectedWidth >= 0 || tty == nullptr) {
|
||||||
fwrite(utf8_codepoint, length, 1, out);
|
fwrite(utf8_codepoint, length, 1, fout);
|
||||||
return expectedWidth;
|
return expectedWidth;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -311,7 +311,7 @@ namespace console {
|
|||||||
pop_cursor();
|
pop_cursor();
|
||||||
put_codepoint(&ch, 1, 1);
|
put_codepoint(&ch, 1, 1);
|
||||||
#else
|
#else
|
||||||
fprintf(out, "\b%c", ch);
|
fprintf(fout, "\b%c", ch);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -353,7 +353,7 @@ namespace console {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static bool readline_advanced(std::string & line, bool multiline_input) {
|
static bool readline_advanced(std::string & line, bool multiline_input) {
|
||||||
if (out != stdout) {
|
if (fout != stdout) {
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -364,7 +364,7 @@ namespace console {
|
|||||||
|
|
||||||
char32_t input_char;
|
char32_t input_char;
|
||||||
while (true) {
|
while (true) {
|
||||||
fflush(out); // Ensure all output is displayed before waiting for input
|
fflush(fout); // Ensure all output is displayed before waiting for input
|
||||||
input_char = getchar32();
|
input_char = getchar32();
|
||||||
|
|
||||||
if (input_char == '\r' || input_char == '\n') {
|
if (input_char == '\r' || input_char == '\n') {
|
||||||
@ -432,7 +432,7 @@ namespace console {
|
|||||||
line.pop_back();
|
line.pop_back();
|
||||||
if (last == '\\') {
|
if (last == '\\') {
|
||||||
line += '\n';
|
line += '\n';
|
||||||
fputc('\n', out);
|
fputc('\n', fout);
|
||||||
has_more = !has_more;
|
has_more = !has_more;
|
||||||
} else {
|
} else {
|
||||||
// llama will just eat the single space, it won't act as a space
|
// llama will just eat the single space, it won't act as a space
|
||||||
@ -447,11 +447,11 @@ namespace console {
|
|||||||
has_more = false;
|
has_more = false;
|
||||||
} else {
|
} else {
|
||||||
line += '\n';
|
line += '\n';
|
||||||
fputc('\n', out);
|
fputc('\n', fout);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fflush(out);
|
fflush(fout);
|
||||||
return has_more;
|
return has_more;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -338,16 +338,16 @@ public:
|
|||||||
resume();
|
resume();
|
||||||
}
|
}
|
||||||
|
|
||||||
void set_prefix(bool prefix) {
|
void set_prefix(bool val) {
|
||||||
std::lock_guard<std::mutex> lock(mtx);
|
std::lock_guard<std::mutex> lock(mtx);
|
||||||
|
|
||||||
this->prefix = prefix;
|
prefix = val;
|
||||||
}
|
}
|
||||||
|
|
||||||
void set_timestamps(bool timestamps) {
|
void set_timestamps(bool val) {
|
||||||
std::lock_guard<std::mutex> lock(mtx);
|
std::lock_guard<std::mutex> lock(mtx);
|
||||||
|
|
||||||
this->timestamps = timestamps;
|
timestamps = val;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -471,12 +471,12 @@ struct my_llama_file {
|
|||||||
GGML_ASSERT(ret == 0); // same
|
GGML_ASSERT(ret == 0); // same
|
||||||
}
|
}
|
||||||
|
|
||||||
void read_raw(void * ptr, size_t size) {
|
void read_raw(void * ptr, size_t size_cur) {
|
||||||
if (size == 0) {
|
if (size_cur == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
errno = 0;
|
errno = 0;
|
||||||
std::size_t ret = std::fread(ptr, size, 1, fp);
|
std::size_t ret = std::fread(ptr, size_cur, 1, fp);
|
||||||
if (ferror(fp)) {
|
if (ferror(fp)) {
|
||||||
die_fmt("fread failed: %s", strerror(errno));
|
die_fmt("fread failed: %s", strerror(errno));
|
||||||
}
|
}
|
||||||
|
@ -60,13 +60,6 @@ int main(int argc, char** argv) {
|
|||||||
const std::string grammar_filename = argv[1];
|
const std::string grammar_filename = argv[1];
|
||||||
const std::string input_filename = argv[2];
|
const std::string input_filename = argv[2];
|
||||||
|
|
||||||
// Read the GBNF grammar file
|
|
||||||
FILE* grammar_file = fopen(grammar_filename.c_str(), "r");
|
|
||||||
if (!grammar_file) {
|
|
||||||
fprintf(stdout, "Failed to open grammar file: %s\n", grammar_filename.c_str());
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string grammar_str;
|
std::string grammar_str;
|
||||||
{
|
{
|
||||||
std::ifstream grammar_file(grammar_filename);
|
std::ifstream grammar_file(grammar_filename);
|
||||||
|
@ -294,7 +294,7 @@ void IMatrixCollector::save_imatrix(int ncall) const {
|
|||||||
bool IMatrixCollector::load_imatrix(const char * fname) {
|
bool IMatrixCollector::load_imatrix(const char * fname) {
|
||||||
std::ifstream in(fname, std::ios::binary);
|
std::ifstream in(fname, std::ios::binary);
|
||||||
if (!in) {
|
if (!in) {
|
||||||
LOG_ERR("%s: failed to open %s\n",__func__, fname);
|
LOG_ERR("%s: failed to open %s\n", __func__, fname);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
int n_entries;
|
int n_entries;
|
||||||
@ -308,7 +308,7 @@ bool IMatrixCollector::load_imatrix(const char * fname) {
|
|||||||
std::vector<char> name_as_vec(len+1);
|
std::vector<char> name_as_vec(len+1);
|
||||||
in.read((char *)name_as_vec.data(), len);
|
in.read((char *)name_as_vec.data(), len);
|
||||||
if (in.fail()) {
|
if (in.fail()) {
|
||||||
LOG_ERR("%s: failed reading name for entry %d from %s\n",__func__,i+1, fname);
|
LOG_ERR("%s: failed reading name for entry %d from %s\n", __func__, i + 1, fname);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
name_as_vec[len] = 0;
|
name_as_vec[len] = 0;
|
||||||
@ -319,7 +319,7 @@ bool IMatrixCollector::load_imatrix(const char * fname) {
|
|||||||
int nval;
|
int nval;
|
||||||
in.read((char *)&nval, sizeof(nval));
|
in.read((char *)&nval, sizeof(nval));
|
||||||
if (in.fail() || nval < 1) {
|
if (in.fail() || nval < 1) {
|
||||||
LOG_ERR("%s: failed reading number of values for entry %d\n",__func__,i);
|
LOG_ERR("%s: failed reading number of values for entry %d\n", __func__, i);
|
||||||
m_stats = {};
|
m_stats = {};
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -332,15 +332,15 @@ bool IMatrixCollector::load_imatrix(const char * fname) {
|
|||||||
std::vector<float> tmp(nval);
|
std::vector<float> tmp(nval);
|
||||||
in.read((char*)tmp.data(), nval*sizeof(float));
|
in.read((char*)tmp.data(), nval*sizeof(float));
|
||||||
if (in.fail()) {
|
if (in.fail()) {
|
||||||
LOG_ERR("%s: failed reading data for entry %d\n",__func__,i);
|
LOG_ERR("%s: failed reading data for entry %d\n", __func__, i);
|
||||||
m_stats = {};
|
m_stats = {};
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Recreate the state as expected by save_imatrix(), and corerct for weighted sum.
|
// Recreate the state as expected by save_imatrix(), and corerct for weighted sum.
|
||||||
for (int i = 0; i < nval; i++) {
|
for (int j = 0; j < nval; j++) {
|
||||||
e.values[i] += tmp[i];
|
e.values[j] += tmp[j];
|
||||||
e.counts[i] += ncall;
|
e.counts[j] += ncall;
|
||||||
}
|
}
|
||||||
e.ncall += ncall;
|
e.ncall += ncall;
|
||||||
|
|
||||||
@ -488,12 +488,10 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
|
|||||||
logits.reserve((size_t)n_ctx * n_vocab);
|
logits.reserve((size_t)n_ctx * n_vocab);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < n_chunk; ++i) {
|
for (int ich = 0; ich < n_chunk; ++ich) {
|
||||||
const int start = i * n_ctx;
|
const int start = ich * n_ctx;
|
||||||
const int end = start + n_ctx;
|
const int end = start + n_ctx;
|
||||||
|
|
||||||
std::vector<float> logits;
|
|
||||||
|
|
||||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||||
|
|
||||||
// clear the KV cache
|
// clear the KV cache
|
||||||
@ -537,7 +535,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
|
|||||||
|
|
||||||
const auto t_end = std::chrono::high_resolution_clock::now();
|
const auto t_end = std::chrono::high_resolution_clock::now();
|
||||||
|
|
||||||
if (i == 0) {
|
if (ich == 0) {
|
||||||
const float t_total = std::chrono::duration<float>(t_end - t_start).count();
|
const float t_total = std::chrono::duration<float>(t_end - t_start).count();
|
||||||
LOG_INF("%s: %.2f seconds per pass - ETA ", __func__, t_total);
|
LOG_INF("%s: %.2f seconds per pass - ETA ", __func__, t_total);
|
||||||
int total_seconds = (int)(t_total * n_chunk);
|
int total_seconds = (int)(t_total * n_chunk);
|
||||||
@ -555,7 +553,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
|
|||||||
workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first);
|
workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first);
|
||||||
count += n_ctx - first - 1;
|
count += n_ctx - first - 1;
|
||||||
|
|
||||||
LOG("[%d]%.4lf,", i + 1, std::exp(nll / count));
|
LOG("[%d]%.4lf,", ich + 1, std::exp(nll / count));
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
|
|
||||||
logits.clear();
|
logits.clear();
|
||||||
|
@ -462,14 +462,14 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// tokenize new prefix and suffix
|
// tokenize new prefix and suffix
|
||||||
std::vector<llama_token> inp_pfx = common_tokenize(ctx, params.input_prefix, false);
|
std::vector<llama_token> inp_pfx_cur = common_tokenize(ctx, params.input_prefix, false);
|
||||||
std::vector<llama_token> inp_sfx = common_tokenize(ctx, params.input_suffix, false);
|
std::vector<llama_token> inp_sfx_cur = common_tokenize(ctx, params.input_suffix, false);
|
||||||
|
|
||||||
inp_pfx.insert(inp_pfx.begin(), llama_vocab_fim_pre(vocab));
|
inp_pfx_cur.insert(inp_pfx_cur.begin(), llama_vocab_fim_pre(vocab));
|
||||||
inp_sfx.insert(inp_sfx.begin(), llama_vocab_fim_suf(vocab));
|
inp_sfx_cur.insert(inp_sfx_cur.begin(), llama_vocab_fim_suf(vocab));
|
||||||
|
|
||||||
embd_inp = params.spm_infill ? inp_sfx : inp_pfx;
|
embd_inp = params.spm_infill ? inp_sfx_cur : inp_pfx_cur;
|
||||||
embd_end = params.spm_infill ? inp_pfx : inp_sfx;
|
embd_end = params.spm_infill ? inp_pfx_cur : inp_sfx_cur;
|
||||||
if (add_bos) {
|
if (add_bos) {
|
||||||
embd_inp.insert(embd_inp.begin(), llama_vocab_bos(vocab));
|
embd_inp.insert(embd_inp.begin(), llama_vocab_bos(vocab));
|
||||||
}
|
}
|
||||||
|
@ -548,11 +548,11 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
|
|||||||
GGML_ASSERT(split_arg.size() <= llama_max_devices());
|
GGML_ASSERT(split_arg.size() <= llama_max_devices());
|
||||||
|
|
||||||
std::vector<float> tensor_split(llama_max_devices());
|
std::vector<float> tensor_split(llama_max_devices());
|
||||||
for (size_t i = 0; i < llama_max_devices(); ++i) {
|
for (size_t is = 0; is < llama_max_devices(); ++is) {
|
||||||
if (i < split_arg.size()) {
|
if (is < split_arg.size()) {
|
||||||
tensor_split[i] = std::stof(split_arg[i]);
|
tensor_split[is] = std::stof(split_arg[is]);
|
||||||
} else {
|
} else {
|
||||||
tensor_split[i] = 0.0f;
|
tensor_split[is] = 0.0f;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
params.tensor_split.push_back(tensor_split);
|
params.tensor_split.push_back(tensor_split);
|
||||||
|
@ -1039,41 +1039,40 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
|||||||
}
|
}
|
||||||
|
|
||||||
{ // attention
|
{ // attention
|
||||||
int hidden_size = 4096;
|
int hidden_size_cur = 4096;
|
||||||
const int d_head = 128;
|
|
||||||
int n_head = hidden_size/d_head;
|
|
||||||
int num_query = 96;
|
int num_query = 96;
|
||||||
if (ctx->minicpmv_version == 2) {
|
if (ctx->minicpmv_version == 2) {
|
||||||
hidden_size = 4096;
|
hidden_size_cur = 4096;
|
||||||
n_head = hidden_size/d_head;
|
|
||||||
num_query = 96;
|
num_query = 96;
|
||||||
}
|
}
|
||||||
else if (ctx->minicpmv_version == 3) {
|
else if (ctx->minicpmv_version == 3) {
|
||||||
hidden_size = 3584;
|
hidden_size_cur = 3584;
|
||||||
n_head = hidden_size/d_head;
|
|
||||||
num_query = 64;
|
num_query = 64;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const int d_head_cur = 128;
|
||||||
|
const int n_head_cur = hidden_size_cur/d_head_cur;
|
||||||
|
|
||||||
struct ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q), model.mm_model_attn_q_b);
|
struct ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q), model.mm_model_attn_q_b);
|
||||||
Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head));
|
Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head_cur));
|
||||||
struct ggml_tensor * K = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_k_w, k), model.mm_model_attn_k_b);
|
struct ggml_tensor * K = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_k_w, k), model.mm_model_attn_k_b);
|
||||||
struct ggml_tensor * V = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_v_w, v), model.mm_model_attn_v_b);
|
struct ggml_tensor * V = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_v_w, v), model.mm_model_attn_v_b);
|
||||||
// permute
|
// permute
|
||||||
Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_query, batch_size);
|
Q = ggml_reshape_4d(ctx0, Q, d_head_cur, n_head_cur, num_query, batch_size);
|
||||||
Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
|
Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
|
||||||
Q = ggml_reshape_3d(ctx0, Q, d_head, num_query, n_head * batch_size);
|
Q = ggml_reshape_3d(ctx0, Q, d_head_cur, num_query, n_head_cur * batch_size);
|
||||||
K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size);
|
K = ggml_reshape_4d(ctx0, K, d_head_cur, n_head_cur, num_positions, batch_size);
|
||||||
K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
|
K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
|
||||||
K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size);
|
K = ggml_reshape_3d(ctx0, K, d_head_cur, num_positions, n_head_cur * batch_size);
|
||||||
V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size);
|
V = ggml_reshape_4d(ctx0, V, d_head_cur, n_head_cur, num_positions, batch_size);
|
||||||
V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
|
V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
|
||||||
V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size);
|
V = ggml_reshape_3d(ctx0, V, num_positions, d_head_cur, n_head_cur * batch_size);
|
||||||
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
||||||
KQ = ggml_soft_max_inplace(ctx0, KQ);
|
KQ = ggml_soft_max_inplace(ctx0, KQ);
|
||||||
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
|
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
|
||||||
KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_query, n_head, batch_size);
|
KQV = ggml_reshape_4d(ctx0, KQV, d_head_cur, num_query, n_head_cur, batch_size);
|
||||||
KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
||||||
KQV = ggml_cont_3d(ctx0, KQV, hidden_size, num_query, batch_size);
|
KQV = ggml_cont_3d(ctx0, KQV, hidden_size_cur, num_query, batch_size);
|
||||||
|
|
||||||
embeddings = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_o_w, KQV), model.mm_model_attn_o_b);
|
embeddings = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_o_w, KQV), model.mm_model_attn_o_b);
|
||||||
}
|
}
|
||||||
@ -1113,12 +1112,12 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
|||||||
struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
||||||
struct ggml_context * meta = NULL;
|
struct ggml_context * meta = NULL;
|
||||||
|
|
||||||
struct gguf_init_params params = {
|
struct gguf_init_params params_meta = {
|
||||||
/*.no_alloc = */ true,
|
/*.no_alloc = */ true,
|
||||||
/*.ctx = */ &meta,
|
/*.ctx = */ &meta,
|
||||||
};
|
};
|
||||||
|
|
||||||
struct gguf_context * ctx = gguf_init_from_file(fname, params);
|
struct gguf_context * ctx = gguf_init_from_file(fname, params_meta);
|
||||||
if (!ctx) {
|
if (!ctx) {
|
||||||
throw std::runtime_error(format("%s: failed to load CLIP model from %s. Does this file exist?\n", __func__, fname));
|
throw std::runtime_error(format("%s: failed to load CLIP model from %s. Does this file exist?\n", __func__, fname));
|
||||||
}
|
}
|
||||||
@ -1310,13 +1309,13 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
|||||||
// load tensors
|
// load tensors
|
||||||
{
|
{
|
||||||
std::vector<uint8_t> read_buf;
|
std::vector<uint8_t> read_buf;
|
||||||
struct ggml_init_params params = {
|
struct ggml_init_params params_data = {
|
||||||
/*.mem_size =*/ (n_tensors + 1) * ggml_tensor_overhead(),
|
/*.mem_size =*/ (n_tensors + 1) * ggml_tensor_overhead(),
|
||||||
/*.mem_buffer =*/ NULL,
|
/*.mem_buffer =*/ NULL,
|
||||||
/*.no_alloc =*/ true,
|
/*.no_alloc =*/ true,
|
||||||
};
|
};
|
||||||
|
|
||||||
new_clip->ctx_data = ggml_init(params);
|
new_clip->ctx_data = ggml_init(params_data);
|
||||||
if (!new_clip->ctx_data) {
|
if (!new_clip->ctx_data) {
|
||||||
LOG_ERR("%s: ggml_init() failed\n", __func__);
|
LOG_ERR("%s: ggml_init() failed\n", __func__);
|
||||||
clip_free(new_clip);
|
clip_free(new_clip);
|
||||||
|
@ -348,8 +348,8 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
|
|||||||
|
|
||||||
LOG_INF("%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch);
|
LOG_INF("%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch);
|
||||||
|
|
||||||
for (int i = 0; i < n_chunk; ++i) {
|
for (int ich = 0; ich < n_chunk; ++ich) {
|
||||||
const int start = i * params.ppl_stride;
|
const int start = ich * params.ppl_stride;
|
||||||
const int end = start + calc_chunk;
|
const int end = start + calc_chunk;
|
||||||
|
|
||||||
const int num_batches = (calc_chunk + n_batch - 1) / n_batch;
|
const int num_batches = (calc_chunk + n_batch - 1) / n_batch;
|
||||||
@ -400,7 +400,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
|
|||||||
|
|
||||||
const auto t_end = std::chrono::high_resolution_clock::now();
|
const auto t_end = std::chrono::high_resolution_clock::now();
|
||||||
|
|
||||||
if (i == 0) {
|
if (ich == 0) {
|
||||||
const float t_total = std::chrono::duration<float>(t_end - t_start).count();
|
const float t_total = std::chrono::duration<float>(t_end - t_start).count();
|
||||||
LOG_INF("%s: %.2f seconds per pass - ETA ", __func__, t_total);
|
LOG_INF("%s: %.2f seconds per pass - ETA ", __func__, t_total);
|
||||||
int total_seconds = (int)(t_total * n_chunk);
|
int total_seconds = (int)(t_total * n_chunk);
|
||||||
@ -427,9 +427,9 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
|
|||||||
}
|
}
|
||||||
// perplexity is e^(average negative log-likelihood)
|
// perplexity is e^(average negative log-likelihood)
|
||||||
if (params.ppl_output_type == 0) {
|
if (params.ppl_output_type == 0) {
|
||||||
LOG("[%d]%.4lf,", i + 1, std::exp(nll / count));
|
LOG("[%d]%.4lf,", ich + 1, std::exp(nll / count));
|
||||||
} else {
|
} else {
|
||||||
LOG("%8d %.4lf\n", i*params.ppl_stride, std::exp(nll / count));
|
LOG("%8d %.4lf\n", ich*params.ppl_stride, std::exp(nll / count));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
LOG("\n");
|
LOG("\n");
|
||||||
@ -659,7 +659,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
|
|||||||
|
|
||||||
static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<float> & batch_logits, int n_batch, int n_vocab) {
|
static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<float> & batch_logits, int n_batch, int n_vocab) {
|
||||||
int prev_outputs = 0;
|
int prev_outputs = 0;
|
||||||
for (int i = 0; i < (int) batch.n_tokens; i += n_batch) {
|
for (int i = 0; i < batch.n_tokens; i += n_batch) {
|
||||||
const int n_tokens = std::min<int>(n_batch, batch.n_tokens - i);
|
const int n_tokens = std::min<int>(n_batch, batch.n_tokens - i);
|
||||||
|
|
||||||
llama_batch batch_view = {
|
llama_batch batch_view = {
|
||||||
@ -679,8 +679,8 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<
|
|||||||
}
|
}
|
||||||
|
|
||||||
int n_outputs = 0;
|
int n_outputs = 0;
|
||||||
for (int i = 0; i < n_tokens; ++i) {
|
for (int iv = 0; iv < n_tokens; ++iv) {
|
||||||
n_outputs += batch_view.logits[i] != 0;
|
n_outputs += batch_view.logits[iv] != 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
memcpy(batch_logits.data() + size_t(prev_outputs)*n_vocab, llama_get_logits(ctx), size_t(n_outputs)*n_vocab*sizeof(float));
|
memcpy(batch_logits.data() + size_t(prev_outputs)*n_vocab, llama_get_logits(ctx), size_t(n_outputs)*n_vocab*sizeof(float));
|
||||||
@ -1752,14 +1752,14 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
|
|||||||
auto kld_ptr = kld_values.data();
|
auto kld_ptr = kld_values.data();
|
||||||
auto p_diff_ptr = p_diff_values.data();
|
auto p_diff_ptr = p_diff_values.data();
|
||||||
|
|
||||||
for (int i = 0; i < n_chunk; ++i) {
|
for (int ich = 0; ich < n_chunk; ++ich) {
|
||||||
const int start = i * n_ctx;
|
const int start = ich * n_ctx;
|
||||||
const int end = start + n_ctx;
|
const int end = start + n_ctx;
|
||||||
|
|
||||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||||
|
|
||||||
if (in.read((char *)log_probs_uint16.data(), log_probs_uint16.size()*sizeof(uint16_t)).fail()) {
|
if (in.read((char *)log_probs_uint16.data(), log_probs_uint16.size()*sizeof(uint16_t)).fail()) {
|
||||||
LOG_ERR("%s: failed reading log-probs for chunk %d\n", __func__, i);
|
LOG_ERR("%s: failed reading log-probs for chunk %d\n", __func__, ich);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1804,7 +1804,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
|
|||||||
|
|
||||||
const auto t_end = std::chrono::high_resolution_clock::now();
|
const auto t_end = std::chrono::high_resolution_clock::now();
|
||||||
|
|
||||||
if (i == 0) {
|
if (ich == 0) {
|
||||||
const float t_total = std::chrono::duration<float>(t_end - t_start).count();
|
const float t_total = std::chrono::duration<float>(t_end - t_start).count();
|
||||||
LOG_INF("%s: %.2f seconds per pass - ETA ", __func__, t_total);
|
LOG_INF("%s: %.2f seconds per pass - ETA ", __func__, t_total);
|
||||||
int total_seconds = (int)(t_total * n_chunk);
|
int total_seconds = (int)(t_total * n_chunk);
|
||||||
@ -1824,7 +1824,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
|
|||||||
p_diff_ptr += n_ctx - 1 - first;
|
p_diff_ptr += n_ctx - 1 - first;
|
||||||
kld_ptr += n_ctx - 1 - first;
|
kld_ptr += n_ctx - 1 - first;
|
||||||
|
|
||||||
LOG("%4d", i+1);
|
LOG("%4d", ich + 1);
|
||||||
|
|
||||||
auto log_ppl = mean_and_uncertainty(kld.sum_nll, kld.sum_nll2, kld.count);
|
auto log_ppl = mean_and_uncertainty(kld.sum_nll, kld.sum_nll2, kld.count);
|
||||||
const double ppl_val = exp(log_ppl.first);
|
const double ppl_val = exp(log_ppl.first);
|
||||||
|
@ -3,3 +3,4 @@ add_executable(${TARGET} run.cpp)
|
|||||||
install(TARGETS ${TARGET} RUNTIME)
|
install(TARGETS ${TARGET} RUNTIME)
|
||||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||||
|
target_compile_options(${TARGET} PRIVATE -Wno-shadow) # TMP
|
||||||
|
@ -122,9 +122,9 @@ struct slot_params {
|
|||||||
samplers.emplace_back(common_sampler_type_to_str(sampler));
|
samplers.emplace_back(common_sampler_type_to_str(sampler));
|
||||||
}
|
}
|
||||||
|
|
||||||
json lora = json::array();
|
json json_lora = json::array();
|
||||||
for (size_t i = 0; i < this->lora.size(); ++i) {
|
for (size_t i = 0; i < lora.size(); ++i) {
|
||||||
lora.push_back({{"id", i}, {"scale", this->lora[i].scale}});
|
json_lora.push_back({{"id", i}, {"scale", lora[i].scale}});
|
||||||
}
|
}
|
||||||
|
|
||||||
return json {
|
return json {
|
||||||
@ -167,7 +167,7 @@ struct slot_params {
|
|||||||
{"speculative.p_min", speculative.p_min},
|
{"speculative.p_min", speculative.p_min},
|
||||||
{"timings_per_token", timings_per_token},
|
{"timings_per_token", timings_per_token},
|
||||||
{"post_sampling_probs", post_sampling_probs},
|
{"post_sampling_probs", post_sampling_probs},
|
||||||
{"lora", lora},
|
{"lora", json_lora},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -1641,7 +1641,7 @@ struct server_context {
|
|||||||
|
|
||||||
llama_context_params cparams_dft;
|
llama_context_params cparams_dft;
|
||||||
|
|
||||||
llama_batch batch = {};
|
llama_batch batch_main = {};
|
||||||
|
|
||||||
bool clean_kv_cache = true;
|
bool clean_kv_cache = true;
|
||||||
bool add_bos_token = true;
|
bool add_bos_token = true;
|
||||||
@ -1676,7 +1676,7 @@ struct server_context {
|
|||||||
llama_batch_free(slot.batch_spec);
|
llama_batch_free(slot.batch_spec);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_batch_free(batch);
|
llama_batch_free(batch_main);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool load_model(const common_params & params) {
|
bool load_model(const common_params & params) {
|
||||||
@ -1797,7 +1797,7 @@ struct server_context {
|
|||||||
const int32_t n_batch = llama_n_batch(ctx);
|
const int32_t n_batch = llama_n_batch(ctx);
|
||||||
|
|
||||||
// only a single seq_id per token is needed
|
// only a single seq_id per token is needed
|
||||||
batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1);
|
batch_main = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
metrics.init();
|
metrics.init();
|
||||||
@ -2655,7 +2655,7 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// start populating the batch for this iteration
|
// start populating the batch for this iteration
|
||||||
common_batch_clear(batch);
|
common_batch_clear(batch_main);
|
||||||
|
|
||||||
// track if given slot can be batched with slots already in the batch
|
// track if given slot can be batched with slots already in the batch
|
||||||
server_slot * slot_batched = nullptr;
|
server_slot * slot_batched = nullptr;
|
||||||
@ -2673,9 +2673,9 @@ struct server_context {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
slot.i_batch = batch.n_tokens;
|
slot.i_batch = batch_main.n_tokens;
|
||||||
|
|
||||||
common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true);
|
common_batch_add(batch_main, slot.sampled, slot.n_past, { slot.id }, true);
|
||||||
|
|
||||||
slot.n_past += 1;
|
slot.n_past += 1;
|
||||||
|
|
||||||
@ -2692,7 +2692,7 @@ struct server_context {
|
|||||||
int32_t n_ubatch = llama_n_ubatch(ctx);
|
int32_t n_ubatch = llama_n_ubatch(ctx);
|
||||||
|
|
||||||
// next, batch any pending prompts without exceeding n_batch
|
// next, batch any pending prompts without exceeding n_batch
|
||||||
if (params_base.cont_batching || batch.n_tokens == 0) {
|
if (params_base.cont_batching || batch_main.n_tokens == 0) {
|
||||||
for (auto & slot : slots) {
|
for (auto & slot : slots) {
|
||||||
// check if we can batch this slot with the previous one
|
// check if we can batch this slot with the previous one
|
||||||
if (slot.is_processing()) {
|
if (slot.is_processing()) {
|
||||||
@ -2858,7 +2858,7 @@ struct server_context {
|
|||||||
// non-causal tasks require to fit the entire prompt in the physical batch
|
// non-causal tasks require to fit the entire prompt in the physical batch
|
||||||
if (slot.is_non_causal()) {
|
if (slot.is_non_causal()) {
|
||||||
// cannot fit the prompt in the current batch - will try next iter
|
// cannot fit the prompt in the current batch - will try next iter
|
||||||
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
|
if (batch_main.n_tokens + slot.n_prompt_tokens > n_batch) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -2878,11 +2878,11 @@ struct server_context {
|
|||||||
slot.cache_tokens.resize(slot.n_past);
|
slot.cache_tokens.resize(slot.n_past);
|
||||||
|
|
||||||
// add prompt tokens for processing in the current batch
|
// add prompt tokens for processing in the current batch
|
||||||
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
|
while (slot.n_past < slot.n_prompt_tokens && batch_main.n_tokens < n_batch) {
|
||||||
// without pooling, we want to output the embeddings for all the tokens in the batch
|
// without pooling, we want to output the embeddings for all the tokens in the batch
|
||||||
const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;
|
const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;
|
||||||
|
|
||||||
common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd);
|
common_batch_add(batch_main, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd);
|
||||||
|
|
||||||
if (slot.params.cache_prompt) {
|
if (slot.params.cache_prompt) {
|
||||||
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
|
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
|
||||||
@ -2892,13 +2892,13 @@ struct server_context {
|
|||||||
slot.n_past++;
|
slot.n_past++;
|
||||||
}
|
}
|
||||||
|
|
||||||
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
|
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch_main.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
|
||||||
|
|
||||||
// entire prompt has been processed
|
// entire prompt has been processed
|
||||||
if (slot.n_past == slot.n_prompt_tokens) {
|
if (slot.n_past == slot.n_prompt_tokens) {
|
||||||
slot.state = SLOT_STATE_DONE_PROMPT;
|
slot.state = SLOT_STATE_DONE_PROMPT;
|
||||||
|
|
||||||
GGML_ASSERT(batch.n_tokens > 0);
|
GGML_ASSERT(batch_main.n_tokens > 0);
|
||||||
|
|
||||||
common_sampler_reset(slot.smpl);
|
common_sampler_reset(slot.smpl);
|
||||||
|
|
||||||
@ -2908,27 +2908,27 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// extract the logits only for the last token
|
// extract the logits only for the last token
|
||||||
batch.logits[batch.n_tokens - 1] = true;
|
batch_main.logits[batch_main.n_tokens - 1] = true;
|
||||||
|
|
||||||
slot.n_decoded = 0;
|
slot.n_decoded = 0;
|
||||||
slot.i_batch = batch.n_tokens - 1;
|
slot.i_batch = batch_main.n_tokens - 1;
|
||||||
|
|
||||||
SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens);
|
SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch_main.n_tokens);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (batch.n_tokens >= n_batch) {
|
if (batch_main.n_tokens >= n_batch) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (batch.n_tokens == 0) {
|
if (batch_main.n_tokens == 0) {
|
||||||
SRV_WRN("%s", "no tokens to decode\n");
|
SRV_WRN("%s", "no tokens to decode\n");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens);
|
SRV_DBG("decoding batch, n_tokens = %d\n", batch_main.n_tokens);
|
||||||
|
|
||||||
if (slot_batched) {
|
if (slot_batched) {
|
||||||
// make sure we're in the right embedding mode
|
// make sure we're in the right embedding mode
|
||||||
@ -2938,17 +2938,17 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// process the created batch of tokens
|
// process the created batch of tokens
|
||||||
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
|
for (int32_t i_batch = 0; i_batch < batch_main.n_tokens; i_batch += n_batch) {
|
||||||
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
|
const int32_t n_tokens = std::min(n_batch, batch_main.n_tokens - i_batch);
|
||||||
|
|
||||||
llama_batch batch_view = {
|
llama_batch batch_view = {
|
||||||
n_tokens,
|
n_tokens,
|
||||||
batch.token + i,
|
batch_main.token + i_batch,
|
||||||
nullptr,
|
nullptr,
|
||||||
batch.pos + i,
|
batch_main.pos + i_batch,
|
||||||
batch.n_seq_id + i,
|
batch_main.n_seq_id + i_batch,
|
||||||
batch.seq_id + i,
|
batch_main.seq_id + i_batch,
|
||||||
batch.logits + i,
|
batch_main.logits + i_batch,
|
||||||
};
|
};
|
||||||
|
|
||||||
const int ret = llama_decode(ctx, batch_view);
|
const int ret = llama_decode(ctx, batch_view);
|
||||||
@ -2957,7 +2957,7 @@ struct server_context {
|
|||||||
if (ret != 0) {
|
if (ret != 0) {
|
||||||
if (n_batch == 1 || ret < 0) {
|
if (n_batch == 1 || ret < 0) {
|
||||||
// if you get here, it means the KV cache is full - try increasing it via the context size
|
// if you get here, it means the KV cache is full - try increasing it via the context size
|
||||||
SRV_ERR("failed to decode the batch: KV cache is full - try increasing it via the context size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);
|
SRV_ERR("failed to decode the batch: KV cache is full - try increasing it via the context size, i_batch = %d, n_batch = %d, ret = %d\n", i_batch, n_batch, ret);
|
||||||
for (auto & slot : slots) {
|
for (auto & slot : slots) {
|
||||||
slot.release();
|
slot.release();
|
||||||
send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size.");
|
send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size.");
|
||||||
@ -2967,15 +2967,15 @@ struct server_context {
|
|||||||
|
|
||||||
// retry with half the batch size to try to find a free slot in the KV cache
|
// retry with half the batch size to try to find a free slot in the KV cache
|
||||||
n_batch /= 2;
|
n_batch /= 2;
|
||||||
i -= n_batch;
|
i_batch -= n_batch;
|
||||||
|
|
||||||
SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size - try increasing it via the context size or enable defragmentation, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);
|
SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size - try increasing it via the context size or enable defragmentation, i_batch = %d, n_batch = %d, ret = %d\n", i_batch, n_batch, ret);
|
||||||
|
|
||||||
continue; // continue loop of n_batch
|
continue; // continue loop of n_batch
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto & slot : slots) {
|
for (auto & slot : slots) {
|
||||||
if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) {
|
if (slot.i_batch < (int) i_batch || slot.i_batch >= (int) (i_batch + n_tokens)) {
|
||||||
continue; // continue loop of slots
|
continue; // continue loop of slots
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3001,7 +3001,7 @@ struct server_context {
|
|||||||
continue; // continue loop of slots
|
continue; // continue loop of slots
|
||||||
}
|
}
|
||||||
|
|
||||||
const int tok_idx = slot.i_batch - i;
|
const int tok_idx = slot.i_batch - i_batch;
|
||||||
|
|
||||||
llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx);
|
llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx);
|
||||||
|
|
||||||
@ -3687,8 +3687,8 @@ int main(int argc, char ** argv) {
|
|||||||
} else {
|
} else {
|
||||||
// multiple results (multitask)
|
// multiple results (multitask)
|
||||||
json arr = json::array();
|
json arr = json::array();
|
||||||
for (auto & res : results) {
|
for (auto & result : results) {
|
||||||
arr.push_back(res->to_json());
|
arr.push_back(result->to_json());
|
||||||
}
|
}
|
||||||
res_ok(res, arr);
|
res_ok(res, arr);
|
||||||
}
|
}
|
||||||
|
@ -129,15 +129,15 @@ static llama_tokens tokenize_mixed(const llama_vocab * vocab, const json & json_
|
|||||||
if (p.is_string()) {
|
if (p.is_string()) {
|
||||||
auto s = p.template get<std::string>();
|
auto s = p.template get<std::string>();
|
||||||
|
|
||||||
llama_tokens p;
|
llama_tokens ids;
|
||||||
if (first) {
|
if (first) {
|
||||||
p = common_tokenize(vocab, s, add_special, parse_special);
|
ids = common_tokenize(vocab, s, add_special, parse_special);
|
||||||
first = false;
|
first = false;
|
||||||
} else {
|
} else {
|
||||||
p = common_tokenize(vocab, s, false, parse_special);
|
ids = common_tokenize(vocab, s, false, parse_special);
|
||||||
}
|
}
|
||||||
|
|
||||||
prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
|
prompt_tokens.insert(prompt_tokens.end(), ids.begin(), ids.end());
|
||||||
} else {
|
} else {
|
||||||
if (first) {
|
if (first) {
|
||||||
first = false;
|
first = false;
|
||||||
|
@ -544,26 +544,26 @@ int main(int argc, char ** argv) {
|
|||||||
for (int is = 0; is < (int) sa.size(); ++is) {
|
for (int is = 0; is < (int) sa.size(); ++is) {
|
||||||
const llama_token id = cur_p->data[is].id;
|
const llama_token id = cur_p->data[is].id;
|
||||||
|
|
||||||
const int s = sa[is];
|
const int sd = sa[is];
|
||||||
|
|
||||||
common_sampler_accept(drafts[s].smpl, id, true);
|
common_sampler_accept(drafts[sd].smpl, id, true);
|
||||||
|
|
||||||
drafts[s].tokens.push_back(id);
|
drafts[sd].tokens.push_back(id);
|
||||||
// save cur_p.data into drafts[s].dists
|
// save cur_p.data into drafts[sd].dists
|
||||||
drafts[s].dists.push_back({cur_p->data, cur_p->data + cur_p->size});
|
drafts[sd].dists.push_back({cur_p->data, cur_p->data + cur_p->size});
|
||||||
|
|
||||||
// add unique drafted tokens to the target batch
|
// add unique drafted tokens to the target batch
|
||||||
drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens);
|
drafts[sd].i_batch_tgt.push_back(batch_tgt.n_tokens);
|
||||||
|
|
||||||
common_batch_add(batch_tgt, id, n_past_tgt + i + 1, { s }, true);
|
common_batch_add(batch_tgt, id, n_past_tgt + i + 1, { sd }, true);
|
||||||
|
|
||||||
// add the token to the batch for batched decoding with the draft model
|
// add the token to the batch for batched decoding with the draft model
|
||||||
drafts[s].i_batch_dft = batch_dft.n_tokens;
|
drafts[sd].i_batch_dft = batch_dft.n_tokens;
|
||||||
|
|
||||||
common_batch_add(batch_dft, id, n_past_cur, { s }, true);
|
common_batch_add(batch_dft, id, n_past_cur, { sd }, true);
|
||||||
|
|
||||||
if (batch_tgt.n_tokens > n_draft) {
|
if (batch_tgt.n_tokens > n_draft) {
|
||||||
drafts[s].drafting = false;
|
drafts[sd].drafting = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user