From 79e2982788b0102aabb098b1a3d6227a7e32a483 Mon Sep 17 00:00:00 2001 From: ngxson Date: Mon, 8 Jul 2024 11:59:01 +0200 Subject: [PATCH] update based on review comments --- src/llama.cpp | 106 +++++++++++++++++++++++++------------------------- 1 file changed, 54 insertions(+), 52 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index ffc8ffbd2..a4ceb0959 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2821,20 +2821,20 @@ struct llama_context { struct llama_control_vector cvec; // lora adapters and scales - std::map lora_adapters; + std::unordered_map lora_adapters; }; -struct lora_weight { +struct llama_lora_weight { struct ggml_tensor * a = nullptr; struct ggml_tensor * b = nullptr; - lora_weight() {} - lora_weight(struct ggml_tensor * a, struct ggml_tensor * b): a(a), b(b) {} + llama_lora_weight() {} + llama_lora_weight(struct ggml_tensor * a, struct ggml_tensor * b): a(a), b(b) {} }; struct llama_lora_adapter { struct llama_model * base_model; // map tensor name to lora_a_b - std::map ab_map; + std::unordered_map ab_map; std::vector ctxs; std::vector bufs; @@ -2842,14 +2842,13 @@ struct llama_lora_adapter { base_model->lora_adapters.insert(this); } - bool has_weight(struct ggml_tensor * w) { + llama_lora_weight * get_weight(struct ggml_tensor * w) { std::string name(w->name); - return ab_map.find(name) != ab_map.end(); - } - - lora_weight & get_weight(struct ggml_tensor * w) { - std::string name(w->name); - return ab_map.at(name); + auto pos = ab_map.find(name); + if (ab_map.find(name) != ab_map.end()) { + return &pos->second; + } + return nullptr; } ~llama_lora_adapter() { @@ -7855,23 +7854,22 @@ static void llm_build_kv_store( } // do mat_mul, while optionally apply lora -static struct ggml_tensor * llm_build_mm( +static struct ggml_tensor * llm_build_lora_mm( struct llama_context & lctx, struct ggml_context * ctx0, struct ggml_tensor * w, struct ggml_tensor * cur) { struct ggml_tensor * res = ggml_mul_mat(ctx0, w, cur); for (auto & it : lctx.lora_adapters) { - struct llama_lora_adapter * adapter = it.first; + struct llama_lora_weight * lora = it.first->get_weight(w); float scale = it.second; - if (!adapter->has_weight(w)) { + if (lora == nullptr) { continue; } - struct lora_weight & lora = adapter->get_weight(w); // TODO: check if lora_a need transpose - struct ggml_tensor * a = ggml_cont(ctx0, ggml_transpose(ctx0, lora.a)); + struct ggml_tensor * a = ggml_cont(ctx0, ggml_transpose(ctx0, lora->a)); struct ggml_tensor * ab_cur = ggml_mul_mat( - ctx0, lora.b, + ctx0, lora->b, ggml_mul_mat(ctx0, a, cur) ); ab_cur = ggml_scale_inplace(ctx0, ab_cur, scale); @@ -7930,7 +7928,7 @@ static struct ggml_tensor * llm_build_ffn( llm_ffn_gate_type type_gate, const llm_build_cb & cb, int il) { - struct ggml_tensor * tmp = up ? llm_build_mm(lctx, ctx, up, cur) : cur; + struct ggml_tensor * tmp = up ? llm_build_lora_mm(lctx, ctx, up, cur) : cur; cb(tmp, "ffn_up", il); if (up_b) { @@ -7947,12 +7945,12 @@ static struct ggml_tensor * llm_build_ffn( switch (type_gate) { case LLM_FFN_SEQ: { - cur = llm_build_mm(lctx, ctx, gate, tmp); + cur = llm_build_lora_mm(lctx, ctx, gate, tmp); cb(cur, "ffn_gate", il); } break; case LLM_FFN_PAR: { - cur = llm_build_mm(lctx, ctx, gate, cur); + cur = llm_build_lora_mm(lctx, ctx, gate, cur); cb(cur, "ffn_gate", il); } break; } @@ -8020,7 +8018,7 @@ static struct ggml_tensor * llm_build_ffn( } if (down) { - cur = llm_build_mm(lctx, ctx, down, cur); + cur = llm_build_lora_mm(lctx, ctx, down, cur); } if (down_b) { @@ -8058,7 +8056,7 @@ static struct ggml_tensor * llm_build_moe_ffn( int64_t n_embd = cur->ne[0]; int64_t n_tokens = cur->ne[1]; - ggml_tensor * logits = llm_build_mm(lctx, ctx, gate_inp, cur); // [n_expert, n_tokens] + ggml_tensor * logits = llm_build_lora_mm(lctx, ctx, gate_inp, cur); // [n_expert, n_tokens] cb(logits, "ffn_moe_logits", il); ggml_tensor * probs = ggml_soft_max(ctx, logits); // [n_expert, n_tokens] @@ -8199,7 +8197,7 @@ static struct ggml_tensor * llm_build_kqv( cur = ggml_reshape_2d(ctx, cur, n_embd_head_v*n_head, n_tokens); } else { - struct ggml_tensor * kq = llm_build_mm(lctx, ctx, k, q); + struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); cb(kq, "kq", il); if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) { @@ -8242,7 +8240,7 @@ static struct ggml_tensor * llm_build_kqv( 0); cb(v, "v", il); - struct ggml_tensor * kqv = llm_build_mm(lctx, ctx, v, kq); + struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq); cb(kqv, "kqv", il); struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3); @@ -8255,7 +8253,7 @@ static struct ggml_tensor * llm_build_kqv( ggml_build_forward_expand(graph, cur); if (wo) { - cur = llm_build_mm(lctx, ctx, wo, cur); + cur = llm_build_lora_mm(lctx, ctx, wo, cur); } if (wo_b) { @@ -8762,21 +8760,21 @@ struct llm_build_context { // self-attention { // compute Q and K and RoPE them - struct ggml_tensor * Qcur = llm_build_mm(lctx, ctx0, model.layers[il].wq, cur); + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); cb(Qcur, "Qcur", il); if (model.layers[il].bq) { Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); cb(Qcur, "Qcur", il); } - struct ggml_tensor * Kcur = llm_build_mm(lctx, ctx0, model.layers[il].wk, cur); + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); cb(Kcur, "Kcur", il); if (model.layers[il].bk) { Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); cb(Kcur, "Kcur", il); } - struct ggml_tensor * Vcur = llm_build_mm(lctx, ctx0, model.layers[il].wv, cur); + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); cb(Vcur, "Vcur", il); if (model.layers[il].bv) { Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); @@ -8864,7 +8862,7 @@ struct llm_build_context { cb(cur, "result_norm", -1); // lm_head - cur = llm_build_mm(lctx, ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -18517,7 +18515,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s } } -static int llama_lora_adapter_init_internal(const struct llama_model * model, const char * path_lora, struct llama_lora_adapter & adapter) { +static void llama_lora_adapter_init_internal(struct llama_model * model, const char * path_lora, struct llama_lora_adapter & adapter) { static const int n_inp_tensors = 5; // see llama_model static const int n_out_tensors = 5; // see llama_model LLAMA_LOG_INFO("%s: applying lora adapter from '%s' - please wait ...\n", __func__, path_lora); @@ -18532,7 +18530,7 @@ static int llama_lora_adapter_init_internal(const struct llama_model * model, co struct gguf_context * ctx_gguf = gguf_init_from_file(path_lora, meta_gguf_params); if (!ctx_gguf) { LLAMA_LOG_ERROR("%s: failed to load lora adapter file from %s\n", __func__, path_lora); - return -1; + throw std::exception(); } // calculate n_tensors_per_layer @@ -18574,7 +18572,7 @@ static int llama_lora_adapter_init_internal(const struct llama_model * model, co } // bundle lora_a and lora_b into pairs - std::map ab_map; + std::map ab_map; auto str_endswith = [](const std::string & str, const std::string & suffix) { return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0; }; @@ -18583,18 +18581,19 @@ static int llama_lora_adapter_init_internal(const struct llama_model * model, co if (str_endswith(name, ".lora_a")) { replace_all(name, ".lora_a", ""); if (ab_map.find(name) == ab_map.end()) { - ab_map[name] = lora_weight(cur, nullptr); + ab_map[name] = llama_lora_weight(cur, nullptr); } else { ab_map[name].a = cur; } } else if (str_endswith(name, ".lora_b")) { replace_all(name, ".lora_b", ""); if (ab_map.find(name) == ab_map.end()) { - ab_map[name] = lora_weight(nullptr, cur); + ab_map[name] = llama_lora_weight(nullptr, cur); } else { ab_map[name].b = cur; } } else { + // maybe "optimizer.*"" tensors LLAMA_LOG_WARN("%s: discard tensor '%s'\n", __func__, cur->name); } } @@ -18603,28 +18602,26 @@ static int llama_lora_adapter_init_internal(const struct llama_model * model, co for (auto & it : ab_map) { std::string name = it.first; const char * cname = name.c_str(); - lora_weight & w = it.second; + llama_lora_weight & w = it.second; GGML_ASSERT(w.a != nullptr); GGML_ASSERT(w.b != nullptr); int il = -1; sscanf(cname, "blk.%d.", &il); - struct ggml_context * dev_ctx; // device ctx - if (il >= 0) { - dev_ctx = ctx_map.at(model->buft_layer[il].buft); - } else if (strstr(cname, "tok") == 0) { - dev_ctx = ctx_map.at(model->buft_input.buft); - } else if (strstr(cname, "output") == 0) { - dev_ctx = ctx_map.at(model->buft_output.buft); - } else { - LLAMA_LOG_WARN("%s: discard tensor '%s'\n", __func__, cname); - continue; + // device buft and device ctx + auto model_tensor = llama_get_model_tensor(model, cname); + if (!model_tensor) { + gguf_free(ctx_gguf); + ggml_free(ctx); + throw std::runtime_error("LoRA tensor '" + name + "' does not exist in base model\n"); } + struct ggml_context * dev_ctx = ctx_map.at(ggml_backend_buffer_get_type(model_tensor->buffer)); + // TODO: validate tensor shape // LLAMA_LOG_INFO("%s %p %p\n", cname, w.a, w.b); struct ggml_tensor * tensor_a = ggml_dup_tensor(dev_ctx, w.a); struct ggml_tensor * tensor_b = ggml_dup_tensor(dev_ctx, w.b); ggml_set_name(tensor_a, w.a->name); ggml_set_name(tensor_b, w.b->name); - adapter.ab_map[name] = lora_weight(tensor_a, tensor_b); + adapter.ab_map[name] = llama_lora_weight(tensor_a, tensor_b); } // allocate tensors / buffers and zero @@ -18636,8 +18633,9 @@ static int llama_lora_adapter_init_internal(const struct llama_model * model, co ggml_context * ctx_dev = it.second; ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx_dev, buft); if (!buf) { - LLAMA_LOG_ERROR("%s: failed to allocate buffer for lora adapter\n", __func__); - return -1; + gguf_free(ctx_gguf); + ggml_free(ctx); + throw std::runtime_error("failed to allocate buffer for lora adapter\n"); } ggml_backend_buffer_clear(buf, 0); adapter.ctxs.push_back(ctx_dev); @@ -18671,14 +18669,18 @@ static int llama_lora_adapter_init_internal(const struct llama_model * model, co LLAMA_LOG_INFO("%s: loaded %ld tensors from lora file\n", __func__, adapter.ab_map.size()*2); // free ctx for reading gguf + gguf_free(ctx_gguf); ggml_free(ctx); - return 0; } int32_t llama_lora_adapter_set( struct llama_context * ctx, struct llama_lora_adapter * adapter, float scale) { + if (ctx->cparams.flash_attn) { + LLAMA_LOG_ERROR("%s: flash_attn is not compatible with LoRA\n", __func__); + return -1; + } ctx->lora_adapters[adapter] = scale; return 0; } @@ -19479,8 +19481,8 @@ uint32_t llama_model_quantize( struct llama_lora_adapter * llama_lora_adapter_init(struct llama_model * model, const char * path_lora) { try { struct llama_lora_adapter * adapter = new llama_lora_adapter(model); - int res = llama_lora_adapter_init_internal(model, path_lora, *adapter); - return res == 0 ? adapter : nullptr; + llama_lora_adapter_init_internal(model, path_lora, *adapter); + return adapter; } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: failed to apply lora adapter: %s\n", __func__, err.what()); return nullptr;