diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index ebb5ca376..70ea963f2 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2222,13 +2222,6 @@ class InternLM2Model(Model): special_vocab.add_to_gguf(self.gguf_writer) - def _hf_permute_qk(self, weights, n_head: int, n_head_kv: int): - if n_head_kv is not None and n_head != n_head_kv: - n_head = n_head_kv - return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) - .swapaxes(1, 2) - .reshape(weights.shape)) - def set_gguf_parameters(self): self.gguf_writer.add_name("InternLM2") self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"]) @@ -2248,26 +2241,22 @@ class InternLM2Model(Model): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: num_heads = self.hparams["num_attention_heads"] num_kv_heads = self.hparams["num_key_value_heads"] - hidden_size = self.hparams["hidden_size"] + n_embd = self.hparams["hidden_size"] q_per_kv = num_heads // num_kv_heads - head_dim = hidden_size // num_heads + head_dim = n_embd // num_heads num_groups = num_heads // q_per_kv - qkv_pattern = r"model\.layers\.(\d+)\.attention\.wqkv" - - if re.match(qkv_pattern, name): - bid = re.findall(qkv_pattern, name)[0] + if bid is not None and f"model.layers.{bid}.attention.wqkv" in name: qkv = data_torch - # qkv = rearrange(qkv.T, " o (g n i) ->o g n i", g=num_groups, n=q_per_kv + 2, i=head_dim) - qkv = qkv.T.reshape((-1, num_groups, q_per_kv + 2, head_dim)) - q, k, v = qkv[..., : q_per_kv, :], qkv[..., q_per_kv: q_per_kv + 1, :], qkv[..., q_per_kv + 1: q_per_kv + 2, :] + + qkv = qkv.reshape((num_groups, q_per_kv + 2, head_dim, n_embd)) + q, k, v = qkv[:, : q_per_kv], qkv[:, -2], qkv[:, -1] + # The model weights of q and k equire additional reshape. - # q = self._hf_permute_qk(rearrange(q, " o g n i -> o (g n i)").T, num_heads, num_heads) - q = self._hf_permute_qk(q.reshape((q.shape[0], -1)).T, num_heads, num_heads) - # k = self._hf_permute_qk(rearrange(k, " o g n i -> o (g n i)").T, num_heads, num_kv_heads) - k = self._hf_permute_qk(k.reshape((k.shape[0], -1)).T, num_heads, num_kv_heads) - # v = rearrange(v, " o g n i -> o (g n i)").T - v = v.reshape((v.shape[0], -1)).T + q = LlamaModel.permute(q.reshape((-1, q.shape[-1])), num_heads, num_heads) + k = LlamaModel.permute(k.reshape((-1, k.shape[-1])), num_heads, num_kv_heads) + v = v.reshape((-1, v.shape[-1])) + return [ (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), q), (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), k), diff --git a/convert_lora_to_gguf.py b/convert_lora_to_gguf.py index 2d01fdc46..71d3e57f5 100755 --- a/convert_lora_to_gguf.py +++ b/convert_lora_to_gguf.py @@ -8,9 +8,10 @@ import logging import argparse import os import sys +import json +from math import prod from pathlib import Path -from types import EllipsisType -from typing import TYPE_CHECKING, Callable, Iterable, Iterator, Sequence, SupportsIndex, cast +from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Sequence, SupportsIndex, cast import torch @@ -22,7 +23,7 @@ if 'NO_LOCAL_GGUF' not in os.environ: import gguf # reuse model definitions from convert_hf_to_gguf.py -from convert_hf_to_gguf import Model +from convert_hf_to_gguf import LazyTorchTensor, Model logger = logging.getLogger("lora-to-gguf") @@ -35,37 +36,45 @@ class PartialLoraTensor: # magic to support tensor shape modifications and splitting class LoraTorchTensor: - _lora_A: Tensor - _lora_B: Tensor + _lora_A: Tensor # (n_rank, row_size) + _lora_B: Tensor # (col_size, n_rank) _rank: int def __init__(self, A: Tensor, B: Tensor): assert len(A.shape) == len(B.shape) + assert A.shape[-2] == B.shape[-1] if A.dtype != B.dtype: A = A.to(torch.float32) B = B.to(torch.float32) self._lora_A = A self._lora_B = B - assert self._lora_A.shape[-2] == self._lora_B.shape[-1] - self._rank = self._lora_B.shape[-1] + self._rank = B.shape[-1] + + def get_lora_A_B(self) -> tuple[Tensor, Tensor]: + return (self._lora_A, self._lora_B) def __getitem__( self, indices: ( SupportsIndex | slice - | tuple[SupportsIndex | slice | EllipsisType | Tensor, ...] + | tuple[SupportsIndex | slice | Tensor, ...] # TODO: add ellipsis in the type signature ), ) -> LoraTorchTensor: shape = self.shape - if isinstance(indices, (SupportsIndex, slice)): + if isinstance(indices, SupportsIndex): if len(shape) > 2: return LoraTorchTensor(self._lora_A[indices], self._lora_B[indices]) else: - raise NotImplementedError + raise NotImplementedError # can't return a vector + elif isinstance(indices, slice): + if len(shape) > 2: + return LoraTorchTensor(self._lora_A[indices], self._lora_B[indices]) + else: + return LoraTorchTensor(self._lora_A, self._lora_B[indices]) elif isinstance(indices, tuple): assert len(indices) > 0 - if isinstance(indices[-1], EllipsisType): + if indices[-1] is Ellipsis: return self[indices[:-1]] # expand ellipsis indices = tuple( @@ -73,7 +82,7 @@ class LoraTorchTensor: for v in ( ( (slice(None, None) for _ in range(len(indices) - 1)) - if isinstance(i, EllipsisType) + if i is Ellipsis else (i,) ) for i in indices @@ -85,11 +94,14 @@ class LoraTorchTensor: indices = (*indices, *(slice(None, None) for _ in range(len(indices), len(shape)))) # TODO: make sure this is correct - # lora_A has a shape which looks like (..., 1, 1, rank, self.shape[-1]) indices_A = ( *( - 0 if isinstance(i, SupportsIndex) else slice(None, None) - for i in indices[:-2] + ( + j.__index__() % self._lora_A.shape[i] + if isinstance(j, SupportsIndex) + else slice(None, None) + ) + for i, j in enumerate(indices[:-2]) ), slice(None, None), indices[-1], @@ -97,7 +109,7 @@ class LoraTorchTensor: indices_B = indices[:-1] return LoraTorchTensor(self._lora_A[indices_A], self._lora_B[indices_B]) else: - raise NotImplementedError + raise NotImplementedError # unknown indice type @property def dtype(self) -> torch.dtype: @@ -106,23 +118,37 @@ class LoraTorchTensor: @property def shape(self) -> tuple[int, ...]: + assert len(self._lora_A.shape) == len(self._lora_B.shape) return (*self._lora_B.shape[:-1], self._lora_A.shape[-1]) def size(self, dim=None): assert dim is None return self.shape - def reshape(self, *shape: int | tuple[int]) -> LoraTorchTensor: + def reshape(self, *shape: int | tuple[int, ...]) -> LoraTorchTensor: if isinstance(shape[0], tuple): - new_shape: tuple[int] = shape[0] + new_shape: tuple[int, ...] = shape[0] else: - new_shape = cast(tuple[int], shape) + new_shape = cast(tuple[int, ...], shape) orig_shape = self.shape + if len(new_shape) < 2: + raise NotImplementedError # can't become a vector + + # expand -1 in the shape + if any(dim == -1 for dim in new_shape): + n_elems = prod(orig_shape) + n_new_elems = prod(dim if dim != -1 else 1 for dim in new_shape) + assert n_elems % n_new_elems == 0 + new_shape = (*(dim if dim != -1 else n_elems // n_new_elems for dim in new_shape),) + if new_shape[-1] != orig_shape[-1]: - raise NotImplementedError + raise NotImplementedError # can't reshape the row size trivially + + shape_A = (*(1 for _ in new_shape[:-2]), self._rank, orig_shape[-1]) + shape_B = (*new_shape[:-1], self._rank) return LoraTorchTensor( - self._lora_A.reshape((*(1 for _ in new_shape[:-2]), *self._lora_A.shape[-2:])), - self._lora_B.reshape((*new_shape[:-1], self._rank)), + self._lora_A.reshape(shape_A), + self._lora_B.reshape(shape_B), ) def reshape_as(self, other: Tensor) -> LoraTorchTensor: @@ -134,12 +160,15 @@ class LoraTorchTensor: def permute(self, *dims: int) -> LoraTorchTensor: shape = self.shape dims = tuple(dim - len(shape) if dim >= 0 else dim for dim in dims) - if dims[-1] == -2 and dims[-2] == -1: - return LoraTorchTensor(self._lora_B.permute(*dims), self._lora_A.permute(*dims)) - else: - assert dims[-1] == -1 + if dims[-1] == -1: + # TODO: support higher dimensional A shapes bigger than 1 assert all(dim == 1 for dim in self._lora_A.shape[:-2]) return LoraTorchTensor(self._lora_A, self._lora_B.permute(*dims)) + if len(shape) == 2 and dims[-1] == -2 and dims[-2] == -1: + return LoraTorchTensor(self._lora_B.permute(*dims), self._lora_A.permute(*dims)) + else: + # TODO: compose the above two + raise NotImplementedError def transpose(self, dim0: int, dim1: int) -> LoraTorchTensor: shape = self.shape @@ -181,11 +210,13 @@ class LoraTorchTensor: torch.cat([a._lora_A for a in args[0]], dim), torch.cat([b._lora_B for b in args[0]], dim), ) - else: + elif all(torch.equal(args[0][0]._lora_A, t._lora_A) for t in args[0][1:]): return LoraTorchTensor( - args[0][0]._lora_A, # TODO: is this correct? (can't cat over the rank) + args[0][0]._lora_A, torch.cat([b._lora_B for b in args[0]], dim), ) + else: + raise NotImplementedError else: raise NotImplementedError @@ -205,13 +236,17 @@ def parse_args() -> argparse.Namespace: help="path to write to; default: based on input. {ftype} will be replaced by the outtype.", ) parser.add_argument( - "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0"], default="f16", - help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0", + "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f16", + help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type", ) parser.add_argument( "--bigendian", action="store_true", help="model is executed on big endian machine", ) + parser.add_argument( + "--no-lazy", action="store_true", + help="use more RAM by computing all outputs before writing (use in case lazy evaluation is broken)", + ) parser.add_argument( "--verbose", action="store_true", help="increase output verbosity", @@ -237,13 +272,16 @@ if __name__ == '__main__': "f16": gguf.LlamaFileType.MOSTLY_F16, "bf16": gguf.LlamaFileType.MOSTLY_BF16, "q8_0": gguf.LlamaFileType.MOSTLY_Q8_0, + "auto": gguf.LlamaFileType.GUESSED, } + ftype = ftype_map[args.outtype] - dir_base_model = args.base - dir_lora = args.lora_path - input_json = os.path.join(dir_lora, "adapter_config.json") - input_model = os.path.join(dir_lora, "adapter_model.safetensors") + dir_base_model: Path = args.base + dir_lora: Path = args.lora_path + lora_config = dir_lora / "adapter_config.json" + input_model = dir_lora / "adapter_model.safetensors" + if args.outfile is not None: fname_out = args.outfile else: @@ -276,6 +314,8 @@ if __name__ == '__main__': tensor_map: dict[str, PartialLoraTensor] = {} for name, tensor in lora_model.items(): + if self.lazy: + tensor = LazyTorchTensor.from_eager(tensor) base_name = get_base_tensor_name(name) is_lora_a = ".lora_A.weight" in name is_lora_b = ".lora_B.weight" in name @@ -305,16 +345,30 @@ if __name__ == '__main__': dest = super().modify_tensors(data_torch, name, bid) for dest_name, dest_data in dest: assert isinstance(dest_data, LoraTorchTensor) - # logger.info(f"{orig_name} --> {dest_name}") - yield (dest_name + ".lora_a", dest_data._lora_A) - yield (dest_name + ".lora_b", dest_data._lora_B) + lora_a, lora_b = dest_data.get_lora_A_B() - model_instance = LoraModel(dir_base_model, ftype, fname_out, args.bigendian, False, False, None) + yield (dest_name + ".lora_a", lora_a) + yield (dest_name + ".lora_b", lora_b) + + model_instance = LoraModel( + dir_base_model, + ftype, + fname_out, + is_big_endian=args.bigendian, + use_temp_file=False, + eager=args.no_lazy, + model_name=None, + ) logger.info("Set model parameters") model_instance.set_gguf_parameters() - # adapter_config = json.load(input_json) + with open(lora_config, "r") as f: + lparams: dict[str, Any] = json.load(f) + + alpha = lparams["lora_alpha"] + model_instance.gguf_writer.add_string("training.type", "finetune_lora") + model_instance.gguf_writer.add_float32("training.lora.alpha", float(alpha)) model_instance.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION) logger.info("Exporting model...") diff --git a/gguf-py/gguf/quants.py b/gguf-py/gguf/quants.py index b22eec166..16e0a9aaa 100644 --- a/gguf-py/gguf/quants.py +++ b/gguf-py/gguf/quants.py @@ -43,7 +43,7 @@ def __apply_over_grouped_rows(func: Callable[[np.ndarray], np.ndarray], arr: np. osize *= dim out = np.empty(shape=osize, dtype=otype) # compute over groups of 16 rows (arbitrary, but seems good for performance) - n_groups = rows.shape[0] // 16 + n_groups = (rows.shape[0] // 16) or 1 np.concatenate([func(group).ravel() for group in np.array_split(rows, n_groups)], axis=0, out=out) return out.reshape(oshape) diff --git a/src/llama.cpp b/src/llama.cpp index 30ecbb801..3906b9ea1 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -379,6 +379,7 @@ enum llm_kv { LLM_KV_TOKENIZER_EOT_ID, LLM_KV_TRAINING_TYPE, + LLM_KV_TRAINING_LORA_ALPHA, }; static const std::map LLM_KV_NAMES = { @@ -473,7 +474,8 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_TOKENIZER_MIDDLE_ID, "tokenizer.ggml.middle_token_id" }, { LLM_KV_TOKENIZER_EOT_ID, "tokenizer.ggml.eot_token_id" }, - { LLM_KV_TRAINING_TYPE, "training.type" }, + { LLM_KV_TRAINING_TYPE, "training.type" }, + { LLM_KV_TRAINING_LORA_ALPHA, "training.lora.alpha" }, }; struct LLM_KV { @@ -2848,6 +2850,8 @@ struct llama_lora_adapter { std::vector ctxs; std::vector bufs; + float alpha; + llama_lora_adapter(struct llama_model * base_model): base_model(base_model) { base_model->lora_adapters.insert(this); } @@ -7878,10 +7882,12 @@ static struct ggml_tensor * llm_build_lora_mm( struct ggml_tensor * res = ggml_mul_mat(ctx0, w, cur); for (auto & it : lctx.lora_adapters) { struct llama_lora_weight * lora = it.first->get_weight(w); - float scale = it.second; if (lora == nullptr) { continue; } + const float alpha = it.first->alpha; + const float rank = (float) lora->b->ne[0]; + const float scale = alpha ? it.second * alpha / rank : it.second; struct ggml_tensor * ab_cur = ggml_mul_mat( ctx0, lora->b, ggml_mul_mat(ctx0, lora->a, cur) @@ -7902,10 +7908,12 @@ static struct ggml_tensor * llm_build_lora_mm_id( struct ggml_tensor * res = ggml_mul_mat_id(ctx0, w, cur, ids); for (auto & it : lctx.lora_adapters) { struct llama_lora_weight * lora = it.first->get_weight(w); - float scale = it.second; if (lora == nullptr) { continue; } + const float alpha = it.first->alpha; + const float rank = (float) lora->b->ne[0]; + const float scale = alpha ? it.second * alpha / rank : it.second; struct ggml_tensor * ab_cur = ggml_mul_mat_id( ctx0, lora->b, ggml_mul_mat_id(ctx0, lora->a, cur, ids), @@ -18587,10 +18595,14 @@ static void llama_lora_adapter_init_internal(struct llama_model * model, const c // check metadata { - auto get_kv_str = [&](std::string key) -> std::string { + auto get_kv_str = [&](const std::string & key) -> std::string { int id = gguf_find_key(ctx_gguf, key.c_str()); return id < 0 ? "" : std::string(gguf_get_val_str(ctx_gguf, id)); }; + auto get_kv_f32 = [&](const std::string & key) -> float { + int id = gguf_find_key(ctx_gguf, key.c_str()); + return id < 0 ? 0.0f : gguf_get_val_f32(ctx_gguf, id); + }; LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN); auto lora_arch_name = get_kv_str(llm_kv(LLM_KV_GENERAL_ARCHITECTURE)); auto lora_arch = llm_arch_from_string(lora_arch_name); @@ -18604,6 +18616,8 @@ static void llama_lora_adapter_init_internal(struct llama_model * model, const c gguf_free(ctx_gguf); throw std::runtime_error("expect training.type to be finetune_lora, but got: " + train_type); } + + adapter.alpha = get_kv_f32(llm_kv(LLM_KV_TRAINING_LORA_ALPHA)); } int n_tensors = gguf_get_n_tensors(ctx_gguf);