convert_lora : MoE LoRA conversion support

* convert_lora : prefer safetensors, similarly to convert_hf
This commit is contained in:
Francis Couture-Harpin 2024-07-09 18:26:38 -04:00
parent 916e95928b
commit 9d96328bdf
2 changed files with 218 additions and 59 deletions

View File

@ -373,9 +373,6 @@ class Model:
except KeyError: except KeyError:
raise NotImplementedError(f'Architecture {arch!r} not supported!') from None raise NotImplementedError(f'Architecture {arch!r} not supported!') from None
def support_lora(self) -> bool:
return False
# used for GPT-2 BPE and WordPiece vocabs # used for GPT-2 BPE and WordPiece vocabs
def get_vocab_base(self) -> tuple[list[str], list[int], str]: def get_vocab_base(self) -> tuple[list[str], list[int], str]:
tokens: list[str] = [] tokens: list[str] = []
@ -1415,9 +1412,9 @@ class LlamaModel(Model):
n_head = self.hparams["num_attention_heads"] n_head = self.hparams["num_attention_heads"]
n_kv_head = self.hparams.get("num_key_value_heads") n_kv_head = self.hparams.get("num_key_value_heads")
if name.endswith(("q_proj.weight", "q_proj.bias", "q_proj.lora_B.weight")): if name.endswith(("q_proj.weight", "q_proj.bias")):
data_torch = LlamaModel.permute(data_torch, n_head, n_head) data_torch = LlamaModel.permute(data_torch, n_head, n_head)
if name.endswith(("k_proj.weight", "k_proj.bias", "k_proj.lora_B.weight")): if name.endswith(("k_proj.weight", "k_proj.bias")):
data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head) data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head)
# process the experts separately # process the experts separately
@ -1465,10 +1462,6 @@ class LlamaModel(Model):
if len(experts) > 0: if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}") raise ValueError(f"Unprocessed experts: {experts}")
def support_lora(self) -> bool:
# TODO: support lora conversion for MOE
return "num_local_experts" not in self.hparams
@Model.register("BitnetForCausalLM") @Model.register("BitnetForCausalLM")
class BitnetModel(Model): class BitnetModel(Model):

View File

