diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 61f8e370c..ebb5ca376 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -373,9 +373,6 @@ class Model: except KeyError: raise NotImplementedError(f'Architecture {arch!r} not supported!') from None - def support_lora(self) -> bool: - return False - # used for GPT-2 BPE and WordPiece vocabs def get_vocab_base(self) -> tuple[list[str], list[int], str]: tokens: list[str] = [] @@ -1415,9 +1412,9 @@ class LlamaModel(Model): n_head = self.hparams["num_attention_heads"] n_kv_head = self.hparams.get("num_key_value_heads") - if name.endswith(("q_proj.weight", "q_proj.bias", "q_proj.lora_B.weight")): + if name.endswith(("q_proj.weight", "q_proj.bias")): data_torch = LlamaModel.permute(data_torch, n_head, n_head) - if name.endswith(("k_proj.weight", "k_proj.bias", "k_proj.lora_B.weight")): + if name.endswith(("k_proj.weight", "k_proj.bias")): data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head) # process the experts separately @@ -1465,10 +1462,6 @@ class LlamaModel(Model): if len(experts) > 0: raise ValueError(f"Unprocessed experts: {experts}") - def support_lora(self) -> bool: - # TODO: support lora conversion for MOE - return "num_local_experts" not in self.hparams - @Model.register("BitnetForCausalLM") class BitnetModel(Model): diff --git a/convert_lora_to_gguf.py b/convert_lora_to_gguf.py index c7393ac3a..2d01fdc46 100755 --- a/convert_lora_to_gguf.py +++ b/convert_lora_to_gguf.py @@ -3,13 +3,14 @@ from __future__ import annotations +from dataclasses import dataclass import logging import argparse import os import sys -import types from pathlib import Path -from typing import TYPE_CHECKING, Iterator +from types import EllipsisType +from typing import TYPE_CHECKING, Callable, Iterable, Iterator, Sequence, SupportsIndex, cast import torch @@ -26,6 +27,169 @@ from convert_hf_to_gguf import Model logger = logging.getLogger("lora-to-gguf") +@dataclass +class PartialLoraTensor: + A: Tensor | None = None + B: Tensor | None = None + + +# magic to support tensor shape modifications and splitting +class LoraTorchTensor: + _lora_A: Tensor + _lora_B: Tensor + _rank: int + + def __init__(self, A: Tensor, B: Tensor): + assert len(A.shape) == len(B.shape) + if A.dtype != B.dtype: + A = A.to(torch.float32) + B = B.to(torch.float32) + self._lora_A = A + self._lora_B = B + assert self._lora_A.shape[-2] == self._lora_B.shape[-1] + self._rank = self._lora_B.shape[-1] + + def __getitem__( + self, + indices: ( + SupportsIndex + | slice + | tuple[SupportsIndex | slice | EllipsisType | Tensor, ...] + ), + ) -> LoraTorchTensor: + shape = self.shape + if isinstance(indices, (SupportsIndex, slice)): + if len(shape) > 2: + return LoraTorchTensor(self._lora_A[indices], self._lora_B[indices]) + else: + raise NotImplementedError + elif isinstance(indices, tuple): + assert len(indices) > 0 + if isinstance(indices[-1], EllipsisType): + return self[indices[:-1]] + # expand ellipsis + indices = tuple( + u + for v in ( + ( + (slice(None, None) for _ in range(len(indices) - 1)) + if isinstance(i, EllipsisType) + else (i,) + ) + for i in indices + ) + for u in v + ) + + if len(indices) < len(shape): + indices = (*indices, *(slice(None, None) for _ in range(len(indices), len(shape)))) + + # TODO: make sure this is correct + # lora_A has a shape which looks like (..., 1, 1, rank, self.shape[-1]) + indices_A = ( + *( + 0 if isinstance(i, SupportsIndex) else slice(None, None) + for i in indices[:-2] + ), + slice(None, None), + indices[-1], + ) + indices_B = indices[:-1] + return LoraTorchTensor(self._lora_A[indices_A], self._lora_B[indices_B]) + else: + raise NotImplementedError + + @property + def dtype(self) -> torch.dtype: + assert self._lora_A.dtype == self._lora_B.dtype + return self._lora_A.dtype + + @property + def shape(self) -> tuple[int, ...]: + return (*self._lora_B.shape[:-1], self._lora_A.shape[-1]) + + def size(self, dim=None): + assert dim is None + return self.shape + + def reshape(self, *shape: int | tuple[int]) -> LoraTorchTensor: + if isinstance(shape[0], tuple): + new_shape: tuple[int] = shape[0] + else: + new_shape = cast(tuple[int], shape) + orig_shape = self.shape + if new_shape[-1] != orig_shape[-1]: + raise NotImplementedError + return LoraTorchTensor( + self._lora_A.reshape((*(1 for _ in new_shape[:-2]), *self._lora_A.shape[-2:])), + self._lora_B.reshape((*new_shape[:-1], self._rank)), + ) + + def reshape_as(self, other: Tensor) -> LoraTorchTensor: + return self.reshape(*other.shape) + + def view(self, *size: int) -> LoraTorchTensor: + return self.reshape(*size) + + def permute(self, *dims: int) -> LoraTorchTensor: + shape = self.shape + dims = tuple(dim - len(shape) if dim >= 0 else dim for dim in dims) + if dims[-1] == -2 and dims[-2] == -1: + return LoraTorchTensor(self._lora_B.permute(*dims), self._lora_A.permute(*dims)) + else: + assert dims[-1] == -1 + assert all(dim == 1 for dim in self._lora_A.shape[:-2]) + return LoraTorchTensor(self._lora_A, self._lora_B.permute(*dims)) + + def transpose(self, dim0: int, dim1: int) -> LoraTorchTensor: + shape = self.shape + dims = [i for i in range(len(shape))] + dims[dim0], dims[dim1] = dims[dim1], dims[dim0] + return self.permute(*dims) + + def swapaxes(self, axis0: int, axis1: int) -> LoraTorchTensor: + return self.transpose(axis0, axis1) + + def to(self, *args, **kwargs): + return LoraTorchTensor(self._lora_A.to(*args, **kwargs), self._lora_B.to(*args, **kwargs)) + + @classmethod + def __torch_function__(cls, func: Callable, types, args=(), kwargs=None): + del types # unused + + if kwargs is None: + kwargs = {} + + if func is torch.permute: + return type(args[0]).permute(*args, **kwargs) + elif func is torch.reshape: + return type(args[0]).reshape(*args, **kwargs) + elif func is torch.stack: + assert isinstance(args[0], Sequence) + dim = kwargs.get("dim", 0) + assert dim == 0 + return LoraTorchTensor( + torch.stack([a._lora_A for a in args[0]], dim), + torch.stack([b._lora_B for b in args[0]], dim), + ) + elif func is torch.cat: + assert isinstance(args[0], Sequence) + dim = kwargs.get("dim", 0) + assert dim == 0 + if len(args[0][0].shape) > 2: + return LoraTorchTensor( + torch.cat([a._lora_A for a in args[0]], dim), + torch.cat([b._lora_B for b in args[0]], dim), + ) + else: + return LoraTorchTensor( + args[0][0]._lora_A, # TODO: is this correct? (can't cat over the rank) + torch.cat([b._lora_B for b in args[0]], dim), + ) + else: + raise NotImplementedError + + def get_base_tensor_name(lora_tensor_name: str) -> str: base_name = lora_tensor_name.replace("base_model.model.", "") base_name = base_name.replace(".lora_A.weight", ".weight") @@ -79,7 +243,7 @@ if __name__ == '__main__': dir_base_model = args.base dir_lora = args.lora_path input_json = os.path.join(dir_lora, "adapter_config.json") - input_model = os.path.join(dir_lora, "adapter_model.bin") + input_model = os.path.join(dir_lora, "adapter_model.safetensors") if args.outfile is not None: fname_out = args.outfile else: @@ -87,12 +251,13 @@ if __name__ == '__main__': fname_out = dir_lora / 'ggml-lora-{ftype}.gguf' if os.path.exists(input_model): - lora_model = torch.load(input_model, map_location="cpu") - else: - input_model = os.path.join(dir_lora, "adapter_model.safetensors") # lazy import load_file only if lora is in safetensors format. from safetensors.torch import load_file + lora_model = load_file(input_model, device="cpu") + else: + input_model = os.path.join(dir_lora, "adapter_model.bin") + lora_model = torch.load(input_model, map_location="cpu", weights_only=True) # load base model logger.info(f"Loading base model: {dir_base_model.name}") @@ -104,53 +269,54 @@ if __name__ == '__main__': logger.error(f"Model {hparams['architectures'][0]} is not supported") sys.exit(1) - model_instance = model_class(dir_base_model, ftype, fname_out, args.bigendian, False, False, None) + class LoraModel(model_class): + model_arch = model_class.model_arch + + def get_tensors(self) -> Iterator[tuple[str, Tensor]]: + tensor_map: dict[str, PartialLoraTensor] = {} + + for name, tensor in lora_model.items(): + base_name = get_base_tensor_name(name) + is_lora_a = ".lora_A.weight" in name + is_lora_b = ".lora_B.weight" in name + if not is_lora_a and not is_lora_b: + if ".base_layer.weight" in name: + continue + logger.error(f"Unexpected name '{name}': Not a lora_A or lora_B tensor") + sys.exit(1) + + if base_name in tensor_map: + if is_lora_a: + tensor_map[base_name].A = tensor + else: + tensor_map[base_name].B = tensor + else: + if is_lora_a: + tensor_map[base_name] = PartialLoraTensor(A=tensor) + else: + tensor_map[base_name] = PartialLoraTensor(B=tensor) + + for name, tensor in tensor_map.items(): + assert tensor.A is not None + assert tensor.B is not None + yield (name, cast(torch.Tensor, LoraTorchTensor(tensor.A, tensor.B))) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + dest = super().modify_tensors(data_torch, name, bid) + for dest_name, dest_data in dest: + assert isinstance(dest_data, LoraTorchTensor) + # logger.info(f"{orig_name} --> {dest_name}") + yield (dest_name + ".lora_a", dest_data._lora_A) + yield (dest_name + ".lora_b", dest_data._lora_B) + + model_instance = LoraModel(dir_base_model, ftype, fname_out, args.bigendian, False, False, None) logger.info("Set model parameters") model_instance.set_gguf_parameters() # adapter_config = json.load(input_json) model_instance.gguf_writer.add_string("training.type", "finetune_lora") - if not model_instance.support_lora(): - logger.error("LoRA conversion is not yet supported for this model") - sys.exit(1) - # map original name to gguf name - map_name: dict[str, str] = {} - for tensor_name, tensor in lora_model.items(): - base_name = get_base_tensor_name(tensor_name) - is_lora_a = ".lora_A.weight" in tensor_name - is_lora_b = ".lora_B.weight" in tensor_name - if not is_lora_a and not is_lora_b: - logger.error(f"Unexpected name '{tensor_name}': Not a lora_A or lora_B tensor") - sys.exit(1) - dest_name = model_instance.map_tensor_name(base_name) - dest_name = f"{dest_name}.lora_a" if is_lora_a else f"{dest_name}.lora_b" - map_name[tensor_name] = dest_name - - # overwrite method - def map_tensor_name(self, name: str) -> str: - return map_name[name] - - # overwrite method - def get_tensors(self) -> Iterator[tuple[str, Tensor]]: - for name, tensor in lora_model.items(): - yield (name, tensor) - - # overwrite method - def extra_f16_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool: - del name, new_name, bid, n_dims # unused - return ftype != gguf.LlamaFileType.ALL_F32 - - model_instance._map_tensor_name = model_instance.map_tensor_name # type: ignore - model_instance.map_tensor_name = types.MethodType(map_tensor_name, model_instance) - - model_instance._get_tensors = model_instance.get_tensors # type: ignore - model_instance.get_tensors = types.MethodType(get_tensors, model_instance) - - model_instance._extra_f16_tensors = model_instance.extra_f16_tensors # type: ignore - model_instance.extra_f16_tensors = types.MethodType(extra_f16_tensors, model_instance) - - model_instance.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION) - logger.info("Exporting model...") - model_instance.write() - logger.info(f"Model successfully exported to {model_instance.fname_out}") + model_instance.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION) + logger.info("Exporting model...") + model_instance.write() + logger.info(f"Model successfully exported to {model_instance.fname_out}")