Add support for file load progress reporting callbacks (#434)

* File load progress reporting

* Move llama_progress_handler into llama_context_params

* Renames

* Use seekg to find file size instead

* More correct load progress

* Call progress callback more frequently

* Fix typo
This commit is contained in:
Jed Fox 2023-03-25 01:26:28 -04:00 committed by GitHub
parent 36d07532ef
commit 58e6c9f36f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 39 additions and 10 deletions

View File

@ -275,6 +275,8 @@ struct llama_context_params llama_context_default_params() {
/*.vocab_only =*/ false, /*.vocab_only =*/ false,
/*.use_mlock =*/ false, /*.use_mlock =*/ false,
/*.embedding =*/ false, /*.embedding =*/ false,
/*.progress_callback =*/ nullptr,
/*.progress_callback_user_data =*/ nullptr,
}; };
return result; return result;
@ -290,7 +292,9 @@ static bool llama_model_load(
int n_ctx, int n_ctx,
int n_parts, int n_parts,
ggml_type memory_type, ggml_type memory_type,
bool vocab_only) { bool vocab_only,
llama_progress_callback progress_callback,
void *progress_callback_user_data) {
fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str()); fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
const int64_t t_start_us = ggml_time_us(); const int64_t t_start_us = ggml_time_us();
@ -576,6 +580,10 @@ static bool llama_model_load(
std::vector<uint8_t> tmp; std::vector<uint8_t> tmp;
if (progress_callback) {
progress_callback(0.0, progress_callback_user_data);
}
for (int i = 0; i < n_parts; ++i) { for (int i = 0; i < n_parts; ++i) {
const int part_id = i; const int part_id = i;
//const int part_id = n_parts - i - 1; //const int part_id = n_parts - i - 1;
@ -589,6 +597,10 @@ static bool llama_model_load(
fin = std::ifstream(fname_part, std::ios::binary); fin = std::ifstream(fname_part, std::ios::binary);
fin.rdbuf()->pubsetbuf(f_buf.data(), f_buf.size()); fin.rdbuf()->pubsetbuf(f_buf.data(), f_buf.size());
fin.seekg(0, fin.end);
const size_t file_size = fin.tellg();
fin.seekg(file_offset); fin.seekg(file_offset);
// load weights // load weights
@ -764,6 +776,11 @@ static bool llama_model_load(
model.n_loaded++; model.n_loaded++;
// progress // progress
if (progress_callback) {
double current_file_progress = double(size_t(fin.tellg()) - file_offset) / double(file_size - file_offset);
double current_progress = (double(i) + current_file_progress) / double(n_parts);
progress_callback(current_progress, progress_callback_user_data);
}
if (model.n_loaded % 8 == 0) { if (model.n_loaded % 8 == 0) {
fprintf(stderr, "."); fprintf(stderr, ".");
fflush(stderr); fflush(stderr);
@ -786,6 +803,10 @@ static bool llama_model_load(
lctx.t_load_us = ggml_time_us() - t_start_us; lctx.t_load_us = ggml_time_us() - t_start_us;
if (progress_callback) {
progress_callback(1.0, progress_callback_user_data);
}
return true; return true;
} }
@ -1617,7 +1638,8 @@ struct llama_context * llama_init_from_file(
ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_parts, memory_type, if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_parts, memory_type,
params.vocab_only)) { params.vocab_only, params.progress_callback,
params.progress_callback_user_data)) {
fprintf(stderr, "%s: failed to load model\n", __func__); fprintf(stderr, "%s: failed to load model\n", __func__);
llama_free(ctx); llama_free(ctx);
return nullptr; return nullptr;

View File

@ -45,6 +45,8 @@ extern "C" {
} llama_token_data; } llama_token_data;
typedef void (*llama_progress_callback)(double progress, void *ctx);
struct llama_context_params { struct llama_context_params {
int n_ctx; // text context int n_ctx; // text context
int n_parts; // -1 for default int n_parts; // -1 for default
@ -55,6 +57,11 @@ extern "C" {
bool vocab_only; // only load the vocabulary, no weights bool vocab_only; // only load the vocabulary, no weights
bool use_mlock; // force system to keep model in RAM bool use_mlock; // force system to keep model in RAM
bool embedding; // embedding mode only bool embedding; // embedding mode only
// called with a progress value between 0 and 1, pass NULL to disable
llama_progress_callback progress_callback;
// context pointer passed to the progress callback
void * progress_callback_user_data;
}; };
LLAMA_API struct llama_context_params llama_context_default_params(); LLAMA_API struct llama_context_params llama_context_default_params();