mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 03:44:35 +00:00
lora: load to devide buft
This commit is contained in:
parent
213701b51a
commit
67c5e14d06
@ -2063,14 +2063,8 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
|
||||
for (unsigned int i = 0; i < params.lora_adapter.size(); ++i) {
|
||||
const std::string & lora_adapter = std::get<0>(params.lora_adapter[i]);
|
||||
float lora_scale = std::get<1>(params.lora_adapter[i]);
|
||||
int err = llama_model_apply_lora_from_file(model,
|
||||
lora_adapter.c_str(),
|
||||
lora_scale,
|
||||
((i > 0) || params.lora_base.empty())
|
||||
? NULL
|
||||
: params.lora_base.c_str(),
|
||||
params.n_threads);
|
||||
if (err != 0) {
|
||||
auto adapter = llama_lora_adapter_init(lctx, lora_adapter.c_str());
|
||||
if (adapter == nullptr) {
|
||||
fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
|
||||
llama_free(lctx);
|
||||
llama_free_model(model);
|
||||
|
@ -406,6 +406,9 @@ extern "C" {
|
||||
const char * content;
|
||||
} llama_chat_message;
|
||||
|
||||
// lora adapter
|
||||
struct llama_lora_adapter;
|
||||
|
||||
// Helpers for getting default parameters
|
||||
LLAMA_API struct llama_model_params llama_model_default_params(void);
|
||||
LLAMA_API struct llama_context_params llama_context_default_params(void);
|
||||
@ -510,13 +513,9 @@ extern "C" {
|
||||
// the layers modified by the adapter. Can be NULL to use the current loaded model.
|
||||
// The model needs to be reloaded before applying a new adapter, otherwise the adapter
|
||||
// will be applied on top of the previous one
|
||||
// Returns 0 on success
|
||||
LLAMA_API int32_t llama_model_apply_lora_from_file(
|
||||
const struct llama_model * model,
|
||||
const char * path_lora,
|
||||
float scale,
|
||||
const char * path_base_model,
|
||||
int32_t n_threads);
|
||||
LLAMA_API struct llama_lora_adapter * llama_lora_adapter_init(
|
||||
struct llama_context * ctx,
|
||||
const char * path_lora);
|
||||
|
||||
// Apply a loaded control vector to a llama_context, or if data is NULL, clear
|
||||
// the currently loaded vector.
|
||||
|
423
src/llama.cpp
423
src/llama.cpp
@ -2547,6 +2547,29 @@ struct llama_control_vector {
|
||||
}
|
||||
};
|
||||
|
||||
struct lora_weight {
|
||||
struct ggml_tensor * a = nullptr;
|
||||
struct ggml_tensor * b = nullptr;
|
||||
lora_weight() {}
|
||||
lora_weight(struct ggml_tensor * a, struct ggml_tensor * b): a(a), b(b) {}
|
||||
};
|
||||
|
||||
struct llama_lora_adapter {
|
||||
// map tensor name to lora_a_b
|
||||
std::map<std::string, lora_weight> ab_map;
|
||||
std::vector<struct ggml_context *> ctxs;
|
||||
std::vector<ggml_backend_buffer_t> bufs;
|
||||
|
||||
~llama_lora_adapter() {
|
||||
for (struct ggml_context * ctx : ctxs) {
|
||||
ggml_free(ctx);
|
||||
}
|
||||
for (ggml_backend_buffer_t buf : bufs) {
|
||||
ggml_backend_buffer_free(buf);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct llama_vocab {
|
||||
using id = int32_t;
|
||||
using token = std::string;
|
||||
@ -2704,6 +2727,10 @@ struct llama_context {
|
||||
}
|
||||
|
||||
ggml_backend_buffer_free(buf_output);
|
||||
|
||||
for (auto adapter : lora_adapters) {
|
||||
delete adapter;
|
||||
}
|
||||
}
|
||||
|
||||
llama_cparams cparams;
|
||||
@ -2795,6 +2822,9 @@ struct llama_context {
|
||||
|
||||
// control vectors
|
||||
struct llama_control_vector cvec;
|
||||
|
||||
// lora adapters
|
||||
std::vector<struct llama_lora_adapter *> lora_adapters;
|
||||
};
|
||||
|
||||
static size_t llama_get_device_count(const llama_model & model) {
|
||||
@ -18243,281 +18273,149 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
|
||||
}
|
||||
}
|
||||
|
||||
static int llama_apply_lora_from_file_internal(
|
||||
const struct llama_model & model, const char * path_lora, float scale, const char * path_base_model, int n_threads
|
||||
) {
|
||||
static int llama_lora_adapter_init_internal(const struct llama_model & model, const char * path_lora, struct llama_lora_adapter & adapter) {
|
||||
static const int n_inp_tensors = 5; // see llama_model
|
||||
static const int n_out_tensors = 5; // see llama_model
|
||||
LLAMA_LOG_INFO("%s: applying lora adapter from '%s' - please wait ...\n", __func__, path_lora);
|
||||
|
||||
const int64_t t_start_lora_us = ggml_time_us();
|
||||
// TODO: check lora base model arch
|
||||
|
||||
llama_file fin(path_lora, "rb");
|
||||
|
||||
// verify magic and version
|
||||
{
|
||||
uint32_t magic = fin.read_u32();
|
||||
if (magic != LLAMA_FILE_MAGIC_GGLA) {
|
||||
LLAMA_LOG_ERROR("%s: bad file magic\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
uint32_t format_version = fin.read_u32();
|
||||
if (format_version != 1) {
|
||||
LLAMA_LOG_ERROR("%s: unsupported file version\n", __func__ );
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
int32_t lora_r = fin.read_u32();
|
||||
int32_t lora_alpha = fin.read_u32();
|
||||
float scaling = scale * (float)lora_alpha / (float)lora_r;
|
||||
|
||||
LLAMA_LOG_INFO("%s: r = %d, alpha = %d, scaling = %.2f\n", __func__, lora_r, lora_alpha, scaling);
|
||||
|
||||
// load base model
|
||||
std::unique_ptr<llama_model_loader> ml;
|
||||
if (path_base_model) {
|
||||
LLAMA_LOG_INFO("%s: loading base model from '%s'\n", __func__, path_base_model);
|
||||
ml.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true, /*check_tensors*/ false, /*kv_overrides*/ nullptr));
|
||||
ml->init_mappings(/*prefetch*/ false); // no prefetching
|
||||
}
|
||||
|
||||
struct tensor_meta {
|
||||
std::string name;
|
||||
ggml_type type;
|
||||
int32_t ne[2];
|
||||
size_t offset;
|
||||
ggml_context * ctx = nullptr;
|
||||
struct gguf_init_params meta_gguf_params = {
|
||||
/* .no_alloc = */ false,
|
||||
/* .ctx = */ &ctx,
|
||||
};
|
||||
std::map<std::string, tensor_meta> tensor_meta_map;
|
||||
|
||||
// load all tensor meta
|
||||
while (true) {
|
||||
if (fin.tell() == fin.size) {
|
||||
// eof
|
||||
break;
|
||||
}
|
||||
|
||||
int32_t n_dims;
|
||||
int32_t name_len;
|
||||
int32_t ftype;
|
||||
|
||||
fin.read_raw(&n_dims, sizeof(n_dims));
|
||||
fin.read_raw(&name_len, sizeof(name_len));
|
||||
fin.read_raw(&ftype, sizeof(ftype));
|
||||
|
||||
if (n_dims != 1 && n_dims != 2) {
|
||||
LLAMA_LOG_ERROR("%s: unsupported tensor dimension %d\n", __func__, n_dims);
|
||||
return 1;
|
||||
}
|
||||
|
||||
int32_t ne[2] = { 1, 1 };
|
||||
for (int i = 0; i < n_dims; ++i) {
|
||||
fin.read_raw(&ne[i], sizeof(ne[i]));
|
||||
}
|
||||
|
||||
std::string name;
|
||||
{
|
||||
GGML_ASSERT(name_len < GGML_MAX_NAME);
|
||||
char buf[GGML_MAX_NAME];
|
||||
fin.read_raw(buf, name_len);
|
||||
name = std::string(buf, name_len);
|
||||
}
|
||||
|
||||
// check for lora suffix
|
||||
std::string lora_suffix;
|
||||
if (name.length() > 6) {
|
||||
lora_suffix = name.substr(name.length() - 6);
|
||||
}
|
||||
if (lora_suffix != ".loraA" && lora_suffix != ".loraB") {
|
||||
LLAMA_LOG_ERROR("%s: error: '%s' is not a lora tensor\n", __func__, name.c_str());
|
||||
return 1;
|
||||
}
|
||||
|
||||
// tensor type
|
||||
ggml_type wtype;
|
||||
switch (ftype) {
|
||||
case 0: wtype = GGML_TYPE_F32; break;
|
||||
case 1: wtype = GGML_TYPE_F16; break;
|
||||
default:
|
||||
{
|
||||
LLAMA_LOG_ERROR("%s: invalid tensor data type '%d'\n",
|
||||
__func__, ftype);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
// data offset
|
||||
size_t offset = fin.tell();
|
||||
offset = (offset + 31) & -32;
|
||||
|
||||
// skip tensor data
|
||||
fin.seek(offset + ggml_row_size(wtype, ne[0]) * ne[1], SEEK_SET);
|
||||
|
||||
tensor_meta_map.emplace(name, tensor_meta{ name, wtype, { ne[0], ne[1] }, offset });
|
||||
struct gguf_context * ctx_gguf = gguf_init_from_file(path_lora, meta_gguf_params);
|
||||
if (!ctx_gguf) {
|
||||
LLAMA_LOG_ERROR("%s: failed to load lora adapter file from %s\n", __func__, path_lora);
|
||||
return -1;
|
||||
}
|
||||
|
||||
bool warned = false;
|
||||
int n_tensors = 0;
|
||||
|
||||
// apply
|
||||
ggml_backend_t backend_cpu = ggml_backend_cpu_init();
|
||||
if (backend_cpu == nullptr) {
|
||||
LLAMA_LOG_ERROR("%s: error: failed to initialize cpu backend\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
ggml_backend_cpu_set_n_threads(backend_cpu, n_threads);
|
||||
|
||||
std::vector<no_init<uint8_t>> read_buf;
|
||||
for (const auto & it : model.tensors_by_name) {
|
||||
const std::string & base_name = it.first;
|
||||
ggml_tensor * model_t = it.second;
|
||||
|
||||
if (tensor_meta_map.find(base_name + ".loraA") == tensor_meta_map.end() ||
|
||||
tensor_meta_map.find(base_name + ".loraB") == tensor_meta_map.end()) {
|
||||
continue;
|
||||
// calculate n_tensors_per_layer
|
||||
int n_tensors_per_layer = 0;
|
||||
{
|
||||
int32_t n_tensors = gguf_get_n_tensors(ctx_gguf);
|
||||
for (int i = 0; i < n_tensors; i++) {
|
||||
int il = -1;
|
||||
sscanf(gguf_get_tensor_name(ctx_gguf, i), "blk.%d.", &il);
|
||||
if (il == 0) n_tensors_per_layer++;
|
||||
}
|
||||
}
|
||||
printf("n_tensors_per_layer %d\n", n_tensors_per_layer);
|
||||
|
||||
tensor_meta & metaA = tensor_meta_map.at(base_name + ".loraA");
|
||||
tensor_meta & metaB = tensor_meta_map.at(base_name + ".loraB");
|
||||
// count layer buffer types
|
||||
std::map<ggml_backend_buffer_type_t, int> buft_layer_count;
|
||||
for (int64_t i = 0; i < model.hparams.n_layer; i++) {
|
||||
buft_layer_count[model.buft_layer[i].buft]++;
|
||||
}
|
||||
|
||||
ggml_init_params lora_init_params = {
|
||||
/* .mem_size */ ggml_tensor_overhead()*128 + ggml_graph_overhead(),
|
||||
/* .mem_buffer */ nullptr,
|
||||
/* .no_alloc */ true,
|
||||
// allocate contexts
|
||||
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
|
||||
{
|
||||
auto new_ggml_ctx = [](size_t n_tensors) {
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ n_tensors*ggml_tensor_overhead(),
|
||||
/*.mem_buffer =*/ NULL,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
return ggml_init(params);
|
||||
};
|
||||
ggml_context * lora_ctx = ggml_init(lora_init_params);
|
||||
if (lora_ctx == nullptr) {
|
||||
LLAMA_LOG_ERROR("%s: error: failed to initialize lora context\n", __func__);
|
||||
ggml_backend_free(backend_cpu);
|
||||
return 1;
|
||||
for (auto & it : buft_layer_count) {
|
||||
int n_layers = it.second;
|
||||
printf("buf %p layers %d\n", it.first, it.second);
|
||||
ctx_map[it.first] = new_ggml_ctx(2*n_layers*n_tensors_per_layer);
|
||||
}
|
||||
//ctx_map[model.buft_input.buft] = new_ggml_ctx(2*n_inp_tensors);
|
||||
//ctx_map[model.buft_output.buft] = new_ggml_ctx(2*n_out_tensors);
|
||||
}
|
||||
|
||||
// create tensors
|
||||
ggml_tensor * loraA = ggml_new_tensor_2d(lora_ctx, metaA.type, metaA.ne[0], metaA.ne[1]);
|
||||
ggml_tensor * loraB = ggml_new_tensor_2d(lora_ctx, metaB.type, metaB.ne[0], metaB.ne[1]);
|
||||
ggml_set_name(loraA, metaA.name.c_str());
|
||||
ggml_set_name(loraB, metaB.name.c_str());
|
||||
|
||||
ggml_tensor * base_t;
|
||||
if (ml) {
|
||||
if (!ml->get_tensor_meta(base_name.c_str())) {
|
||||
LLAMA_LOG_ERROR("%s: error: tensor '%s' not found in base model\n", __func__, base_name.c_str());
|
||||
return 1;
|
||||
// bundle lora_a and lora_b into pairs
|
||||
std::map<std::string, lora_weight> ab_map;
|
||||
auto str_endswith = [](const std::string & str, const std::string & suffix) {
|
||||
return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
|
||||
};
|
||||
for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) {
|
||||
std::string name(cur->name);
|
||||
if (str_endswith(name, ".lora_a")) {
|
||||
replace_all(name, ".lora_a", "");
|
||||
if (ab_map.find(name) == ab_map.end()) {
|
||||
ab_map[name] = lora_weight(cur, nullptr);
|
||||
} else {
|
||||
ab_map[name].a = cur;
|
||||
}
|
||||
base_t = ggml_dup_tensor(lora_ctx, ml->get_tensor_meta(base_name.c_str()));
|
||||
} else if (str_endswith(name, ".lora_b")) {
|
||||
replace_all(name, ".lora_b", "");
|
||||
if (ab_map.find(name) == ab_map.end()) {
|
||||
ab_map[name] = lora_weight(nullptr, cur);
|
||||
} else {
|
||||
ab_map[name].b = cur;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// add tensors
|
||||
for (auto & it : ab_map) {
|
||||
std::string name = it.first;
|
||||
lora_weight & w = it.second;
|
||||
GGML_ASSERT(w.a != nullptr);
|
||||
GGML_ASSERT(w.b != nullptr);
|
||||
int il = -1;
|
||||
sscanf(name.c_str(), "blk.%d.", &il);
|
||||
if (il >= 0) {
|
||||
printf("%s %p %p\n", name.c_str(), w.a, w.b);
|
||||
struct ggml_context * dev_ctx = ctx_map.at(model.buft_layer[il].buft);
|
||||
struct ggml_tensor * tensor_a = ggml_dup_tensor(dev_ctx, w.a);
|
||||
struct ggml_tensor * tensor_b = ggml_dup_tensor(dev_ctx, w.b);
|
||||
ggml_set_name(tensor_a, w.a->name);
|
||||
ggml_set_name(tensor_b, w.b->name);
|
||||
adapter.ab_map[name] = lora_weight(tensor_a, tensor_b);
|
||||
} else {
|
||||
base_t = ggml_dup_tensor(lora_ctx, model_t);
|
||||
}
|
||||
ggml_set_name(base_t, base_name.c_str());
|
||||
|
||||
// allocate in backend buffer
|
||||
ggml_backend_buffer_t lora_buf = ggml_backend_alloc_ctx_tensors_from_buft(lora_ctx, ggml_backend_cpu_buffer_type());
|
||||
if (lora_buf == nullptr) {
|
||||
LLAMA_LOG_ERROR("%s: error: failed to allocate lora tensors\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// load tensor data
|
||||
auto load_tensor = [&read_buf, &fin](const tensor_meta & tensor_meta, ggml_tensor * tensor) {
|
||||
read_buf.resize(ggml_nbytes(tensor));
|
||||
fin.seek(tensor_meta.offset, SEEK_SET);
|
||||
fin.read_raw(read_buf.data(), ggml_nbytes(tensor));
|
||||
ggml_backend_tensor_set(tensor, read_buf.data(), 0, read_buf.size());
|
||||
};
|
||||
load_tensor(metaA, loraA);
|
||||
load_tensor(metaB, loraB);
|
||||
|
||||
// load base model tensor data
|
||||
if (ml) {
|
||||
ml->load_data_for(base_t);
|
||||
} else {
|
||||
ggml_backend_tensor_copy(model_t, base_t);
|
||||
}
|
||||
|
||||
if (ggml_is_quantized(base_t->type) && !warned) {
|
||||
LLAMA_LOG_WARN("%s: warning: using a lora adapter with a quantized model may result in poor quality, "
|
||||
"use a f16 or f32 base model with --lora-base\n", __func__);
|
||||
warned = true;
|
||||
}
|
||||
|
||||
if (base_t->ne[0] != loraA->ne[1] || base_t->ne[1] != loraB->ne[1]) {
|
||||
LLAMA_LOG_ERROR("%s: incompatible tensor dimensions (%" PRId64 " and %" PRId64 ");"
|
||||
" are you sure that this adapter is for this model?\n", __func__, base_t->ne[0], loraA->ne[1]);
|
||||
ggml_free(lora_ctx);
|
||||
ggml_backend_buffer_free(lora_buf);
|
||||
ggml_backend_free(backend_cpu);
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto build_lora_graph = [&]() {
|
||||
// w = w + BA*s
|
||||
ggml_tensor * BA = ggml_mul_mat(lora_ctx, loraA, loraB);
|
||||
ggml_set_name(BA, "BA");
|
||||
|
||||
if (scaling != 1.0f) {
|
||||
BA = ggml_scale(lora_ctx, BA, scaling);
|
||||
ggml_set_name(BA, "BA_scaled");
|
||||
}
|
||||
|
||||
ggml_tensor * r;
|
||||
r = ggml_add_inplace(lora_ctx, base_t, BA);
|
||||
ggml_set_name(r, "r_add");
|
||||
|
||||
if (base_t->type != model_t->type) {
|
||||
// convert the result to the model type
|
||||
r = ggml_cast(lora_ctx, r, model_t->type);
|
||||
ggml_set_name(r, "r_cast");
|
||||
}
|
||||
|
||||
return r;
|
||||
};
|
||||
|
||||
ggml_cgraph * gf = ggml_new_graph(lora_ctx);
|
||||
ggml_tensor * r = build_lora_graph();
|
||||
ggml_build_forward_expand(gf, r);
|
||||
|
||||
ggml_backend_buffer_t graph_buf = ggml_backend_alloc_ctx_tensors_from_buft(lora_ctx, ggml_backend_cpu_buffer_type());
|
||||
if (graph_buf == nullptr) {
|
||||
LLAMA_LOG_ERROR("%s: error: failed to allocate graph tensors\n", __func__);
|
||||
ggml_free(lora_ctx);
|
||||
ggml_backend_buffer_free(lora_buf);
|
||||
ggml_backend_free(backend_cpu);
|
||||
return 1;
|
||||
}
|
||||
|
||||
ggml_backend_graph_compute(backend_cpu, gf);
|
||||
|
||||
ggml_backend_tensor_set(model_t, r->data, 0, ggml_nbytes(r));
|
||||
|
||||
#if 0
|
||||
// TODO: use scheduler with fallback to CPU for less copies between CPU and GPU
|
||||
//ggml_backend_sched_t sched = ggml_backend_sched_new(backends.data(), backends.size(), GGML_DEFAULT_GRAPH_SIZE);
|
||||
|
||||
// sched compute
|
||||
ggml_build_forward_expand(gf, build_graph());
|
||||
ggml_backend_sched_init_measure(sched, gf);
|
||||
|
||||
// create the graph again, since the previous one was destroyed by the measure
|
||||
ggml_graph_clear(gf);
|
||||
ggml_build_forward_expand(gf, build_graph());
|
||||
ggml_backend_sched_graph_compute(sched, gf);
|
||||
ggml_backend_sched_free(sched);
|
||||
#endif
|
||||
|
||||
ggml_backend_buffer_free(lora_buf);
|
||||
ggml_backend_buffer_free(graph_buf);
|
||||
ggml_free(lora_ctx);
|
||||
|
||||
n_tensors++;
|
||||
if (n_tensors % 4 == 0) {
|
||||
LLAMA_LOG_INFO(".");
|
||||
// TODO: process output & token_embd tensors
|
||||
}
|
||||
}
|
||||
|
||||
ggml_backend_free(backend_cpu);
|
||||
// allocate tensors / buffers and zero
|
||||
{
|
||||
adapter.ctxs.reserve(ctx_map.size());
|
||||
adapter.bufs.reserve(ctx_map.size());
|
||||
for (auto it : ctx_map) {
|
||||
ggml_backend_buffer_type_t buft = it.first;
|
||||
ggml_context * ctx = it.second;
|
||||
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
|
||||
if (!buf) {
|
||||
LLAMA_LOG_ERROR("%s: failed to allocate buffer for lora adapter\n", __func__);
|
||||
return -1;
|
||||
}
|
||||
ggml_backend_buffer_clear(buf, 0);
|
||||
adapter.ctxs.push_back(ctx);
|
||||
adapter.bufs.push_back(buf);
|
||||
}
|
||||
}
|
||||
|
||||
const int64_t t_lora_us = ggml_time_us() - t_start_lora_us;
|
||||
LLAMA_LOG_INFO(" done (%.2f ms)\n", t_lora_us / 1000.0);
|
||||
// set tensor data
|
||||
{
|
||||
llama_file gguf_file(path_lora, "rb");
|
||||
std::vector<uint8_t> read_buf;
|
||||
auto set_tensor = [&](struct ggml_tensor * orig, struct ggml_tensor * dev) {
|
||||
size_t offs = gguf_get_data_offset(ctx_gguf) + gguf_get_tensor_offset(ctx_gguf, gguf_find_tensor(ctx_gguf, orig->name));
|
||||
size_t size = ggml_nbytes(orig);
|
||||
if (read_buf.size() < size) {
|
||||
read_buf.resize(size);
|
||||
}
|
||||
gguf_file.read_raw(read_buf.data(), size);
|
||||
printf("%s: %s size=%ld\n", __func__, orig->name, size);
|
||||
return ggml_backend_tensor_set(dev, read_buf.data(), 0, size);
|
||||
};
|
||||
for (auto & it : adapter.ab_map) {
|
||||
auto orig = ab_map[it.first];
|
||||
auto dev = it.second;
|
||||
set_tensor(orig.a, dev.a);
|
||||
set_tensor(orig.b, dev.b);
|
||||
}
|
||||
}
|
||||
|
||||
// free ctx for reading gguf
|
||||
ggml_free(ctx);
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -19298,12 +19196,19 @@ uint32_t llama_model_quantize(
|
||||
}
|
||||
}
|
||||
|
||||
int32_t llama_model_apply_lora_from_file(const struct llama_model * model, const char * path_lora, float scale, const char * path_base_model, int32_t n_threads) {
|
||||
struct llama_lora_adapter * llama_lora_adapter_init(struct llama_context * ctx, const char * path_lora) {
|
||||
try {
|
||||
return llama_apply_lora_from_file_internal(*model, path_lora, scale, path_base_model, n_threads);
|
||||
struct llama_lora_adapter * adapter = new llama_lora_adapter;
|
||||
int res = llama_lora_adapter_init_internal(ctx->model, path_lora, *adapter);
|
||||
if (res == 0) {
|
||||
ctx->lora_adapters.push_back(adapter);
|
||||
return adapter;
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: failed to apply lora adapter: %s\n", __func__, err.what());
|
||||
return 1;
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user