@ -3,13 +3,14 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass
import logging import logging
import argparse import argparse
import os import os
import sys import sys
import types
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Iterator from types import EllipsisType
from typing import TYPE_CHECKING, Callable, Iterable, Iterator, Sequence, SupportsIndex, cast
import torch import torch
@ -26,6 +27,169 @@ from convert_hf_to_gguf import Model
logger = logging.getLogger("lora-to-gguf") logger = logging.getLogger("lora-to-gguf")
@dataclass
class PartialLoraTensor:
A: Tensor | None = None
B: Tensor | None = None
# magic to support tensor shape modifications and splitting
class LoraTorchTensor:
_lora_A: Tensor
_lora_B: Tensor
_rank: int
def __init__(self, A: Tensor, B: Tensor):
assert len(A.shape) == len(B.shape)
if A.dtype != B.dtype:
A = A.to(torch.float32)
B = B.to(torch.float32)
self._lora_A = A
self._lora_B = B
assert self._lora_A.shape[-2] == self._lora_B.shape[-1]
self._rank = self._lora_B.shape[-1]
def __getitem__(
self,
indices: (
SupportsIndex
| slice
| tuple[SupportsIndex | slice | EllipsisType | Tensor, ...]
),
) -> LoraTorchTensor:
shape = self.shape
if isinstance(indices, (SupportsIndex, slice)):
if len(shape) > 2:
return LoraTorchTensor(self._lora_A[indices], self._lora_B[indices])
else:
raise NotImplementedError
elif isinstance(indices, tuple):
assert len(indices) > 0
if isinstance(indices[-1], EllipsisType):
return self[indices[:-1]]
# expand ellipsis
indices = tuple(
u
for v in (
(
(slice(None, None) for _ in range(len(indices) - 1))
if isinstance(i, EllipsisType)
else (i,)
)
for i in indices
)
for u in v
)
if len(indices) < len(shape):
indices = (*indices, *(slice(None, None) for _ in range(len(indices), len(shape))))
# TODO: make sure this is correct
# lora_A has a shape which looks like (..., 1, 1, rank, self.shape[-1])
indices_A = (
*(
0 if isinstance(i, SupportsIndex) else slice(None, None)
for i in indices[:-2]
),
slice(None, None),
indices[-1],
)
indices_B = indices[:-1]
return LoraTorchTensor(self._lora_A[indices_A], self._lora_B[indices_B])
else:
raise NotImplementedError
@property
def dtype(self) -> torch.dtype:
assert self._lora_A.dtype == self._lora_B.dtype
return self._lora_A.dtype
@property
def shape(self) -> tuple[int, ...]:
return (*self._lora_B.shape[:-1], self._lora_A.shape[-1])
def size(self, dim=None):
assert dim is None
return self.shape
def reshape(self, *shape: int | tuple[int]) -> LoraTorchTensor:
if isinstance(shape[0], tuple):
new_shape: tuple[int] = shape[0]
else:
new_shape = cast(tuple[int], shape)
orig_shape = self.shape
if new_shape[-1] != orig_shape[-1]:
raise NotImplementedError
return LoraTorchTensor(
self._lora_A.reshape((*(1 for _ in new_shape[:-2]), *self._lora_A.shape[-2:])),
self._lora_B.reshape((*new_shape[:-1], self._rank)),
)
def reshape_as(self, other: Tensor) -> LoraTorchTensor:
return self.reshape(*other.shape)
def view(self, *size: int) -> LoraTorchTensor:
return self.reshape(*size)
def permute(self, *dims: int) -> LoraTorchTensor:
shape = self.shape
dims = tuple(dim - len(shape) if dim >= 0 else dim for dim in dims)
if dims[-1] == -2 and dims[-2] == -1:
return LoraTorchTensor(self._lora_B.permute(*dims), self._lora_A.permute(*dims))
else:
assert dims[-1] == -1
assert all(dim == 1 for dim in self._lora_A.shape[:-2])
return LoraTorchTensor(self._lora_A, self._lora_B.permute(*dims))
def transpose(self, dim0: int, dim1: int) -> LoraTorchTensor:
shape = self.shape
dims = [i for i in range(len(shape))]
dims[dim0], dims[dim1] = dims[dim1], dims[dim0]
return self.permute(*dims)
def swapaxes(self, axis0: int, axis1: int) -> LoraTorchTensor:
return self.transpose(axis0, axis1)
def to(self, *args, **kwargs):
return LoraTorchTensor(self._lora_A.to(*args, **kwargs), self._lora_B.to(*args, **kwargs))
@classmethod
def __torch_function__(cls, func: Callable, types, args=(), kwargs=None):
del types # unused
if kwargs is None:
kwargs = {}
if func is torch.permute:
return type(args[0]).permute(*args, **kwargs)
elif func is torch.reshape:
return type(args[0]).reshape(*args, **kwargs)
elif func is torch.stack:
assert isinstance(args[0], Sequence)
dim = kwargs.get("dim", 0)
assert dim == 0
return LoraTorchTensor(
torch.stack([a._lora_A for a in args[0]], dim),
torch.stack([b._lora_B for b in args[0]], dim),
)
elif func is torch.cat:
assert isinstance(args[0], Sequence)
dim = kwargs.get("dim", 0)
assert dim == 0
if len(args[0][0].shape) > 2:
return LoraTorchTensor(
torch.cat([a._lora_A for a in args[0]], dim),
torch.cat([b._lora_B for b in args[0]], dim),
)
else:
return LoraTorchTensor(
args[0][0]._lora_A, # TODO: is this correct? (can't cat over the rank)
torch.cat([b._lora_B for b in args[0]], dim),
)
else:
raise NotImplementedError
def get_base_tensor_name(lora_tensor_name: str) -> str: def get_base_tensor_name(lora_tensor_name: str) -> str:
base_name = lora_tensor_name.replace("base_model.model.", "") base_name = lora_tensor_name.replace("base_model.model.", "")
base_name = base_name.replace(".lora_A.weight", ".weight") base_name = base_name.replace(".lora_A.weight", ".weight")
@ -79,7 +243,7 @@ if __name__ == '__main__':
dir_base_model = args.base dir_base_model = args.base
dir_lora = args.lora_path dir_lora = args.lora_path
input_json = os.path.join(dir_lora, "adapter_config.json") input_json = os.path.join(dir_lora, "adapter_config.json")
input_model = os.path.join(dir_lora, "adapter_model.bin") input_model = os.path.join(dir_lora, "adapter_model.safetensors")
if args.outfile is not None: if args.outfile is not None:
fname_out = args.outfile fname_out = args.outfile
else: else:
@ -87,12 +251,13 @@ if __name__ == '__main__':
fname_out = dir_lora / 'ggml-lora-{ftype}.gguf' fname_out = dir_lora / 'ggml-lora-{ftype}.gguf'
if os.path.exists(input_model): if os.path.exists(input_model):
lora_model = torch.load(input_model, map_location="cpu")
else:
input_model = os.path.join(dir_lora, "adapter_model.safetensors")
# lazy import load_file only if lora is in safetensors format. # lazy import load_file only if lora is in safetensors format.
from safetensors.torch import load_file from safetensors.torch import load_file
lora_model = load_file(input_model, device="cpu") lora_model = load_file(input_model, device="cpu")
else:
input_model = os.path.join(dir_lora, "adapter_model.bin")
lora_model = torch.load(input_model, map_location="cpu", weights_only=True)
# load base model # load base model
logger.info(f"Loading base model: {dir_base_model.name}") logger.info(f"Loading base model: {dir_base_model.name}")
@ -104,53 +269,54 @@ if __name__ == '__main__':
logger.error(f"Model {hparams['architectures'][0]} is not supported") logger.error(f"Model {hparams['architectures'][0]} is not supported")
sys.exit(1) sys.exit(1)
model_instance = model_class(dir_base_model, ftype, fname_out, args.bigendian, False, False, None) class LoraModel(model_class):
model_arch = model_class.model_arch
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
tensor_map: dict[str, PartialLoraTensor] = {}
for name, tensor in lora_model.items():
base_name = get_base_tensor_name(name)
is_lora_a = ".lora_A.weight" in name
is_lora_b = ".lora_B.weight" in name
if not is_lora_a and not is_lora_b:
if ".base_layer.weight" in name:
continue
logger.error(f"Unexpected name '{name}': Not a lora_A or lora_B tensor")
sys.exit(1)
if base_name in tensor_map:
if is_lora_a:
tensor_map[base_name].A = tensor
else:
tensor_map[base_name].B = tensor
else:
if is_lora_a:
tensor_map[base_name] = PartialLoraTensor(A=tensor)
else:
tensor_map[base_name] = PartialLoraTensor(B=tensor)
for name, tensor in tensor_map.items():
assert tensor.A is not None
assert tensor.B is not None
yield (name, cast(torch.Tensor, LoraTorchTensor(tensor.A, tensor.B)))
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
dest = super().modify_tensors(data_torch, name, bid)
for dest_name, dest_data in dest:
assert isinstance(dest_data, LoraTorchTensor)
# logger.info(f"{orig_name} --> {dest_name}")
yield (dest_name + ".lora_a", dest_data._lora_A)
yield (dest_name + ".lora_b", dest_data._lora_B)
model_instance = LoraModel(dir_base_model, ftype, fname_out, args.bigendian, False, False, None)
logger.info("Set model parameters") logger.info("Set model parameters")
model_instance.set_gguf_parameters() model_instance.set_gguf_parameters()
# adapter_config = json.load(input_json) # adapter_config = json.load(input_json)
model_instance.gguf_writer.add_string("training.type", "finetune_lora") model_instance.gguf_writer.add_string("training.type", "finetune_lora")
if not model_instance.support_lora():
logger.error("LoRA conversion is not yet supported for this model")
sys.exit(1)
# map original name to gguf name model_instance.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION)
map_name: dict[str, str] = {} logger.info("Exporting model...")
for tensor_name, tensor in lora_model.items(): model_instance.write()
base_name = get_base_tensor_name(tensor_name) logger.info(f"Model successfully exported to {model_instance.fname_out}")
is_lora_a = ".lora_A.weight" in tensor_name
is_lora_b = ".lora_B.weight" in tensor_name
if not is_lora_a and not is_lora_b:
logger.error(f"Unexpected name '{tensor_name}': Not a lora_A or lora_B tensor")
sys.exit(1)
dest_name = model_instance.map_tensor_name(base_name)
dest_name = f"{dest_name}.lora_a" if is_lora_a else f"{dest_name}.lora_b"
map_name[tensor_name] = dest_name
# overwrite method
def map_tensor_name(self, name: str) -> str:
return map_name[name]
# overwrite method
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
for name, tensor in lora_model.items():
yield (name, tensor)
# overwrite method
def extra_f16_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool:
del name, new_name, bid, n_dims # unused
return ftype != gguf.LlamaFileType.ALL_F32
model_instance._map_tensor_name = model_instance.map_tensor_name # type: ignore
model_instance.map_tensor_name = types.MethodType(map_tensor_name, model_instance)
model_instance._get_tensors = model_instance.get_tensors # type: ignore
model_instance.get_tensors = types.MethodType(get_tensors, model_instance)
model_instance._extra_f16_tensors = model_instance.extra_f16_tensors # type: ignore
model_instance.extra_f16_tensors = types.MethodType(extra_f16_tensors, model_instance)
model_instance.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION)
logger.info("Exporting model...")
model_instance.write()
logger.info(f"Model successfully exported to {model_instance.fname_out}")