mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-12 03:31:46 +00:00
Add support for file load progress reporting callbacks (#434)
* File load progress reporting * Move llama_progress_handler into llama_context_params * Renames * Use seekg to find file size instead * More correct load progress * Call progress callback more frequently * Fix typo
This commit is contained in:
parent
36d07532ef
commit
58e6c9f36f
26
llama.cpp
26
llama.cpp
@ -275,6 +275,8 @@ struct llama_context_params llama_context_default_params() {
|
|||||||
/*.vocab_only =*/ false,
|
/*.vocab_only =*/ false,
|
||||||
/*.use_mlock =*/ false,
|
/*.use_mlock =*/ false,
|
||||||
/*.embedding =*/ false,
|
/*.embedding =*/ false,
|
||||||
|
/*.progress_callback =*/ nullptr,
|
||||||
|
/*.progress_callback_user_data =*/ nullptr,
|
||||||
};
|
};
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
@ -290,7 +292,9 @@ static bool llama_model_load(
|
|||||||
int n_ctx,
|
int n_ctx,
|
||||||
int n_parts,
|
int n_parts,
|
||||||
ggml_type memory_type,
|
ggml_type memory_type,
|
||||||
bool vocab_only) {
|
bool vocab_only,
|
||||||
|
llama_progress_callback progress_callback,
|
||||||
|
void *progress_callback_user_data) {
|
||||||
fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
|
fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
|
||||||
|
|
||||||
const int64_t t_start_us = ggml_time_us();
|
const int64_t t_start_us = ggml_time_us();
|
||||||
@ -576,6 +580,10 @@ static bool llama_model_load(
|
|||||||
|
|
||||||
std::vector<uint8_t> tmp;
|
std::vector<uint8_t> tmp;
|
||||||
|
|
||||||
|
if (progress_callback) {
|
||||||
|
progress_callback(0.0, progress_callback_user_data);
|
||||||
|
}
|
||||||
|
|
||||||
for (int i = 0; i < n_parts; ++i) {
|
for (int i = 0; i < n_parts; ++i) {
|
||||||
const int part_id = i;
|
const int part_id = i;
|
||||||
//const int part_id = n_parts - i - 1;
|
//const int part_id = n_parts - i - 1;
|
||||||
@ -589,6 +597,10 @@ static bool llama_model_load(
|
|||||||
|
|
||||||
fin = std::ifstream(fname_part, std::ios::binary);
|
fin = std::ifstream(fname_part, std::ios::binary);
|
||||||
fin.rdbuf()->pubsetbuf(f_buf.data(), f_buf.size());
|
fin.rdbuf()->pubsetbuf(f_buf.data(), f_buf.size());
|
||||||
|
|
||||||
|
fin.seekg(0, fin.end);
|
||||||
|
const size_t file_size = fin.tellg();
|
||||||
|
|
||||||
fin.seekg(file_offset);
|
fin.seekg(file_offset);
|
||||||
|
|
||||||
// load weights
|
// load weights
|
||||||
@ -764,6 +776,11 @@ static bool llama_model_load(
|
|||||||
model.n_loaded++;
|
model.n_loaded++;
|
||||||
|
|
||||||
// progress
|
// progress
|
||||||
|
if (progress_callback) {
|
||||||
|
double current_file_progress = double(size_t(fin.tellg()) - file_offset) / double(file_size - file_offset);
|
||||||
|
double current_progress = (double(i) + current_file_progress) / double(n_parts);
|
||||||
|
progress_callback(current_progress, progress_callback_user_data);
|
||||||
|
}
|
||||||
if (model.n_loaded % 8 == 0) {
|
if (model.n_loaded % 8 == 0) {
|
||||||
fprintf(stderr, ".");
|
fprintf(stderr, ".");
|
||||||
fflush(stderr);
|
fflush(stderr);
|
||||||
@ -786,6 +803,10 @@ static bool llama_model_load(
|
|||||||
|
|
||||||
lctx.t_load_us = ggml_time_us() - t_start_us;
|
lctx.t_load_us = ggml_time_us() - t_start_us;
|
||||||
|
|
||||||
|
if (progress_callback) {
|
||||||
|
progress_callback(1.0, progress_callback_user_data);
|
||||||
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1617,7 +1638,8 @@ struct llama_context * llama_init_from_file(
|
|||||||
ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
||||||
|
|
||||||
if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_parts, memory_type,
|
if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_parts, memory_type,
|
||||||
params.vocab_only)) {
|
params.vocab_only, params.progress_callback,
|
||||||
|
params.progress_callback_user_data)) {
|
||||||
fprintf(stderr, "%s: failed to load model\n", __func__);
|
fprintf(stderr, "%s: failed to load model\n", __func__);
|
||||||
llama_free(ctx);
|
llama_free(ctx);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
7
llama.h
7
llama.h
@ -45,6 +45,8 @@ extern "C" {
|
|||||||
|
|
||||||
} llama_token_data;
|
} llama_token_data;
|
||||||
|
|
||||||
|
typedef void (*llama_progress_callback)(double progress, void *ctx);
|
||||||
|
|
||||||
struct llama_context_params {
|
struct llama_context_params {
|
||||||
int n_ctx; // text context
|
int n_ctx; // text context
|
||||||
int n_parts; // -1 for default
|
int n_parts; // -1 for default
|
||||||
@ -55,6 +57,11 @@ extern "C" {
|
|||||||
bool vocab_only; // only load the vocabulary, no weights
|
bool vocab_only; // only load the vocabulary, no weights
|
||||||
bool use_mlock; // force system to keep model in RAM
|
bool use_mlock; // force system to keep model in RAM
|
||||||
bool embedding; // embedding mode only
|
bool embedding; // embedding mode only
|
||||||
|
|
||||||
|
// called with a progress value between 0 and 1, pass NULL to disable
|
||||||
|
llama_progress_callback progress_callback;
|
||||||
|
// context pointer passed to the progress callback
|
||||||
|
void * progress_callback_user_data;
|
||||||
};
|
};
|
||||||
|
|
||||||
LLAMA_API struct llama_context_params llama_context_default_params();
|
LLAMA_API struct llama_context_params llama_context_default_params();
|
||||||
|
Loading…
Reference in New Issue
Block a user