mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 10:54:36 +00:00
lora : add support for non-llama models (#3333)
* lora : add support for non-llama models ggml-ci * avoid leaking ggml_context on failure cleanup ggml-ci * lora : allow 1d tensors * lora : include embd and output layers in size calculation * fix style
This commit is contained in:
parent
8a5be3bd58
commit
c6c4fc081c
@ -3,7 +3,6 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
|
||||||
import struct
|
import struct
|
||||||
import sys
|
import sys
|
||||||
from typing import Any, BinaryIO, Sequence
|
from typing import Any, BinaryIO, Sequence
|
||||||
@ -11,43 +10,15 @@ from typing import Any, BinaryIO, Sequence
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
if 'NO_LOCAL_GGUF' not in os.environ:
|
||||||
|
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py' / 'gguf'))
|
||||||
|
import gguf
|
||||||
|
|
||||||
|
|
||||||
NUMPY_TYPE_TO_FTYPE: dict[str, int] = {"float32": 0, "float16": 1}
|
NUMPY_TYPE_TO_FTYPE: dict[str, int] = {"float32": 0, "float16": 1}
|
||||||
|
|
||||||
|
|
||||||
HF_SUBLAYER_TO_GGML = {
|
|
||||||
"self_attn.q_proj": "attn_q",
|
|
||||||
"self_attn.k_proj": "attn_k",
|
|
||||||
"self_attn.v_proj": "attn_v",
|
|
||||||
"self_attn.o_proj": "attn_output",
|
|
||||||
"mlp.gate_proj": "ffn_gate",
|
|
||||||
"mlp.down_proj": "ffn_down",
|
|
||||||
"mlp.up_proj": "ffn_up",
|
|
||||||
"input_layernorm": "attn_norm",
|
|
||||||
"post_attention_layernorm": "ffn_norm",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def translate_tensor_name(t: str) -> str:
|
|
||||||
match = re.match(r".*layers\.(\d+)\.(\w+\.\w+)\.lora_(A|B)\.weight", t)
|
|
||||||
if match:
|
|
||||||
nn = match.group(1)
|
|
||||||
sub_layer = match.group(2)
|
|
||||||
lora_type = match.group(3)
|
|
||||||
|
|
||||||
sub_layer_renamed = HF_SUBLAYER_TO_GGML.get(sub_layer)
|
|
||||||
if sub_layer_renamed is None:
|
|
||||||
print(f"Error: unrecognized sub-layer {sub_layer} in tensor {t}")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
output_string = (
|
|
||||||
f"blk.{nn}.{HF_SUBLAYER_TO_GGML[sub_layer]}.weight.lora{lora_type}"
|
|
||||||
)
|
|
||||||
return output_string
|
|
||||||
else:
|
|
||||||
print(f"Error: unrecognized tensor {t}")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
def write_file_header(fout: BinaryIO, params: dict[str, Any]) -> None:
|
def write_file_header(fout: BinaryIO, params: dict[str, Any]) -> None:
|
||||||
fout.write(b"ggla"[::-1]) # magic (ggml lora)
|
fout.write(b"ggla"[::-1]) # magic (ggml lora)
|
||||||
fout.write(struct.pack("i", 1)) # file version
|
fout.write(struct.pack("i", 1)) # file version
|
||||||
@ -61,9 +32,7 @@ def write_file_header(fout: BinaryIO, params: dict[str, Any]) -> None:
|
|||||||
fout.write(struct.pack("i", int(params["lora_alpha"])))
|
fout.write(struct.pack("i", int(params["lora_alpha"])))
|
||||||
|
|
||||||
|
|
||||||
def write_tensor_header(
|
def write_tensor_header(fout: BinaryIO, name: str, shape: Sequence[int], data_type: np.dtype[Any]) -> None:
|
||||||
self, name: str, shape: Sequence[int], data_type: np.dtype[Any]
|
|
||||||
) -> None:
|
|
||||||
sname = name.encode("utf-8")
|
sname = name.encode("utf-8")
|
||||||
fout.write(
|
fout.write(
|
||||||
struct.pack(
|
struct.pack(
|
||||||
@ -78,11 +47,12 @@ def write_tensor_header(
|
|||||||
fout.seek((fout.tell() + 31) & -32)
|
fout.seek((fout.tell() + 31) & -32)
|
||||||
|
|
||||||
|
|
||||||
if len(sys.argv) != 2:
|
if len(sys.argv) < 2:
|
||||||
print(f"Usage: python {sys.argv[0]} <path>")
|
print(f"Usage: python {sys.argv[0]} <path> [arch]")
|
||||||
print(
|
print(
|
||||||
"Path must contain HuggingFace PEFT LoRA files 'adapter_config.json' and 'adapter_model.bin'"
|
"Path must contain HuggingFace PEFT LoRA files 'adapter_config.json' and 'adapter_model.bin'"
|
||||||
)
|
)
|
||||||
|
print(f"Arch must be one of {list(gguf.MODEL_ARCH_NAMES.values())} (default: llama)")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
input_json = os.path.join(sys.argv[1], "adapter_config.json")
|
input_json = os.path.join(sys.argv[1], "adapter_config.json")
|
||||||
@ -90,6 +60,14 @@ input_model = os.path.join(sys.argv[1], "adapter_model.bin")
|
|||||||
output_path = os.path.join(sys.argv[1], "ggml-adapter-model.bin")
|
output_path = os.path.join(sys.argv[1], "ggml-adapter-model.bin")
|
||||||
|
|
||||||
model = torch.load(input_model, map_location="cpu")
|
model = torch.load(input_model, map_location="cpu")
|
||||||
|
arch_name = sys.argv[2] if len(sys.argv) == 3 else "llama"
|
||||||
|
|
||||||
|
if arch_name not in gguf.MODEL_ARCH_NAMES.values():
|
||||||
|
print(f"Error: unsupported architecture {arch_name}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
arch = list(gguf.MODEL_ARCH_NAMES.keys())[list(gguf.MODEL_ARCH_NAMES.values()).index(arch_name)]
|
||||||
|
name_map = gguf.TensorNameMap(arch, 200) # 200 layers ought to be enough for anyone
|
||||||
|
|
||||||
with open(input_json, "r") as f:
|
with open(input_json, "r") as f:
|
||||||
params = json.load(f)
|
params = json.load(f)
|
||||||
@ -117,6 +95,7 @@ with open(output_path, "wb") as fout:
|
|||||||
|
|
||||||
write_file_header(fout, params)
|
write_file_header(fout, params)
|
||||||
for k, v in model.items():
|
for k, v in model.items():
|
||||||
|
orig_k = k
|
||||||
if k.endswith(".default.weight"):
|
if k.endswith(".default.weight"):
|
||||||
k = k.replace(".default.weight", ".weight")
|
k = k.replace(".default.weight", ".weight")
|
||||||
if k in ["llama_proj.weight", "llama_proj.bias"]:
|
if k in ["llama_proj.weight", "llama_proj.bias"]:
|
||||||
@ -129,7 +108,32 @@ with open(output_path, "wb") as fout:
|
|||||||
v = v.float()
|
v = v.float()
|
||||||
|
|
||||||
t = v.detach().numpy()
|
t = v.detach().numpy()
|
||||||
tname = translate_tensor_name(k)
|
|
||||||
|
prefix = "base_model.model."
|
||||||
|
if k.startswith(prefix):
|
||||||
|
k = k[len(prefix) :]
|
||||||
|
|
||||||
|
lora_suffixes = (".lora_A.weight", ".lora_B.weight")
|
||||||
|
if k.endswith(lora_suffixes):
|
||||||
|
suffix = k[-len(lora_suffixes[0]):]
|
||||||
|
k = k[: -len(lora_suffixes[0])]
|
||||||
|
else:
|
||||||
|
print(f"Error: unrecognized tensor name {orig_k}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
tname = name_map.get_name(k)
|
||||||
|
if tname is None:
|
||||||
|
print(f"Error: could not map tensor name {orig_k}")
|
||||||
|
print(" Note: the arch parameter must be specified if the model is not llama")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if suffix == ".lora_A.weight":
|
||||||
|
tname += ".weight.loraA"
|
||||||
|
elif suffix == ".lora_B.weight":
|
||||||
|
tname += ".weight.loraB"
|
||||||
|
else:
|
||||||
|
assert False
|
||||||
|
|
||||||
print(f"{k} => {tname} {t.shape} {t.dtype} {t.nbytes/1024/1024:.2f}MB")
|
print(f"{k} => {tname} {t.shape} {t.dtype} {t.nbytes/1024/1024:.2f}MB")
|
||||||
write_tensor_header(fout, tname, t.shape, t.dtype)
|
write_tensor_header(fout, tname, t.shape, t.dtype)
|
||||||
t.tofile(fout)
|
t.tofile(fout)
|
||||||
|
133
llama.cpp
133
llama.cpp
@ -8647,53 +8647,60 @@ static int llama_apply_lora_from_file_internal(
|
|||||||
|
|
||||||
const int64_t t_start_lora_us = ggml_time_us();
|
const int64_t t_start_lora_us = ggml_time_us();
|
||||||
|
|
||||||
auto fin = std::ifstream(path_lora, std::ios::binary);
|
llama_file fin(path_lora, "rb");
|
||||||
if (!fin) {
|
|
||||||
LLAMA_LOG_ERROR("%s: failed to open '%s'\n", __func__, path_lora);
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// verify magic and version
|
// verify magic and version
|
||||||
{
|
{
|
||||||
uint32_t magic;
|
uint32_t magic = fin.read_u32();
|
||||||
fin.read((char *) &magic, sizeof(magic));
|
if (magic != LLAMA_FILE_MAGIC_GGLA) {
|
||||||
uint32_t format_version;
|
LLAMA_LOG_ERROR("%s: bad file magic\n", __func__);
|
||||||
fin.read((char *) &format_version, sizeof(format_version));
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t format_version = fin.read_u32();
|
||||||
if (format_version != 1) {
|
if (format_version != 1) {
|
||||||
LLAMA_LOG_ERROR("%s: unsupported file version\n", __func__ );
|
LLAMA_LOG_ERROR("%s: unsupported file version\n", __func__ );
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t lora_r;
|
int32_t lora_r = fin.read_u32();
|
||||||
int32_t lora_alpha;
|
int32_t lora_alpha = fin.read_u32();
|
||||||
fin.read((char *) &lora_r, sizeof(lora_r));
|
|
||||||
fin.read((char *) &lora_alpha, sizeof(lora_alpha));
|
|
||||||
float scaling = scale * (float)lora_alpha / (float)lora_r;
|
float scaling = scale * (float)lora_alpha / (float)lora_r;
|
||||||
|
|
||||||
LLAMA_LOG_INFO("%s: r = %d, alpha = %d, scaling = %.2f\n", __func__, lora_r, lora_alpha, scaling);
|
LLAMA_LOG_INFO("%s: r = %d, alpha = %d, scaling = %.2f\n", __func__, lora_r, lora_alpha, scaling);
|
||||||
|
|
||||||
|
// create a name -> tensor map of the model to accelerate lookups
|
||||||
|
// find the max tensor size to estimate the required temporary buffer size
|
||||||
|
size_t max_tensor_size = 0;
|
||||||
|
std::unordered_map<std::string, struct ggml_tensor*> model_tensors;
|
||||||
|
for (const auto & kv : model.tensors_by_name) {
|
||||||
|
model_tensors.insert(kv);
|
||||||
|
size_t f32_size = ggml_nelements(kv.second) * sizeof(float);
|
||||||
|
max_tensor_size = std::max(max_tensor_size, f32_size);
|
||||||
|
}
|
||||||
|
|
||||||
// create a temporary ggml context to store the lora tensors
|
// create a temporary ggml context to store the lora tensors
|
||||||
// todo: calculate size from biggest possible tensor
|
// TODO: use ggml-alloc
|
||||||
std::vector<uint8_t> lora_buf(1024ull * 1024ull * 1024ull);
|
size_t lora_ctx_size = max_tensor_size * 3;
|
||||||
|
LLAMA_LOG_INFO("%s: allocating %.f MB for lora temporary buffer\n", __func__, lora_ctx_size / 1024.0 / 1024.0);
|
||||||
|
std::vector<uint8_t> lora_buf(lora_ctx_size);
|
||||||
|
|
||||||
struct ggml_init_params params;
|
struct ggml_init_params params;
|
||||||
params.mem_size = lora_buf.size();
|
params.mem_size = lora_buf.size();
|
||||||
params.mem_buffer = lora_buf.data();
|
params.mem_buffer = lora_buf.data();
|
||||||
params.no_alloc = false;
|
params.no_alloc = false;
|
||||||
|
|
||||||
ggml_context * lora_ctx = ggml_init(params);
|
using unique_context = std::unique_ptr<ggml_context, decltype(&ggml_free)>;
|
||||||
std::unordered_map<std::string, struct ggml_tensor *> lora_tensors;
|
|
||||||
|
|
||||||
// create a name -> tensor map of the model to accelerate lookups
|
unique_context lora_ctx(nullptr, ggml_free);
|
||||||
std::unordered_map<std::string, struct ggml_tensor*> model_tensors;
|
lora_ctx.reset(ggml_init(params));
|
||||||
for (const auto & kv : model.tensors_by_name) {
|
std::unordered_map<std::string, struct ggml_tensor *> lora_tensors;
|
||||||
model_tensors.insert(kv);
|
|
||||||
}
|
|
||||||
|
|
||||||
// load base model
|
// load base model
|
||||||
std::unique_ptr<llama_model_loader> ml;
|
std::unique_ptr<llama_model_loader> ml;
|
||||||
ggml_context * base_ctx = NULL;
|
|
||||||
|
unique_context base_ctx(nullptr, ggml_free);
|
||||||
std::vector<uint8_t> base_buf;
|
std::vector<uint8_t> base_buf;
|
||||||
if (path_base_model) {
|
if (path_base_model) {
|
||||||
LLAMA_LOG_INFO("%s: loading base model from '%s'\n", __func__, path_base_model);
|
LLAMA_LOG_INFO("%s: loading base model from '%s'\n", __func__, path_base_model);
|
||||||
@ -8702,6 +8709,7 @@ static int llama_apply_lora_from_file_internal(
|
|||||||
size_t ctx_size;
|
size_t ctx_size;
|
||||||
size_t mmapped_size;
|
size_t mmapped_size;
|
||||||
ml->calc_sizes(ctx_size, mmapped_size);
|
ml->calc_sizes(ctx_size, mmapped_size);
|
||||||
|
|
||||||
base_buf.resize(ctx_size);
|
base_buf.resize(ctx_size);
|
||||||
|
|
||||||
ggml_init_params base_params;
|
ggml_init_params base_params;
|
||||||
@ -8709,9 +8717,9 @@ static int llama_apply_lora_from_file_internal(
|
|||||||
base_params.mem_buffer = base_buf.data();
|
base_params.mem_buffer = base_buf.data();
|
||||||
base_params.no_alloc = ml->use_mmap;
|
base_params.no_alloc = ml->use_mmap;
|
||||||
|
|
||||||
base_ctx = ggml_init(base_params);
|
base_ctx.reset(ggml_init(base_params));
|
||||||
|
|
||||||
// maybe this should in llama_model_loader
|
// maybe this should be in llama_model_loader
|
||||||
if (ml->use_mmap) {
|
if (ml->use_mmap) {
|
||||||
ml->mapping.reset(new llama_mmap(&ml->file, /* prefetch */ 0, ggml_is_numa()));
|
ml->mapping.reset(new llama_mmap(&ml->file, /* prefetch */ 0, ggml_is_numa()));
|
||||||
}
|
}
|
||||||
@ -8724,27 +8732,35 @@ static int llama_apply_lora_from_file_internal(
|
|||||||
std::vector<uint8_t> work_buffer;
|
std::vector<uint8_t> work_buffer;
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
|
if (fin.tell() == fin.size) {
|
||||||
|
// eof
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
int32_t n_dims;
|
int32_t n_dims;
|
||||||
int32_t length;
|
int32_t name_len;
|
||||||
int32_t ftype;
|
int32_t ftype;
|
||||||
|
|
||||||
fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
|
fin.read_raw(&n_dims, sizeof(n_dims));
|
||||||
fin.read(reinterpret_cast<char *>(&length), sizeof(length));
|
fin.read_raw(&name_len, sizeof(name_len));
|
||||||
fin.read(reinterpret_cast<char *>(&ftype), sizeof(ftype));
|
fin.read_raw(&ftype, sizeof(ftype));
|
||||||
if (fin.eof()) {
|
|
||||||
break;
|
if (n_dims != 1 && n_dims != 2) {
|
||||||
|
LLAMA_LOG_ERROR("%s: unsupported tensor dimension %d\n", __func__, n_dims);
|
||||||
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t ne[2] = { 1, 1 };
|
int32_t ne[2] = { 1, 1 };
|
||||||
for (int i = 0; i < n_dims; ++i) {
|
for (int i = 0; i < n_dims; ++i) {
|
||||||
fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
|
fin.read_raw(&ne[i], sizeof(ne[i]));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string name;
|
std::string name;
|
||||||
{
|
{
|
||||||
|
GGML_ASSERT(name_len <= 1024);
|
||||||
char buf[1024];
|
char buf[1024];
|
||||||
fin.read(buf, length);
|
fin.read_raw(buf, name_len);
|
||||||
name = std::string(buf, length);
|
name = std::string(buf, name_len);
|
||||||
}
|
}
|
||||||
|
|
||||||
// check for lora suffix and get the type of tensor
|
// check for lora suffix and get the type of tensor
|
||||||
@ -8758,7 +8774,7 @@ static int llama_apply_lora_from_file_internal(
|
|||||||
std::string lora_type = name.substr(pos + lora_suffix.length());
|
std::string lora_type = name.substr(pos + lora_suffix.length());
|
||||||
std::string base_name = name;
|
std::string base_name = name;
|
||||||
base_name.erase(pos);
|
base_name.erase(pos);
|
||||||
// LLAMA_LOG_INFO("%s: %s => %s (lora type %s) \n", __func__, name.c_str(),base_name.c_str(), lora_type.c_str());
|
// LLAMA_LOG_INFO("%s: %s => %s (lora type %s) \n", __func__, name.c_str(), base_name.c_str(), lora_type.c_str());
|
||||||
|
|
||||||
if (model_tensors.find(base_name) == model_tensors.end()) {
|
if (model_tensors.find(base_name) == model_tensors.end()) {
|
||||||
LLAMA_LOG_ERROR("%s: unknown tensor '%s' in lora adapter\n", __func__, name.data());
|
LLAMA_LOG_ERROR("%s: unknown tensor '%s' in lora adapter\n", __func__, name.data());
|
||||||
@ -8777,22 +8793,15 @@ static int llama_apply_lora_from_file_internal(
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ggml_tensor * lora_tensor;
|
ggml_tensor * lora_tensor = ggml_new_tensor_2d(lora_ctx.get(), wtype, ne[0], ne[1]);
|
||||||
if (n_dims == 2) {
|
ggml_set_name(lora_tensor, name.c_str());
|
||||||
lora_tensor = ggml_new_tensor_2d(lora_ctx, wtype, ne[0], ne[1]);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
LLAMA_LOG_ERROR("%s: unsupported tensor dimension %d\n", __func__, n_dims);
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
ggml_set_name(lora_tensor, "lora_tensor");
|
|
||||||
|
|
||||||
// load tensor data
|
// load tensor data
|
||||||
size_t offset = fin.tellg();
|
size_t offset = fin.tell();
|
||||||
size_t tensor_data_size = ggml_nbytes(lora_tensor);
|
size_t tensor_data_size = ggml_nbytes(lora_tensor);
|
||||||
offset = (offset + 31) & -32;
|
offset = (offset + 31) & -32;
|
||||||
fin.seekg(offset);
|
fin.seek(offset, SEEK_SET);
|
||||||
fin.read((char*)lora_tensor->data, tensor_data_size);
|
fin.read_raw(lora_tensor->data, tensor_data_size);
|
||||||
|
|
||||||
lora_tensors[name] = lora_tensor;
|
lora_tensors[name] = lora_tensor;
|
||||||
|
|
||||||
@ -8822,13 +8831,11 @@ static int llama_apply_lora_from_file_internal(
|
|||||||
|
|
||||||
// load from base model
|
// load from base model
|
||||||
if (gguf_find_tensor(ctx_gguf, base_name.c_str()) < 0) {
|
if (gguf_find_tensor(ctx_gguf, base_name.c_str()) < 0) {
|
||||||
// TODO: throw
|
|
||||||
LLAMA_LOG_ERROR("%s: error: tensor '%s' not found in base model\n", __func__, base_name.c_str());
|
LLAMA_LOG_ERROR("%s: error: tensor '%s' not found in base model\n", __func__, base_name.c_str());
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: not tested!! maybe not working!
|
base_t = ml->create_tensor(base_ctx.get(), base_name, { dest_t->ne[0], dest_t->ne[1] }, GGML_BACKEND_CPU);
|
||||||
base_t = ml->create_tensor(base_ctx, base_name, { (uint32_t)dest_t->ne[0], (uint32_t)dest_t->ne[1] }, GGML_BACKEND_CPU);
|
|
||||||
ml->load_data_for(base_t);
|
ml->load_data_for(base_t);
|
||||||
} else {
|
} else {
|
||||||
base_t = dest_t;
|
base_t = dest_t;
|
||||||
@ -8857,43 +8864,45 @@ static int llama_apply_lora_from_file_internal(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// w = w + BA*s
|
// w = w + BA*s
|
||||||
ggml_tensor * BA = ggml_mul_mat(lora_ctx, loraA, loraB);
|
ggml_tensor * BA = ggml_mul_mat(lora_ctx.get(), loraA, loraB);
|
||||||
offload_func(BA);
|
offload_func(BA);
|
||||||
ggml_set_name(BA, "BA");
|
ggml_set_name(BA, "BA");
|
||||||
|
|
||||||
if (scaling != 1.0f) {
|
if (scaling != 1.0f) {
|
||||||
ggml_tensor * scale_tensor = ggml_new_f32(lora_ctx, scaling);
|
ggml_tensor * scale_tensor = ggml_new_f32(lora_ctx.get(), scaling);
|
||||||
ggml_set_name(scale_tensor, "scale_tensor");
|
ggml_set_name(scale_tensor, "scale_tensor");
|
||||||
|
|
||||||
BA = ggml_scale_inplace(lora_ctx, BA, scale_tensor);
|
BA = ggml_scale_inplace(lora_ctx.get(), BA, scale_tensor);
|
||||||
offload_func(BA);
|
offload_func(BA);
|
||||||
ggml_set_name(BA, "BA_scaled");
|
ggml_set_name(BA, "BA_scaled");
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * r;
|
ggml_tensor * r;
|
||||||
if (base_t == dest_t) {
|
if (base_t == dest_t) {
|
||||||
r = ggml_add_inplace(lora_ctx, dest_t, BA);
|
r = ggml_add_inplace(lora_ctx.get(), dest_t, BA);
|
||||||
offload_func_force_inplace(r);
|
offload_func_force_inplace(r);
|
||||||
ggml_set_name(r, "r_add_inplace");
|
ggml_set_name(r, "r_add_inplace");
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
r = ggml_add(lora_ctx, base_t, BA);
|
r = ggml_add(lora_ctx.get(), base_t, BA);
|
||||||
offload_func(r);
|
offload_func(r);
|
||||||
ggml_set_name(r, "r_add");
|
ggml_set_name(r, "r_add");
|
||||||
|
|
||||||
r = ggml_cpy(lora_ctx, r, dest_t);
|
r = ggml_cpy(lora_ctx.get(), r, dest_t);
|
||||||
offload_func(r);
|
offload_func(r);
|
||||||
ggml_set_name(r, "r_cpy");
|
ggml_set_name(r, "r_cpy");
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_cgraph * gf = ggml_new_graph(lora_ctx);
|
struct ggml_cgraph * gf = ggml_new_graph(lora_ctx.get());
|
||||||
ggml_build_forward_expand(gf, r);
|
ggml_build_forward_expand(gf, r);
|
||||||
|
|
||||||
ggml_graph_compute_helper(work_buffer, gf, n_threads);
|
ggml_graph_compute_helper(work_buffer, gf, n_threads);
|
||||||
|
|
||||||
|
// the tensors in the adapter must be sorted such that loraA and loraB of the same tensor are next to each other
|
||||||
|
GGML_ASSERT(lora_tensors.size() == 2);
|
||||||
|
|
||||||
// we won't need these tensors again, reset the context to save memory
|
// we won't need these tensors again, reset the context to save memory
|
||||||
ggml_free(lora_ctx);
|
lora_ctx.reset(ggml_init(params));
|
||||||
lora_ctx = ggml_init(params);
|
|
||||||
lora_tensors.clear();
|
lora_tensors.clear();
|
||||||
|
|
||||||
n_tensors++;
|
n_tensors++;
|
||||||
@ -8903,12 +8912,6 @@ static int llama_apply_lora_from_file_internal(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: this should be in a destructor, it will leak on failure
|
|
||||||
ggml_free(lora_ctx);
|
|
||||||
if (base_ctx) {
|
|
||||||
ggml_free(base_ctx);
|
|
||||||
}
|
|
||||||
|
|
||||||
const int64_t t_lora_us = ggml_time_us() - t_start_lora_us;
|
const int64_t t_lora_us = ggml_time_us() - t_start_lora_us;
|
||||||
LLAMA_LOG_INFO(" done (%.2f ms)\n", t_lora_us / 1000.0);
|
LLAMA_LOG_INFO(" done (%.2f ms)\n", t_lora_us / 1000.0);
|
||||||
|
|
||||||
|
1
llama.h
1
llama.h
@ -39,6 +39,7 @@
|
|||||||
|
|
||||||
#define LLAMA_MAX_RNG_STATE (64*1024)
|
#define LLAMA_MAX_RNG_STATE (64*1024)
|
||||||
|
|
||||||
|
#define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
|
||||||
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
|
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
|
||||||
|
|
||||||
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
|
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
|
||||||
|
Loading…
Reference in New Issue
Block a user