mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 03:14:35 +00:00
convert_hf : simplify modify_tensors for InternLM2
* convert_lora : lazy conversion * llama : load and use alpha from LoRA adapters
This commit is contained in:
parent
9d96328bdf
commit
8956543c09
@ -2222,13 +2222,6 @@ class InternLM2Model(Model):
|
|||||||
|
|
||||||
special_vocab.add_to_gguf(self.gguf_writer)
|
special_vocab.add_to_gguf(self.gguf_writer)
|
||||||
|
|
||||||
def _hf_permute_qk(self, weights, n_head: int, n_head_kv: int):
|
|
||||||
if n_head_kv is not None and n_head != n_head_kv:
|
|
||||||
n_head = n_head_kv
|
|
||||||
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
|
|
||||||
.swapaxes(1, 2)
|
|
||||||
.reshape(weights.shape))
|
|
||||||
|
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
self.gguf_writer.add_name("InternLM2")
|
self.gguf_writer.add_name("InternLM2")
|
||||||
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
|
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
|
||||||
@ -2248,26 +2241,22 @@ class InternLM2Model(Model):
|
|||||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||||
num_heads = self.hparams["num_attention_heads"]
|
num_heads = self.hparams["num_attention_heads"]
|
||||||
num_kv_heads = self.hparams["num_key_value_heads"]
|
num_kv_heads = self.hparams["num_key_value_heads"]
|
||||||
hidden_size = self.hparams["hidden_size"]
|
n_embd = self.hparams["hidden_size"]
|
||||||
q_per_kv = num_heads // num_kv_heads
|
q_per_kv = num_heads // num_kv_heads
|
||||||
head_dim = hidden_size // num_heads
|
head_dim = n_embd // num_heads
|
||||||
num_groups = num_heads // q_per_kv
|
num_groups = num_heads // q_per_kv
|
||||||
|
|
||||||
qkv_pattern = r"model\.layers\.(\d+)\.attention\.wqkv"
|
if bid is not None and f"model.layers.{bid}.attention.wqkv" in name:
|
||||||
|
|
||||||
if re.match(qkv_pattern, name):
|
|
||||||
bid = re.findall(qkv_pattern, name)[0]
|
|
||||||
qkv = data_torch
|
qkv = data_torch
|
||||||
# qkv = rearrange(qkv.T, " o (g n i) ->o g n i", g=num_groups, n=q_per_kv + 2, i=head_dim)
|
|
||||||
qkv = qkv.T.reshape((-1, num_groups, q_per_kv + 2, head_dim))
|
qkv = qkv.reshape((num_groups, q_per_kv + 2, head_dim, n_embd))
|
||||||
q, k, v = qkv[..., : q_per_kv, :], qkv[..., q_per_kv: q_per_kv + 1, :], qkv[..., q_per_kv + 1: q_per_kv + 2, :]
|
q, k, v = qkv[:, : q_per_kv], qkv[:, -2], qkv[:, -1]
|
||||||
|
|
||||||
# The model weights of q and k equire additional reshape.
|
# The model weights of q and k equire additional reshape.
|
||||||
# q = self._hf_permute_qk(rearrange(q, " o g n i -> o (g n i)").T, num_heads, num_heads)
|
q = LlamaModel.permute(q.reshape((-1, q.shape[-1])), num_heads, num_heads)
|
||||||
q = self._hf_permute_qk(q.reshape((q.shape[0], -1)).T, num_heads, num_heads)
|
k = LlamaModel.permute(k.reshape((-1, k.shape[-1])), num_heads, num_kv_heads)
|
||||||
# k = self._hf_permute_qk(rearrange(k, " o g n i -> o (g n i)").T, num_heads, num_kv_heads)
|
v = v.reshape((-1, v.shape[-1]))
|
||||||
k = self._hf_permute_qk(k.reshape((k.shape[0], -1)).T, num_heads, num_kv_heads)
|
|
||||||
# v = rearrange(v, " o g n i -> o (g n i)").T
|
|
||||||
v = v.reshape((v.shape[0], -1)).T
|
|
||||||
return [
|
return [
|
||||||
(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), q),
|
(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), q),
|
||||||
(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), k),
|
(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), k),
|
||||||
|
@ -8,9 +8,10 @@ import logging
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import json
|
||||||
|
from math import prod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import EllipsisType
|
from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Sequence, SupportsIndex, cast
|
||||||
from typing import TYPE_CHECKING, Callable, Iterable, Iterator, Sequence, SupportsIndex, cast
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -22,7 +23,7 @@ if 'NO_LOCAL_GGUF' not in os.environ:
|
|||||||
import gguf
|
import gguf
|
||||||
|
|
||||||
# reuse model definitions from convert_hf_to_gguf.py
|
# reuse model definitions from convert_hf_to_gguf.py
|
||||||
from convert_hf_to_gguf import Model
|
from convert_hf_to_gguf import LazyTorchTensor, Model
|
||||||
|
|
||||||
logger = logging.getLogger("lora-to-gguf")
|
logger = logging.getLogger("lora-to-gguf")
|
||||||
|
|
||||||
@ -35,37 +36,45 @@ class PartialLoraTensor:
|
|||||||
|
|
||||||
# magic to support tensor shape modifications and splitting
|
# magic to support tensor shape modifications and splitting
|
||||||
class LoraTorchTensor:
|
class LoraTorchTensor:
|
||||||
_lora_A: Tensor
|
_lora_A: Tensor # (n_rank, row_size)
|
||||||
_lora_B: Tensor
|
_lora_B: Tensor # (col_size, n_rank)
|
||||||
_rank: int
|
_rank: int
|
||||||
|
|
||||||
def __init__(self, A: Tensor, B: Tensor):
|
def __init__(self, A: Tensor, B: Tensor):
|
||||||
assert len(A.shape) == len(B.shape)
|
assert len(A.shape) == len(B.shape)
|
||||||
|
assert A.shape[-2] == B.shape[-1]
|
||||||
if A.dtype != B.dtype:
|
if A.dtype != B.dtype:
|
||||||
A = A.to(torch.float32)
|
A = A.to(torch.float32)
|
||||||
B = B.to(torch.float32)
|
B = B.to(torch.float32)
|
||||||
self._lora_A = A
|
self._lora_A = A
|
||||||
self._lora_B = B
|
self._lora_B = B
|
||||||
assert self._lora_A.shape[-2] == self._lora_B.shape[-1]
|
self._rank = B.shape[-1]
|
||||||
self._rank = self._lora_B.shape[-1]
|
|
||||||
|
def get_lora_A_B(self) -> tuple[Tensor, Tensor]:
|
||||||
|
return (self._lora_A, self._lora_B)
|
||||||
|
|
||||||
def __getitem__(
|
def __getitem__(
|
||||||
self,
|
self,
|
||||||
indices: (
|
indices: (
|
||||||
SupportsIndex
|
SupportsIndex
|
||||||
| slice
|
| slice
|
||||||
| tuple[SupportsIndex | slice | EllipsisType | Tensor, ...]
|
| tuple[SupportsIndex | slice | Tensor, ...] # TODO: add ellipsis in the type signature
|
||||||
),
|
),
|
||||||
) -> LoraTorchTensor:
|
) -> LoraTorchTensor:
|
||||||
shape = self.shape
|
shape = self.shape
|
||||||
if isinstance(indices, (SupportsIndex, slice)):
|
if isinstance(indices, SupportsIndex):
|
||||||
if len(shape) > 2:
|
if len(shape) > 2:
|
||||||
return LoraTorchTensor(self._lora_A[indices], self._lora_B[indices])
|
return LoraTorchTensor(self._lora_A[indices], self._lora_B[indices])
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError # can't return a vector
|
||||||
|
elif isinstance(indices, slice):
|
||||||
|
if len(shape) > 2:
|
||||||
|
return LoraTorchTensor(self._lora_A[indices], self._lora_B[indices])
|
||||||
|
else:
|
||||||
|
return LoraTorchTensor(self._lora_A, self._lora_B[indices])
|
||||||
elif isinstance(indices, tuple):
|
elif isinstance(indices, tuple):
|
||||||
assert len(indices) > 0
|
assert len(indices) > 0
|
||||||
if isinstance(indices[-1], EllipsisType):
|
if indices[-1] is Ellipsis:
|
||||||
return self[indices[:-1]]
|
return self[indices[:-1]]
|
||||||
# expand ellipsis
|
# expand ellipsis
|
||||||
indices = tuple(
|
indices = tuple(
|
||||||
@ -73,7 +82,7 @@ class LoraTorchTensor:
|
|||||||
for v in (
|
for v in (
|
||||||
(
|
(
|
||||||
(slice(None, None) for _ in range(len(indices) - 1))
|
(slice(None, None) for _ in range(len(indices) - 1))
|
||||||
if isinstance(i, EllipsisType)
|
if i is Ellipsis
|
||||||
else (i,)
|
else (i,)
|
||||||
)
|
)
|
||||||
for i in indices
|
for i in indices
|
||||||
@ -85,11 +94,14 @@ class LoraTorchTensor:
|
|||||||
indices = (*indices, *(slice(None, None) for _ in range(len(indices), len(shape))))
|
indices = (*indices, *(slice(None, None) for _ in range(len(indices), len(shape))))
|
||||||
|
|
||||||
# TODO: make sure this is correct
|
# TODO: make sure this is correct
|
||||||
# lora_A has a shape which looks like (..., 1, 1, rank, self.shape[-1])
|
|
||||||
indices_A = (
|
indices_A = (
|
||||||
*(
|
*(
|
||||||
0 if isinstance(i, SupportsIndex) else slice(None, None)
|
(
|
||||||
for i in indices[:-2]
|
j.__index__() % self._lora_A.shape[i]
|
||||||
|
if isinstance(j, SupportsIndex)
|
||||||
|
else slice(None, None)
|
||||||
|
)
|
||||||
|
for i, j in enumerate(indices[:-2])
|
||||||
),
|
),
|
||||||
slice(None, None),
|
slice(None, None),
|
||||||
indices[-1],
|
indices[-1],
|
||||||
@ -97,7 +109,7 @@ class LoraTorchTensor:
|
|||||||
indices_B = indices[:-1]
|
indices_B = indices[:-1]
|
||||||
return LoraTorchTensor(self._lora_A[indices_A], self._lora_B[indices_B])
|
return LoraTorchTensor(self._lora_A[indices_A], self._lora_B[indices_B])
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError # unknown indice type
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dtype(self) -> torch.dtype:
|
def dtype(self) -> torch.dtype:
|
||||||
@ -106,23 +118,37 @@ class LoraTorchTensor:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def shape(self) -> tuple[int, ...]:
|
def shape(self) -> tuple[int, ...]:
|
||||||
|
assert len(self._lora_A.shape) == len(self._lora_B.shape)
|
||||||
return (*self._lora_B.shape[:-1], self._lora_A.shape[-1])
|
return (*self._lora_B.shape[:-1], self._lora_A.shape[-1])
|
||||||
|
|
||||||
def size(self, dim=None):
|
def size(self, dim=None):
|
||||||
assert dim is None
|
assert dim is None
|
||||||
return self.shape
|
return self.shape
|
||||||
|
|
||||||
def reshape(self, *shape: int | tuple[int]) -> LoraTorchTensor:
|
def reshape(self, *shape: int | tuple[int, ...]) -> LoraTorchTensor:
|
||||||
if isinstance(shape[0], tuple):
|
if isinstance(shape[0], tuple):
|
||||||
new_shape: tuple[int] = shape[0]
|
new_shape: tuple[int, ...] = shape[0]
|
||||||
else:
|
else:
|
||||||
new_shape = cast(tuple[int], shape)
|
new_shape = cast(tuple[int, ...], shape)
|
||||||
orig_shape = self.shape
|
orig_shape = self.shape
|
||||||
|
if len(new_shape) < 2:
|
||||||
|
raise NotImplementedError # can't become a vector
|
||||||
|
|
||||||
|
# expand -1 in the shape
|
||||||
|
if any(dim == -1 for dim in new_shape):
|
||||||
|
n_elems = prod(orig_shape)
|
||||||
|
n_new_elems = prod(dim if dim != -1 else 1 for dim in new_shape)
|
||||||
|
assert n_elems % n_new_elems == 0
|
||||||
|
new_shape = (*(dim if dim != -1 else n_elems // n_new_elems for dim in new_shape),)
|
||||||
|
|
||||||
if new_shape[-1] != orig_shape[-1]:
|
if new_shape[-1] != orig_shape[-1]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError # can't reshape the row size trivially
|
||||||
|
|
||||||
|
shape_A = (*(1 for _ in new_shape[:-2]), self._rank, orig_shape[-1])
|
||||||
|
shape_B = (*new_shape[:-1], self._rank)
|
||||||
return LoraTorchTensor(
|
return LoraTorchTensor(
|
||||||
self._lora_A.reshape((*(1 for _ in new_shape[:-2]), *self._lora_A.shape[-2:])),
|
self._lora_A.reshape(shape_A),
|
||||||
self._lora_B.reshape((*new_shape[:-1], self._rank)),
|
self._lora_B.reshape(shape_B),
|
||||||
)
|
)
|
||||||
|
|
||||||
def reshape_as(self, other: Tensor) -> LoraTorchTensor:
|
def reshape_as(self, other: Tensor) -> LoraTorchTensor:
|
||||||
@ -134,12 +160,15 @@ class LoraTorchTensor:
|
|||||||
def permute(self, *dims: int) -> LoraTorchTensor:
|
def permute(self, *dims: int) -> LoraTorchTensor:
|
||||||
shape = self.shape
|
shape = self.shape
|
||||||
dims = tuple(dim - len(shape) if dim >= 0 else dim for dim in dims)
|
dims = tuple(dim - len(shape) if dim >= 0 else dim for dim in dims)
|
||||||
if dims[-1] == -2 and dims[-2] == -1:
|
if dims[-1] == -1:
|
||||||
return LoraTorchTensor(self._lora_B.permute(*dims), self._lora_A.permute(*dims))
|
# TODO: support higher dimensional A shapes bigger than 1
|
||||||
else:
|
|
||||||
assert dims[-1] == -1
|
|
||||||
assert all(dim == 1 for dim in self._lora_A.shape[:-2])
|
assert all(dim == 1 for dim in self._lora_A.shape[:-2])
|
||||||
return LoraTorchTensor(self._lora_A, self._lora_B.permute(*dims))
|
return LoraTorchTensor(self._lora_A, self._lora_B.permute(*dims))
|
||||||
|
if len(shape) == 2 and dims[-1] == -2 and dims[-2] == -1:
|
||||||
|
return LoraTorchTensor(self._lora_B.permute(*dims), self._lora_A.permute(*dims))
|
||||||
|
else:
|
||||||
|
# TODO: compose the above two
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def transpose(self, dim0: int, dim1: int) -> LoraTorchTensor:
|
def transpose(self, dim0: int, dim1: int) -> LoraTorchTensor:
|
||||||
shape = self.shape
|
shape = self.shape
|
||||||
@ -181,11 +210,13 @@ class LoraTorchTensor:
|
|||||||
torch.cat([a._lora_A for a in args[0]], dim),
|
torch.cat([a._lora_A for a in args[0]], dim),
|
||||||
torch.cat([b._lora_B for b in args[0]], dim),
|
torch.cat([b._lora_B for b in args[0]], dim),
|
||||||
)
|
)
|
||||||
else:
|
elif all(torch.equal(args[0][0]._lora_A, t._lora_A) for t in args[0][1:]):
|
||||||
return LoraTorchTensor(
|
return LoraTorchTensor(
|
||||||
args[0][0]._lora_A, # TODO: is this correct? (can't cat over the rank)
|
args[0][0]._lora_A,
|
||||||
torch.cat([b._lora_B for b in args[0]], dim),
|
torch.cat([b._lora_B for b in args[0]], dim),
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@ -205,13 +236,17 @@ def parse_args() -> argparse.Namespace:
|
|||||||
help="path to write to; default: based on input. {ftype} will be replaced by the outtype.",
|
help="path to write to; default: based on input. {ftype} will be replaced by the outtype.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0"], default="f16",
|
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f16",
|
||||||
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0",
|
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bigendian", action="store_true",
|
"--bigendian", action="store_true",
|
||||||
help="model is executed on big endian machine",
|
help="model is executed on big endian machine",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-lazy", action="store_true",
|
||||||
|
help="use more RAM by computing all outputs before writing (use in case lazy evaluation is broken)",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--verbose", action="store_true",
|
"--verbose", action="store_true",
|
||||||
help="increase output verbosity",
|
help="increase output verbosity",
|
||||||
@ -237,13 +272,16 @@ if __name__ == '__main__':
|
|||||||
"f16": gguf.LlamaFileType.MOSTLY_F16,
|
"f16": gguf.LlamaFileType.MOSTLY_F16,
|
||||||
"bf16": gguf.LlamaFileType.MOSTLY_BF16,
|
"bf16": gguf.LlamaFileType.MOSTLY_BF16,
|
||||||
"q8_0": gguf.LlamaFileType.MOSTLY_Q8_0,
|
"q8_0": gguf.LlamaFileType.MOSTLY_Q8_0,
|
||||||
|
"auto": gguf.LlamaFileType.GUESSED,
|
||||||
}
|
}
|
||||||
|
|
||||||
ftype = ftype_map[args.outtype]
|
ftype = ftype_map[args.outtype]
|
||||||
|
|
||||||
dir_base_model = args.base
|
dir_base_model: Path = args.base
|
||||||
dir_lora = args.lora_path
|
dir_lora: Path = args.lora_path
|
||||||
input_json = os.path.join(dir_lora, "adapter_config.json")
|
lora_config = dir_lora / "adapter_config.json"
|
||||||
input_model = os.path.join(dir_lora, "adapter_model.safetensors")
|
input_model = dir_lora / "adapter_model.safetensors"
|
||||||
|
|
||||||
if args.outfile is not None:
|
if args.outfile is not None:
|
||||||
fname_out = args.outfile
|
fname_out = args.outfile
|
||||||
else:
|
else:
|
||||||
@ -276,6 +314,8 @@ if __name__ == '__main__':
|
|||||||
tensor_map: dict[str, PartialLoraTensor] = {}
|
tensor_map: dict[str, PartialLoraTensor] = {}
|
||||||
|
|
||||||
for name, tensor in lora_model.items():
|
for name, tensor in lora_model.items():
|
||||||
|
if self.lazy:
|
||||||
|
tensor = LazyTorchTensor.from_eager(tensor)
|
||||||
base_name = get_base_tensor_name(name)
|
base_name = get_base_tensor_name(name)
|
||||||
is_lora_a = ".lora_A.weight" in name
|
is_lora_a = ".lora_A.weight" in name
|
||||||
is_lora_b = ".lora_B.weight" in name
|
is_lora_b = ".lora_B.weight" in name
|
||||||
@ -305,16 +345,30 @@ if __name__ == '__main__':
|
|||||||
dest = super().modify_tensors(data_torch, name, bid)
|
dest = super().modify_tensors(data_torch, name, bid)
|
||||||
for dest_name, dest_data in dest:
|
for dest_name, dest_data in dest:
|
||||||
assert isinstance(dest_data, LoraTorchTensor)
|
assert isinstance(dest_data, LoraTorchTensor)
|
||||||
# logger.info(f"{orig_name} --> {dest_name}")
|
lora_a, lora_b = dest_data.get_lora_A_B()
|
||||||
yield (dest_name + ".lora_a", dest_data._lora_A)
|
|
||||||
yield (dest_name + ".lora_b", dest_data._lora_B)
|
|
||||||
|
|
||||||
model_instance = LoraModel(dir_base_model, ftype, fname_out, args.bigendian, False, False, None)
|
yield (dest_name + ".lora_a", lora_a)
|
||||||
|
yield (dest_name + ".lora_b", lora_b)
|
||||||
|
|
||||||
|
model_instance = LoraModel(
|
||||||
|
dir_base_model,
|
||||||
|
ftype,
|
||||||
|
fname_out,
|
||||||
|
is_big_endian=args.bigendian,
|
||||||
|
use_temp_file=False,
|
||||||
|
eager=args.no_lazy,
|
||||||
|
model_name=None,
|
||||||
|
)
|
||||||
logger.info("Set model parameters")
|
logger.info("Set model parameters")
|
||||||
model_instance.set_gguf_parameters()
|
model_instance.set_gguf_parameters()
|
||||||
|
|
||||||
# adapter_config = json.load(input_json)
|
with open(lora_config, "r") as f:
|
||||||
|
lparams: dict[str, Any] = json.load(f)
|
||||||
|
|
||||||
|
alpha = lparams["lora_alpha"]
|
||||||
|
|
||||||
model_instance.gguf_writer.add_string("training.type", "finetune_lora")
|
model_instance.gguf_writer.add_string("training.type", "finetune_lora")
|
||||||
|
model_instance.gguf_writer.add_float32("training.lora.alpha", float(alpha))
|
||||||
|
|
||||||
model_instance.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION)
|
model_instance.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION)
|
||||||
logger.info("Exporting model...")
|
logger.info("Exporting model...")
|
||||||
|
@ -43,7 +43,7 @@ def __apply_over_grouped_rows(func: Callable[[np.ndarray], np.ndarray], arr: np.
|
|||||||
osize *= dim
|
osize *= dim
|
||||||
out = np.empty(shape=osize, dtype=otype)
|
out = np.empty(shape=osize, dtype=otype)
|
||||||
# compute over groups of 16 rows (arbitrary, but seems good for performance)
|
# compute over groups of 16 rows (arbitrary, but seems good for performance)
|
||||||
n_groups = rows.shape[0] // 16
|
n_groups = (rows.shape[0] // 16) or 1
|
||||||
np.concatenate([func(group).ravel() for group in np.array_split(rows, n_groups)], axis=0, out=out)
|
np.concatenate([func(group).ravel() for group in np.array_split(rows, n_groups)], axis=0, out=out)
|
||||||
return out.reshape(oshape)
|
return out.reshape(oshape)
|
||||||
|
|
||||||
|
@ -379,6 +379,7 @@ enum llm_kv {
|
|||||||
LLM_KV_TOKENIZER_EOT_ID,
|
LLM_KV_TOKENIZER_EOT_ID,
|
||||||
|
|
||||||
LLM_KV_TRAINING_TYPE,
|
LLM_KV_TRAINING_TYPE,
|
||||||
|
LLM_KV_TRAINING_LORA_ALPHA,
|
||||||
};
|
};
|
||||||
|
|
||||||
static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||||
@ -473,7 +474,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
|||||||
{ LLM_KV_TOKENIZER_MIDDLE_ID, "tokenizer.ggml.middle_token_id" },
|
{ LLM_KV_TOKENIZER_MIDDLE_ID, "tokenizer.ggml.middle_token_id" },
|
||||||
{ LLM_KV_TOKENIZER_EOT_ID, "tokenizer.ggml.eot_token_id" },
|
{ LLM_KV_TOKENIZER_EOT_ID, "tokenizer.ggml.eot_token_id" },
|
||||||
|
|
||||||
{ LLM_KV_TRAINING_TYPE, "training.type" },
|
{ LLM_KV_TRAINING_TYPE, "training.type" },
|
||||||
|
{ LLM_KV_TRAINING_LORA_ALPHA, "training.lora.alpha" },
|
||||||
};
|
};
|
||||||
|
|
||||||
struct LLM_KV {
|
struct LLM_KV {
|
||||||
@ -2848,6 +2850,8 @@ struct llama_lora_adapter {
|
|||||||
std::vector<struct ggml_context *> ctxs;
|
std::vector<struct ggml_context *> ctxs;
|
||||||
std::vector<ggml_backend_buffer_t> bufs;
|
std::vector<ggml_backend_buffer_t> bufs;
|
||||||
|
|
||||||
|
float alpha;
|
||||||
|
|
||||||
llama_lora_adapter(struct llama_model * base_model): base_model(base_model) {
|
llama_lora_adapter(struct llama_model * base_model): base_model(base_model) {
|
||||||
base_model->lora_adapters.insert(this);
|
base_model->lora_adapters.insert(this);
|
||||||
}
|
}
|
||||||
@ -7878,10 +7882,12 @@ static struct ggml_tensor * llm_build_lora_mm(
|
|||||||
struct ggml_tensor * res = ggml_mul_mat(ctx0, w, cur);
|
struct ggml_tensor * res = ggml_mul_mat(ctx0, w, cur);
|
||||||
for (auto & it : lctx.lora_adapters) {
|
for (auto & it : lctx.lora_adapters) {
|
||||||
struct llama_lora_weight * lora = it.first->get_weight(w);
|
struct llama_lora_weight * lora = it.first->get_weight(w);
|
||||||
float scale = it.second;
|
|
||||||
if (lora == nullptr) {
|
if (lora == nullptr) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
const float alpha = it.first->alpha;
|
||||||
|
const float rank = (float) lora->b->ne[0];
|
||||||
|
const float scale = alpha ? it.second * alpha / rank : it.second;
|
||||||
struct ggml_tensor * ab_cur = ggml_mul_mat(
|
struct ggml_tensor * ab_cur = ggml_mul_mat(
|
||||||
ctx0, lora->b,
|
ctx0, lora->b,
|
||||||
ggml_mul_mat(ctx0, lora->a, cur)
|
ggml_mul_mat(ctx0, lora->a, cur)
|
||||||
@ -7902,10 +7908,12 @@ static struct ggml_tensor * llm_build_lora_mm_id(
|
|||||||
struct ggml_tensor * res = ggml_mul_mat_id(ctx0, w, cur, ids);
|
struct ggml_tensor * res = ggml_mul_mat_id(ctx0, w, cur, ids);
|
||||||
for (auto & it : lctx.lora_adapters) {
|
for (auto & it : lctx.lora_adapters) {
|
||||||
struct llama_lora_weight * lora = it.first->get_weight(w);
|
struct llama_lora_weight * lora = it.first->get_weight(w);
|
||||||
float scale = it.second;
|
|
||||||
if (lora == nullptr) {
|
if (lora == nullptr) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
const float alpha = it.first->alpha;
|
||||||
|
const float rank = (float) lora->b->ne[0];
|
||||||
|
const float scale = alpha ? it.second * alpha / rank : it.second;
|
||||||
struct ggml_tensor * ab_cur = ggml_mul_mat_id(
|
struct ggml_tensor * ab_cur = ggml_mul_mat_id(
|
||||||
ctx0, lora->b,
|
ctx0, lora->b,
|
||||||
ggml_mul_mat_id(ctx0, lora->a, cur, ids),
|
ggml_mul_mat_id(ctx0, lora->a, cur, ids),
|
||||||
@ -18587,10 +18595,14 @@ static void llama_lora_adapter_init_internal(struct llama_model * model, const c
|
|||||||
|
|
||||||
// check metadata
|
// check metadata
|
||||||
{
|
{
|
||||||
auto get_kv_str = [&](std::string key) -> std::string {
|
auto get_kv_str = [&](const std::string & key) -> std::string {
|
||||||
int id = gguf_find_key(ctx_gguf, key.c_str());
|
int id = gguf_find_key(ctx_gguf, key.c_str());
|
||||||
return id < 0 ? "" : std::string(gguf_get_val_str(ctx_gguf, id));
|
return id < 0 ? "" : std::string(gguf_get_val_str(ctx_gguf, id));
|
||||||
};
|
};
|
||||||
|
auto get_kv_f32 = [&](const std::string & key) -> float {
|
||||||
|
int id = gguf_find_key(ctx_gguf, key.c_str());
|
||||||
|
return id < 0 ? 0.0f : gguf_get_val_f32(ctx_gguf, id);
|
||||||
|
};
|
||||||
LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN);
|
LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN);
|
||||||
auto lora_arch_name = get_kv_str(llm_kv(LLM_KV_GENERAL_ARCHITECTURE));
|
auto lora_arch_name = get_kv_str(llm_kv(LLM_KV_GENERAL_ARCHITECTURE));
|
||||||
auto lora_arch = llm_arch_from_string(lora_arch_name);
|
auto lora_arch = llm_arch_from_string(lora_arch_name);
|
||||||
@ -18604,6 +18616,8 @@ static void llama_lora_adapter_init_internal(struct llama_model * model, const c
|
|||||||
gguf_free(ctx_gguf);
|
gguf_free(ctx_gguf);
|
||||||
throw std::runtime_error("expect training.type to be finetune_lora, but got: " + train_type);
|
throw std::runtime_error("expect training.type to be finetune_lora, but got: " + train_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
adapter.alpha = get_kv_f32(llm_kv(LLM_KV_TRAINING_LORA_ALPHA));
|
||||||
}
|
}
|
||||||
|
|
||||||
int n_tensors = gguf_get_n_tensors(ctx_gguf);
|
int n_tensors = gguf_get_n_tensors(ctx_gguf);
|
||||||
|
Loading…
Reference in New Issue
Block a user