update based on review comments

This commit is contained in:
ngxson 2024-07-08 11:59:01 +02:00
parent 30faf1f3de
commit 79e2982788

View File

@ -2821,20 +2821,20 @@ struct llama_context {
struct llama_control_vector cvec;
// lora adapters and scales
std::map<struct llama_lora_adapter *, float> lora_adapters;
std::unordered_map<struct llama_lora_adapter *, float> lora_adapters;
};
struct lora_weight {
struct llama_lora_weight {
struct ggml_tensor * a = nullptr;
struct ggml_tensor * b = nullptr;
lora_weight() {}
lora_weight(struct ggml_tensor * a, struct ggml_tensor * b): a(a), b(b) {}
llama_lora_weight() {}
llama_lora_weight(struct ggml_tensor * a, struct ggml_tensor * b): a(a), b(b) {}
};
struct llama_lora_adapter {
struct llama_model * base_model;
// map tensor name to lora_a_b
std::map<std::string, struct lora_weight> ab_map;
std::unordered_map<std::string, struct llama_lora_weight> ab_map;
std::vector<struct ggml_context *> ctxs;
std::vector<ggml_backend_buffer_t> bufs;
@ -2842,14 +2842,13 @@ struct llama_lora_adapter {
base_model->lora_adapters.insert(this);
}
bool has_weight(struct ggml_tensor * w) {
llama_lora_weight * get_weight(struct ggml_tensor * w) {
std::string name(w->name);
return ab_map.find(name) != ab_map.end();
auto pos = ab_map.find(name);
if (ab_map.find(name) != ab_map.end()) {
return &pos->second;
}
lora_weight & get_weight(struct ggml_tensor * w) {
std::string name(w->name);
return ab_map.at(name);
return nullptr;
}
~llama_lora_adapter() {
@ -7855,23 +7854,22 @@ static void llm_build_kv_store(
}
// do mat_mul, while optionally apply lora
static struct ggml_tensor * llm_build_mm(
static struct ggml_tensor * llm_build_lora_mm(
struct llama_context & lctx,
struct ggml_context * ctx0,
struct ggml_tensor * w,
struct ggml_tensor * cur) {
struct ggml_tensor * res = ggml_mul_mat(ctx0, w, cur);
for (auto & it : lctx.lora_adapters) {
struct llama_lora_adapter * adapter = it.first;
struct llama_lora_weight * lora = it.first->get_weight(w);
float scale = it.second;
if (!adapter->has_weight(w)) {
if (lora == nullptr) {
continue;
}
struct lora_weight & lora = adapter->get_weight(w);
// TODO: check if lora_a need transpose
struct ggml_tensor * a = ggml_cont(ctx0, ggml_transpose(ctx0, lora.a));
struct ggml_tensor * a = ggml_cont(ctx0, ggml_transpose(ctx0, lora->a));
struct ggml_tensor * ab_cur = ggml_mul_mat(
ctx0, lora.b,
ctx0, lora->b,
ggml_mul_mat(ctx0, a, cur)
);
ab_cur = ggml_scale_inplace(ctx0, ab_cur, scale);
@ -7930,7 +7928,7 @@ static struct ggml_tensor * llm_build_ffn(
llm_ffn_gate_type type_gate,
const llm_build_cb & cb,
int il) {
struct ggml_tensor * tmp = up ? llm_build_mm(lctx, ctx, up, cur) : cur;
struct ggml_tensor * tmp = up ? llm_build_lora_mm(lctx, ctx, up, cur) : cur;
cb(tmp, "ffn_up", il);
if (up_b) {
@ -7947,12 +7945,12 @@ static struct ggml_tensor * llm_build_ffn(
switch (type_gate) {
case LLM_FFN_SEQ:
{
cur = llm_build_mm(lctx, ctx, gate, tmp);
cur = llm_build_lora_mm(lctx, ctx, gate, tmp);
cb(cur, "ffn_gate", il);
} break;
case LLM_FFN_PAR:
{
cur = llm_build_mm(lctx, ctx, gate, cur);
cur = llm_build_lora_mm(lctx, ctx, gate, cur);
cb(cur, "ffn_gate", il);
} break;
}
@ -8020,7 +8018,7 @@ static struct ggml_tensor * llm_build_ffn(
}
if (down) {
cur = llm_build_mm(lctx, ctx, down, cur);
cur = llm_build_lora_mm(lctx, ctx, down, cur);
}
if (down_b) {
@ -8058,7 +8056,7 @@ static struct ggml_tensor * llm_build_moe_ffn(
int64_t n_embd = cur->ne[0];
int64_t n_tokens = cur->ne[1];
ggml_tensor * logits = llm_build_mm(lctx, ctx, gate_inp, cur); // [n_expert, n_tokens]
ggml_tensor * logits = llm_build_lora_mm(lctx, ctx, gate_inp, cur); // [n_expert, n_tokens]
cb(logits, "ffn_moe_logits", il);
ggml_tensor * probs = ggml_soft_max(ctx, logits); // [n_expert, n_tokens]
@ -8199,7 +8197,7 @@ static struct ggml_tensor * llm_build_kqv(
cur = ggml_reshape_2d(ctx, cur, n_embd_head_v*n_head, n_tokens);
} else {
struct ggml_tensor * kq = llm_build_mm(lctx, ctx, k, q);
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
cb(kq, "kq", il);
if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) {
@ -8242,7 +8240,7 @@ static struct ggml_tensor * llm_build_kqv(
0);
cb(v, "v", il);
struct ggml_tensor * kqv = llm_build_mm(lctx, ctx, v, kq);
struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq);
cb(kqv, "kqv", il);
struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3);
@ -8255,7 +8253,7 @@ static struct ggml_tensor * llm_build_kqv(
ggml_build_forward_expand(graph, cur);
if (wo) {
cur = llm_build_mm(lctx, ctx, wo, cur);
cur = llm_build_lora_mm(lctx, ctx, wo, cur);
}
if (wo_b) {
@ -8762,21 +8760,21 @@ struct llm_build_context {
// self-attention
{
// compute Q and K and RoPE them
struct ggml_tensor * Qcur = llm_build_mm(lctx, ctx0, model.layers[il].wq, cur);
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
if (model.layers[il].bq) {
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
cb(Qcur, "Qcur", il);
}
struct ggml_tensor * Kcur = llm_build_mm(lctx, ctx0, model.layers[il].wk, cur);
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
cb(Kcur, "Kcur", il);
if (model.layers[il].bk) {
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
cb(Kcur, "Kcur", il);
}
struct ggml_tensor * Vcur = llm_build_mm(lctx, ctx0, model.layers[il].wv, cur);
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
if (model.layers[il].bv) {
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
@ -8864,7 +8862,7 @@ struct llm_build_context {
cb(cur, "result_norm", -1);
// lm_head
cur = llm_build_mm(lctx, ctx0, model.output, cur);
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
cb(cur, "result_output", -1);
ggml_build_forward_expand(gf, cur);
@ -18517,7 +18515,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
}
}
static int llama_lora_adapter_init_internal(const struct llama_model * model, const char * path_lora, struct llama_lora_adapter & adapter) {
static void llama_lora_adapter_init_internal(struct llama_model * model, const char * path_lora, struct llama_lora_adapter & adapter) {
static const int n_inp_tensors = 5; // see llama_model
static const int n_out_tensors = 5; // see llama_model
LLAMA_LOG_INFO("%s: applying lora adapter from '%s' - please wait ...\n", __func__, path_lora);
@ -18532,7 +18530,7 @@ static int llama_lora_adapter_init_internal(const struct llama_model * model, co
struct gguf_context * ctx_gguf = gguf_init_from_file(path_lora, meta_gguf_params);
if (!ctx_gguf) {
LLAMA_LOG_ERROR("%s: failed to load lora adapter file from %s\n", __func__, path_lora);
return -1;
throw std::exception();
}
// calculate n_tensors_per_layer
@ -18574,7 +18572,7 @@ static int llama_lora_adapter_init_internal(const struct llama_model * model, co
}
// bundle lora_a and lora_b into pairs
std::map<std::string, lora_weight> ab_map;
std::map<std::string, llama_lora_weight> ab_map;
auto str_endswith = [](const std::string & str, const std::string & suffix) {
return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
};
@ -18583,18 +18581,19 @@ static int llama_lora_adapter_init_internal(const struct llama_model * model, co
if (str_endswith(name, ".lora_a")) {
replace_all(name, ".lora_a", "");
if (ab_map.find(name) == ab_map.end()) {
ab_map[name] = lora_weight(cur, nullptr);
ab_map[name] = llama_lora_weight(cur, nullptr);
} else {
ab_map[name].a = cur;
}
} else if (str_endswith(name, ".lora_b")) {
replace_all(name, ".lora_b", "");
if (ab_map.find(name) == ab_map.end()) {
ab_map[name] = lora_weight(nullptr, cur);
ab_map[name] = llama_lora_weight(nullptr, cur);
} else {
ab_map[name].b = cur;
}
} else {
// maybe "optimizer.*"" tensors
LLAMA_LOG_WARN("%s: discard tensor '%s'\n", __func__, cur->name);
}
}
@ -18603,28 +18602,26 @@ static int llama_lora_adapter_init_internal(const struct llama_model * model, co
for (auto & it : ab_map) {
std::string name = it.first;
const char * cname = name.c_str();
lora_weight & w = it.second;
llama_lora_weight & w = it.second;
GGML_ASSERT(w.a != nullptr);
GGML_ASSERT(w.b != nullptr);
int il = -1;
sscanf(cname, "blk.%d.", &il);
struct ggml_context * dev_ctx; // device ctx
if (il >= 0) {
dev_ctx = ctx_map.at(model->buft_layer[il].buft);
} else if (strstr(cname, "tok") == 0) {
dev_ctx = ctx_map.at(model->buft_input.buft);
} else if (strstr(cname, "output") == 0) {
dev_ctx = ctx_map.at(model->buft_output.buft);
} else {
LLAMA_LOG_WARN("%s: discard tensor '%s'\n", __func__, cname);
continue;
// device buft and device ctx
auto model_tensor = llama_get_model_tensor(model, cname);
if (!model_tensor) {
gguf_free(ctx_gguf);
ggml_free(ctx);
throw std::runtime_error("LoRA tensor '" + name + "' does not exist in base model\n");
}
struct ggml_context * dev_ctx = ctx_map.at(ggml_backend_buffer_get_type(model_tensor->buffer));
// TODO: validate tensor shape
// LLAMA_LOG_INFO("%s %p %p\n", cname, w.a, w.b);
struct ggml_tensor * tensor_a = ggml_dup_tensor(dev_ctx, w.a);
struct ggml_tensor * tensor_b = ggml_dup_tensor(dev_ctx, w.b);
ggml_set_name(tensor_a, w.a->name);
ggml_set_name(tensor_b, w.b->name);
adapter.ab_map[name] = lora_weight(tensor_a, tensor_b);
adapter.ab_map[name] = llama_lora_weight(tensor_a, tensor_b);
}
// allocate tensors / buffers and zero
@ -18636,8 +18633,9 @@ static int llama_lora_adapter_init_internal(const struct llama_model * model, co
ggml_context * ctx_dev = it.second;
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx_dev, buft);
if (!buf) {
LLAMA_LOG_ERROR("%s: failed to allocate buffer for lora adapter\n", __func__);
return -1;
gguf_free(ctx_gguf);
ggml_free(ctx);
throw std::runtime_error("failed to allocate buffer for lora adapter\n");
}
ggml_backend_buffer_clear(buf, 0);
adapter.ctxs.push_back(ctx_dev);
@ -18671,14 +18669,18 @@ static int llama_lora_adapter_init_internal(const struct llama_model * model, co
LLAMA_LOG_INFO("%s: loaded %ld tensors from lora file\n", __func__, adapter.ab_map.size()*2);
// free ctx for reading gguf
gguf_free(ctx_gguf);
ggml_free(ctx);
return 0;
}
int32_t llama_lora_adapter_set(
struct llama_context * ctx,
struct llama_lora_adapter * adapter,
float scale) {
if (ctx->cparams.flash_attn) {
LLAMA_LOG_ERROR("%s: flash_attn is not compatible with LoRA\n", __func__);
return -1;
}
ctx->lora_adapters[adapter] = scale;
return 0;
}
@ -19479,8 +19481,8 @@ uint32_t llama_model_quantize(
struct llama_lora_adapter * llama_lora_adapter_init(struct llama_model * model, const char * path_lora) {
try {
struct llama_lora_adapter * adapter = new llama_lora_adapter(model);
int res = llama_lora_adapter_init_internal(model, path_lora, *adapter);
return res == 0 ? adapter : nullptr;
llama_lora_adapter_init_internal(model, path_lora, *adapter);
return adapter;
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: failed to apply lora adapter: %s\n", __func__, err.what());
return nullptr;