update based on review comments

This commit is contained in:
ngxson 2024-07-08 11:59:01 +02:00
parent 30faf1f3de
commit 79e2982788

View File

@ -2821,20 +2821,20 @@ struct llama_context {
struct llama_control_vector cvec; struct llama_control_vector cvec;
// lora adapters and scales // lora adapters and scales
std::map<struct llama_lora_adapter *, float> lora_adapters; std::unordered_map<struct llama_lora_adapter *, float> lora_adapters;
}; };
struct lora_weight { struct llama_lora_weight {
struct ggml_tensor * a = nullptr; struct ggml_tensor * a = nullptr;
struct ggml_tensor * b = nullptr; struct ggml_tensor * b = nullptr;
lora_weight() {} llama_lora_weight() {}
lora_weight(struct ggml_tensor * a, struct ggml_tensor * b): a(a), b(b) {} llama_lora_weight(struct ggml_tensor * a, struct ggml_tensor * b): a(a), b(b) {}
}; };
struct llama_lora_adapter { struct llama_lora_adapter {
struct llama_model * base_model; struct llama_model * base_model;
// map tensor name to lora_a_b // map tensor name to lora_a_b
std::map<std::string, struct lora_weight> ab_map; std::unordered_map<std::string, struct llama_lora_weight> ab_map;
std::vector<struct ggml_context *> ctxs; std::vector<struct ggml_context *> ctxs;
std::vector<ggml_backend_buffer_t> bufs; std::vector<ggml_backend_buffer_t> bufs;
@ -2842,14 +2842,13 @@ struct llama_lora_adapter {
base_model->lora_adapters.insert(this); base_model->lora_adapters.insert(this);
} }
bool has_weight(struct ggml_tensor * w) { llama_lora_weight * get_weight(struct ggml_tensor * w) {
std::string name(w->name); std::string name(w->name);
return ab_map.find(name) != ab_map.end(); auto pos = ab_map.find(name);
if (ab_map.find(name) != ab_map.end()) {
return &pos->second;
} }
return nullptr;
lora_weight & get_weight(struct ggml_tensor * w) {
std::string name(w->name);
return ab_map.at(name);
} }
~llama_lora_adapter() { ~llama_lora_adapter() {
@ -7855,23 +7854,22 @@ static void llm_build_kv_store(
} }
// do mat_mul, while optionally apply lora // do mat_mul, while optionally apply lora
static struct ggml_tensor * llm_build_mm( static struct ggml_tensor * llm_build_lora_mm(
struct llama_context & lctx, struct llama_context & lctx,
struct ggml_context * ctx0, struct ggml_context * ctx0,
struct ggml_tensor * w, struct ggml_tensor * w,
struct ggml_tensor * cur) { struct ggml_tensor * cur) {
struct ggml_tensor * res = ggml_mul_mat(ctx0, w, cur); struct ggml_tensor * res = ggml_mul_mat(ctx0, w, cur);
for (auto & it : lctx.lora_adapters) { for (auto & it : lctx.lora_adapters) {
struct llama_lora_adapter * adapter = it.first; struct llama_lora_weight * lora = it.first->get_weight(w);
float scale = it.second; float scale = it.second;
if (!adapter->has_weight(w)) { if (lora == nullptr) {
continue; continue;
} }
struct lora_weight & lora = adapter->get_weight(w);
// TODO: check if lora_a need transpose // TODO: check if lora_a need transpose
struct ggml_tensor * a = ggml_cont(ctx0, ggml_transpose(ctx0, lora.a)); struct ggml_tensor * a = ggml_cont(ctx0, ggml_transpose(ctx0, lora->a));
struct ggml_tensor * ab_cur = ggml_mul_mat( struct ggml_tensor * ab_cur = ggml_mul_mat(
ctx0, lora.b, ctx0, lora->b,
ggml_mul_mat(ctx0, a, cur) ggml_mul_mat(ctx0, a, cur)
); );
ab_cur = ggml_scale_inplace(ctx0, ab_cur, scale); ab_cur = ggml_scale_inplace(ctx0, ab_cur, scale);
@ -7930,7 +7928,7 @@ static struct ggml_tensor * llm_build_ffn(
llm_ffn_gate_type type_gate, llm_ffn_gate_type type_gate,
const llm_build_cb & cb, const llm_build_cb & cb,
int il) { int il) {
struct ggml_tensor * tmp = up ? llm_build_mm(lctx, ctx, up, cur) : cur; struct ggml_tensor * tmp = up ? llm_build_lora_mm(lctx, ctx, up, cur) : cur;
cb(tmp, "ffn_up", il); cb(tmp, "ffn_up", il);
if (up_b) { if (up_b) {
@ -7947,12 +7945,12 @@ static struct ggml_tensor * llm_build_ffn(
switch (type_gate) { switch (type_gate) {
case LLM_FFN_SEQ: case LLM_FFN_SEQ:
{ {
cur = llm_build_mm(lctx, ctx, gate, tmp); cur = llm_build_lora_mm(lctx, ctx, gate, tmp);
cb(cur, "ffn_gate", il); cb(cur, "ffn_gate", il);
} break; } break;
case LLM_FFN_PAR: case LLM_FFN_PAR:
{ {
cur = llm_build_mm(lctx, ctx, gate, cur); cur = llm_build_lora_mm(lctx, ctx, gate, cur);
cb(cur, "ffn_gate", il); cb(cur, "ffn_gate", il);
} break; } break;
} }
@ -8020,7 +8018,7 @@ static struct ggml_tensor * llm_build_ffn(
} }
if (down) { if (down) {
cur = llm_build_mm(lctx, ctx, down, cur); cur = llm_build_lora_mm(lctx, ctx, down, cur);
} }
if (down_b) { if (down_b) {
@ -8058,7 +8056,7 @@ static struct ggml_tensor * llm_build_moe_ffn(
int64_t n_embd = cur->ne[0]; int64_t n_embd = cur->ne[0];
int64_t n_tokens = cur->ne[1]; int64_t n_tokens = cur->ne[1];
ggml_tensor * logits = llm_build_mm(lctx, ctx, gate_inp, cur); // [n_expert, n_tokens] ggml_tensor * logits = llm_build_lora_mm(lctx, ctx, gate_inp, cur); // [n_expert, n_tokens]
cb(logits, "ffn_moe_logits", il); cb(logits, "ffn_moe_logits", il);
ggml_tensor * probs = ggml_soft_max(ctx, logits); // [n_expert, n_tokens] ggml_tensor * probs = ggml_soft_max(ctx, logits); // [n_expert, n_tokens]
@ -8199,7 +8197,7 @@ static struct ggml_tensor * llm_build_kqv(
cur = ggml_reshape_2d(ctx, cur, n_embd_head_v*n_head, n_tokens); cur = ggml_reshape_2d(ctx, cur, n_embd_head_v*n_head, n_tokens);
} else { } else {
struct ggml_tensor * kq = llm_build_mm(lctx, ctx, k, q); struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
cb(kq, "kq", il); cb(kq, "kq", il);
if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) { if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) {
@ -8242,7 +8240,7 @@ static struct ggml_tensor * llm_build_kqv(
0); 0);
cb(v, "v", il); cb(v, "v", il);
struct ggml_tensor * kqv = llm_build_mm(lctx, ctx, v, kq); struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq);
cb(kqv, "kqv", il); cb(kqv, "kqv", il);
struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3); struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3);
@ -8255,7 +8253,7 @@ static struct ggml_tensor * llm_build_kqv(
ggml_build_forward_expand(graph, cur); ggml_build_forward_expand(graph, cur);
if (wo) { if (wo) {
cur = llm_build_mm(lctx, ctx, wo, cur); cur = llm_build_lora_mm(lctx, ctx, wo, cur);
} }
if (wo_b) { if (wo_b) {
@ -8762,21 +8760,21 @@ struct llm_build_context {
// self-attention // self-attention
{ {
// compute Q and K and RoPE them // compute Q and K and RoPE them
struct ggml_tensor * Qcur = llm_build_mm(lctx, ctx0, model.layers[il].wq, cur); struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
cb(Qcur, "Qcur", il); cb(Qcur, "Qcur", il);
if (model.layers[il].bq) { if (model.layers[il].bq) {
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
cb(Qcur, "Qcur", il); cb(Qcur, "Qcur", il);
} }
struct ggml_tensor * Kcur = llm_build_mm(lctx, ctx0, model.layers[il].wk, cur); struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
cb(Kcur, "Kcur", il); cb(Kcur, "Kcur", il);
if (model.layers[il].bk) { if (model.layers[il].bk) {
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
cb(Kcur, "Kcur", il); cb(Kcur, "Kcur", il);
} }
struct ggml_tensor * Vcur = llm_build_mm(lctx, ctx0, model.layers[il].wv, cur); struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
cb(Vcur, "Vcur", il); cb(Vcur, "Vcur", il);
if (model.layers[il].bv) { if (model.layers[il].bv) {
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
@ -8864,7 +8862,7 @@ struct llm_build_context {
cb(cur, "result_norm", -1); cb(cur, "result_norm", -1);
// lm_head // lm_head
cur = llm_build_mm(lctx, ctx0, model.output, cur); cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
cb(cur, "result_output", -1); cb(cur, "result_output", -1);
ggml_build_forward_expand(gf, cur); ggml_build_forward_expand(gf, cur);
@ -18517,7 +18515,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
} }
} }
static int llama_lora_adapter_init_internal(const struct llama_model * model, const char * path_lora, struct llama_lora_adapter & adapter) { static void llama_lora_adapter_init_internal(struct llama_model * model, const char * path_lora, struct llama_lora_adapter & adapter) {
static const int n_inp_tensors = 5; // see llama_model static const int n_inp_tensors = 5; // see llama_model
static const int n_out_tensors = 5; // see llama_model static const int n_out_tensors = 5; // see llama_model
LLAMA_LOG_INFO("%s: applying lora adapter from '%s' - please wait ...\n", __func__, path_lora); LLAMA_LOG_INFO("%s: applying lora adapter from '%s' - please wait ...\n", __func__, path_lora);
@ -18532,7 +18530,7 @@ static int llama_lora_adapter_init_internal(const struct llama_model * model, co
struct gguf_context * ctx_gguf = gguf_init_from_file(path_lora, meta_gguf_params); struct gguf_context * ctx_gguf = gguf_init_from_file(path_lora, meta_gguf_params);
if (!ctx_gguf) { if (!ctx_gguf) {
LLAMA_LOG_ERROR("%s: failed to load lora adapter file from %s\n", __func__, path_lora); LLAMA_LOG_ERROR("%s: failed to load lora adapter file from %s\n", __func__, path_lora);
return -1; throw std::exception();
} }
// calculate n_tensors_per_layer // calculate n_tensors_per_layer
@ -18574,7 +18572,7 @@ static int llama_lora_adapter_init_internal(const struct llama_model * model, co
} }
// bundle lora_a and lora_b into pairs // bundle lora_a and lora_b into pairs
std::map<std::string, lora_weight> ab_map; std::map<std::string, llama_lora_weight> ab_map;
auto str_endswith = [](const std::string & str, const std::string & suffix) { auto str_endswith = [](const std::string & str, const std::string & suffix) {
return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0; return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
}; };
@ -18583,18 +18581,19 @@ static int llama_lora_adapter_init_internal(const struct llama_model * model, co
if (str_endswith(name, ".lora_a")) { if (str_endswith(name, ".lora_a")) {
replace_all(name, ".lora_a", ""); replace_all(name, ".lora_a", "");
if (ab_map.find(name) == ab_map.end()) { if (ab_map.find(name) == ab_map.end()) {
ab_map[name] = lora_weight(cur, nullptr); ab_map[name] = llama_lora_weight(cur, nullptr);
} else { } else {
ab_map[name].a = cur; ab_map[name].a = cur;
} }
} else if (str_endswith(name, ".lora_b")) { } else if (str_endswith(name, ".lora_b")) {
replace_all(name, ".lora_b", ""); replace_all(name, ".lora_b", "");
if (ab_map.find(name) == ab_map.end()) { if (ab_map.find(name) == ab_map.end()) {
ab_map[name] = lora_weight(nullptr, cur); ab_map[name] = llama_lora_weight(nullptr, cur);
} else { } else {
ab_map[name].b = cur; ab_map[name].b = cur;
} }
} else { } else {
// maybe "optimizer.*"" tensors
LLAMA_LOG_WARN("%s: discard tensor '%s'\n", __func__, cur->name); LLAMA_LOG_WARN("%s: discard tensor '%s'\n", __func__, cur->name);
} }
} }
@ -18603,28 +18602,26 @@ static int llama_lora_adapter_init_internal(const struct llama_model * model, co
for (auto & it : ab_map) { for (auto & it : ab_map) {
std::string name = it.first; std::string name = it.first;
const char * cname = name.c_str(); const char * cname = name.c_str();
lora_weight & w = it.second; llama_lora_weight & w = it.second;
GGML_ASSERT(w.a != nullptr); GGML_ASSERT(w.a != nullptr);
GGML_ASSERT(w.b != nullptr); GGML_ASSERT(w.b != nullptr);
int il = -1; int il = -1;
sscanf(cname, "blk.%d.", &il); sscanf(cname, "blk.%d.", &il);
struct ggml_context * dev_ctx; // device ctx // device buft and device ctx
if (il >= 0) { auto model_tensor = llama_get_model_tensor(model, cname);
dev_ctx = ctx_map.at(model->buft_layer[il].buft); if (!model_tensor) {
} else if (strstr(cname, "tok") == 0) { gguf_free(ctx_gguf);
dev_ctx = ctx_map.at(model->buft_input.buft); ggml_free(ctx);
} else if (strstr(cname, "output") == 0) { throw std::runtime_error("LoRA tensor '" + name + "' does not exist in base model\n");
dev_ctx = ctx_map.at(model->buft_output.buft);
} else {
LLAMA_LOG_WARN("%s: discard tensor '%s'\n", __func__, cname);
continue;
} }
struct ggml_context * dev_ctx = ctx_map.at(ggml_backend_buffer_get_type(model_tensor->buffer));
// TODO: validate tensor shape
// LLAMA_LOG_INFO("%s %p %p\n", cname, w.a, w.b); // LLAMA_LOG_INFO("%s %p %p\n", cname, w.a, w.b);
struct ggml_tensor * tensor_a = ggml_dup_tensor(dev_ctx, w.a); struct ggml_tensor * tensor_a = ggml_dup_tensor(dev_ctx, w.a);
struct ggml_tensor * tensor_b = ggml_dup_tensor(dev_ctx, w.b); struct ggml_tensor * tensor_b = ggml_dup_tensor(dev_ctx, w.b);
ggml_set_name(tensor_a, w.a->name); ggml_set_name(tensor_a, w.a->name);
ggml_set_name(tensor_b, w.b->name); ggml_set_name(tensor_b, w.b->name);
adapter.ab_map[name] = lora_weight(tensor_a, tensor_b); adapter.ab_map[name] = llama_lora_weight(tensor_a, tensor_b);
} }
// allocate tensors / buffers and zero // allocate tensors / buffers and zero
@ -18636,8 +18633,9 @@ static int llama_lora_adapter_init_internal(const struct llama_model * model, co
ggml_context * ctx_dev = it.second; ggml_context * ctx_dev = it.second;
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx_dev, buft); ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx_dev, buft);
if (!buf) { if (!buf) {
LLAMA_LOG_ERROR("%s: failed to allocate buffer for lora adapter\n", __func__); gguf_free(ctx_gguf);
return -1; ggml_free(ctx);
throw std::runtime_error("failed to allocate buffer for lora adapter\n");
} }
ggml_backend_buffer_clear(buf, 0); ggml_backend_buffer_clear(buf, 0);
adapter.ctxs.push_back(ctx_dev); adapter.ctxs.push_back(ctx_dev);
@ -18671,14 +18669,18 @@ static int llama_lora_adapter_init_internal(const struct llama_model * model, co
LLAMA_LOG_INFO("%s: loaded %ld tensors from lora file\n", __func__, adapter.ab_map.size()*2); LLAMA_LOG_INFO("%s: loaded %ld tensors from lora file\n", __func__, adapter.ab_map.size()*2);
// free ctx for reading gguf // free ctx for reading gguf
gguf_free(ctx_gguf);
ggml_free(ctx); ggml_free(ctx);
return 0;
} }
int32_t llama_lora_adapter_set( int32_t llama_lora_adapter_set(
struct llama_context * ctx, struct llama_context * ctx,
struct llama_lora_adapter * adapter, struct llama_lora_adapter * adapter,
float scale) { float scale) {
if (ctx->cparams.flash_attn) {
LLAMA_LOG_ERROR("%s: flash_attn is not compatible with LoRA\n", __func__);
return -1;
}
ctx->lora_adapters[adapter] = scale; ctx->lora_adapters[adapter] = scale;
return 0; return 0;
} }
@ -19479,8 +19481,8 @@ uint32_t llama_model_quantize(
struct llama_lora_adapter * llama_lora_adapter_init(struct llama_model * model, const char * path_lora) { struct llama_lora_adapter * llama_lora_adapter_init(struct llama_model * model, const char * path_lora) {
try { try {
struct llama_lora_adapter * adapter = new llama_lora_adapter(model); struct llama_lora_adapter * adapter = new llama_lora_adapter(model);
int res = llama_lora_adapter_init_internal(model, path_lora, *adapter); llama_lora_adapter_init_internal(model, path_lora, *adapter);
return res == 0 ? adapter : nullptr; return adapter;
} catch (const std::exception & err) { } catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: failed to apply lora adapter: %s\n", __func__, err.what()); LLAMA_LOG_ERROR("%s: failed to apply lora adapter: %s\n", __func__, err.what());
return nullptr; return nullptr;