mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-11-11 13:30:35 +00:00
Refactor lora adapter support (#8332)
* lora: load to devide buft
* add patch tensor function
* correct tensor patch
* llama_lora_adapter_apply
* correct ggml_backend_tensor_copy
* add llm_build_mm
* fix auto merge
* update based on review comments
* add convert script
* no more transpose A
* add f16 convert
* add metadata check
* add sanity check
* fix ftype
* add requirements
* fix requirements
* fix outfile
* conversion: only allow selected models
* fix types
* cuda : do not use dmmv if the tensor does not have enough cols
* llama : lora fixes
* do not disable mmap with lora
Co-authored-by: slaren <slarengh@gmail.com>
* llm_build_lora_mm_id
* convert_lora : MoE LoRA conversion support
* convert_lora : prefer safetensors, similarly to convert_hf
* convert_hf : simplify modify_tensors for InternLM2
* convert_lora : lazy conversion
* llama : load and use alpha from LoRA adapters
* llama : use llm_build_lora_mm in most model graphs
* auto scale
* Revert "auto scale"
This reverts commit 42415a4874
.
* remove redundant params
* Apply suggestions from code review
Co-authored-by: slaren <slarengh@gmail.com>
* change kv metadata
* move add_type to __init__
* convert_hf : move add_type to main()
* convert_lora : use the GGUFWriter from Model instead of overwriting it
---------
Co-authored-by: slaren <slarengh@gmail.com>
Co-authored-by: Francis Couture-Harpin <git@compilade.net>
This commit is contained in:
parent
4db8f60fe7
commit
97bdd26eee
@ -685,7 +685,6 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
|||||||
if (arg == "--lora") {
|
if (arg == "--lora") {
|
||||||
CHECK_ARG
|
CHECK_ARG
|
||||||
params.lora_adapter.emplace_back(argv[i], 1.0f);
|
params.lora_adapter.emplace_back(argv[i], 1.0f);
|
||||||
params.use_mmap = false;
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
if (arg == "--lora-scaled") {
|
if (arg == "--lora-scaled") {
|
||||||
@ -693,7 +692,6 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
|||||||
const char* lora_adapter = argv[i];
|
const char* lora_adapter = argv[i];
|
||||||
CHECK_ARG
|
CHECK_ARG
|
||||||
params.lora_adapter.emplace_back(lora_adapter, std::stof(argv[i]));
|
params.lora_adapter.emplace_back(lora_adapter, std::stof(argv[i]));
|
||||||
params.use_mmap = false;
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
if (arg == "--lora-base") {
|
if (arg == "--lora-base") {
|
||||||
@ -2089,19 +2087,14 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
|
|||||||
for (unsigned int i = 0; i < params.lora_adapter.size(); ++i) {
|
for (unsigned int i = 0; i < params.lora_adapter.size(); ++i) {
|
||||||
const std::string & lora_adapter = std::get<0>(params.lora_adapter[i]);
|
const std::string & lora_adapter = std::get<0>(params.lora_adapter[i]);
|
||||||
float lora_scale = std::get<1>(params.lora_adapter[i]);
|
float lora_scale = std::get<1>(params.lora_adapter[i]);
|
||||||
int err = llama_model_apply_lora_from_file(model,
|
auto adapter = llama_lora_adapter_init(model, lora_adapter.c_str());
|
||||||
lora_adapter.c_str(),
|
if (adapter == nullptr) {
|
||||||
lora_scale,
|
|
||||||
((i > 0) || params.lora_base.empty())
|
|
||||||
? NULL
|
|
||||||
: params.lora_base.c_str(),
|
|
||||||
params.n_threads);
|
|
||||||
if (err != 0) {
|
|
||||||
fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
|
fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
|
||||||
llama_free(lctx);
|
llama_free(lctx);
|
||||||
llama_free_model(model);
|
llama_free_model(model);
|
||||||
return std::make_tuple(nullptr, nullptr);
|
return std::make_tuple(nullptr, nullptr);
|
||||||
}
|
}
|
||||||
|
llama_lora_adapter_set(lctx, adapter, lora_scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (params.ignore_eos) {
|
if (params.ignore_eos) {
|
||||||
|
@ -2264,13 +2264,6 @@ class InternLM2Model(Model):
|
|||||||
|
|
||||||
special_vocab.add_to_gguf(self.gguf_writer)
|
special_vocab.add_to_gguf(self.gguf_writer)
|
||||||
|
|
||||||
def _hf_permute_qk(self, weights, n_head: int, n_head_kv: int):
|
|
||||||
if n_head_kv is not None and n_head != n_head_kv:
|
|
||||||
n_head = n_head_kv
|
|
||||||
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
|
|
||||||
.swapaxes(1, 2)
|
|
||||||
.reshape(weights.shape))
|
|
||||||
|
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
self.gguf_writer.add_name("InternLM2")
|
self.gguf_writer.add_name("InternLM2")
|
||||||
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
|
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
|
||||||
@ -2290,26 +2283,22 @@ class InternLM2Model(Model):
|
|||||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||||
num_heads = self.hparams["num_attention_heads"]
|
num_heads = self.hparams["num_attention_heads"]
|
||||||
num_kv_heads = self.hparams["num_key_value_heads"]
|
num_kv_heads = self.hparams["num_key_value_heads"]
|
||||||
hidden_size = self.hparams["hidden_size"]
|
n_embd = self.hparams["hidden_size"]
|
||||||
q_per_kv = num_heads // num_kv_heads
|
q_per_kv = num_heads // num_kv_heads
|
||||||
head_dim = hidden_size // num_heads
|
head_dim = n_embd // num_heads
|
||||||
num_groups = num_heads // q_per_kv
|
num_groups = num_heads // q_per_kv
|
||||||
|
|
||||||
qkv_pattern = r"model\.layers\.(\d+)\.attention\.wqkv"
|
if bid is not None and f"model.layers.{bid}.attention.wqkv" in name:
|
||||||
|
|
||||||
if re.match(qkv_pattern, name):
|
|
||||||
bid = re.findall(qkv_pattern, name)[0]
|
|
||||||
qkv = data_torch
|
qkv = data_torch
|
||||||
# qkv = rearrange(qkv.T, " o (g n i) ->o g n i", g=num_groups, n=q_per_kv + 2, i=head_dim)
|
|
||||||
qkv = qkv.T.reshape((-1, num_groups, q_per_kv + 2, head_dim))
|
qkv = qkv.reshape((num_groups, q_per_kv + 2, head_dim, n_embd))
|
||||||
q, k, v = qkv[..., : q_per_kv, :], qkv[..., q_per_kv: q_per_kv + 1, :], qkv[..., q_per_kv + 1: q_per_kv + 2, :]
|
q, k, v = qkv[:, : q_per_kv], qkv[:, -2], qkv[:, -1]
|
||||||
|
|
||||||
# The model weights of q and k equire additional reshape.
|
# The model weights of q and k equire additional reshape.
|
||||||
# q = self._hf_permute_qk(rearrange(q, " o g n i -> o (g n i)").T, num_heads, num_heads)
|
q = LlamaModel.permute(q.reshape((-1, q.shape[-1])), num_heads, num_heads)
|
||||||
q = self._hf_permute_qk(q.reshape((q.shape[0], -1)).T, num_heads, num_heads)
|
k = LlamaModel.permute(k.reshape((-1, k.shape[-1])), num_heads, num_kv_heads)
|
||||||
# k = self._hf_permute_qk(rearrange(k, " o g n i -> o (g n i)").T, num_heads, num_kv_heads)
|
v = v.reshape((-1, v.shape[-1]))
|
||||||
k = self._hf_permute_qk(k.reshape((k.shape[0], -1)).T, num_heads, num_kv_heads)
|
|
||||||
# v = rearrange(v, " o g n i -> o (g n i)").T
|
|
||||||
v = v.reshape((v.shape[0], -1)).T
|
|
||||||
return [
|
return [
|
||||||
(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), q),
|
(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), q),
|
||||||
(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), k),
|
(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), k),
|
||||||
@ -3585,6 +3574,7 @@ def main() -> None:
|
|||||||
small_first_shard=args.no_tensor_first_split)
|
small_first_shard=args.no_tensor_first_split)
|
||||||
|
|
||||||
logger.info("Set model parameters")
|
logger.info("Set model parameters")
|
||||||
|
model_instance.gguf_writer.add_type(gguf.GGUFType.MODEL)
|
||||||
model_instance.set_gguf_parameters()
|
model_instance.set_gguf_parameters()
|
||||||
|
|
||||||
logger.info("Set model tokenizer")
|
logger.info("Set model tokenizer")
|
||||||
|
374
convert_lora_to_gguf.py
Executable file
374
convert_lora_to_gguf.py
Executable file
@ -0,0 +1,374 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import logging
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
from math import prod
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Sequence, SupportsIndex, cast
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
if 'NO_LOCAL_GGUF' not in os.environ:
|
||||||
|
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
|
||||||
|
import gguf
|
||||||
|
|
||||||
|
# reuse model definitions from convert_hf_to_gguf.py
|
||||||
|
from convert_hf_to_gguf import LazyTorchTensor, Model
|
||||||
|
|
||||||
|
logger = logging.getLogger("lora-to-gguf")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PartialLoraTensor:
|
||||||
|
A: Tensor | None = None
|
||||||
|
B: Tensor | None = None
|
||||||
|
|
||||||
|
|
||||||
|
# magic to support tensor shape modifications and splitting
|
||||||
|
class LoraTorchTensor:
|
||||||
|
_lora_A: Tensor # (n_rank, row_size)
|
||||||
|
_lora_B: Tensor # (col_size, n_rank)
|
||||||
|
_rank: int
|
||||||
|
|
||||||
|
def __init__(self, A: Tensor, B: Tensor):
|
||||||
|
assert len(A.shape) == len(B.shape)
|
||||||
|
assert A.shape[-2] == B.shape[-1]
|
||||||
|
if A.dtype != B.dtype:
|
||||||
|
A = A.to(torch.float32)
|
||||||
|
B = B.to(torch.float32)
|
||||||
|
self._lora_A = A
|
||||||
|
self._lora_B = B
|
||||||
|
self._rank = B.shape[-1]
|
||||||
|
|
||||||
|
def get_lora_A_B(self) -> tuple[Tensor, Tensor]:
|
||||||
|
return (self._lora_A, self._lora_B)
|
||||||
|
|
||||||
|
def __getitem__(
|
||||||
|
self,
|
||||||
|
indices: (
|
||||||
|
SupportsIndex
|
||||||
|
| slice
|
||||||
|
| tuple[SupportsIndex | slice | Tensor, ...] # TODO: add ellipsis in the type signature
|
||||||
|
),
|
||||||
|
) -> LoraTorchTensor:
|
||||||
|
shape = self.shape
|
||||||
|
if isinstance(indices, SupportsIndex):
|
||||||
|
if len(shape) > 2:
|
||||||
|
return LoraTorchTensor(self._lora_A[indices], self._lora_B[indices])
|
||||||
|
else:
|
||||||
|
raise NotImplementedError # can't return a vector
|
||||||
|
elif isinstance(indices, slice):
|
||||||
|
if len(shape) > 2:
|
||||||
|
return LoraTorchTensor(self._lora_A[indices], self._lora_B[indices])
|
||||||
|
else:
|
||||||
|
return LoraTorchTensor(self._lora_A, self._lora_B[indices])
|
||||||
|
elif isinstance(indices, tuple):
|
||||||
|
assert len(indices) > 0
|
||||||
|
if indices[-1] is Ellipsis:
|
||||||
|
return self[indices[:-1]]
|
||||||
|
# expand ellipsis
|
||||||
|
indices = tuple(
|
||||||
|
u
|
||||||
|
for v in (
|
||||||
|
(
|
||||||
|
(slice(None, None) for _ in range(len(indices) - 1))
|
||||||
|
if i is Ellipsis
|
||||||
|
else (i,)
|
||||||
|
)
|
||||||
|
for i in indices
|
||||||
|
)
|
||||||
|
for u in v
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(indices) < len(shape):
|
||||||
|
indices = (*indices, *(slice(None, None) for _ in range(len(indices), len(shape))))
|
||||||
|
|
||||||
|
# TODO: make sure this is correct
|
||||||
|
indices_A = (
|
||||||
|
*(
|
||||||
|
(
|
||||||
|
j.__index__() % self._lora_A.shape[i]
|
||||||
|
if isinstance(j, SupportsIndex)
|
||||||
|
else slice(None, None)
|
||||||
|
)
|
||||||
|
for i, j in enumerate(indices[:-2])
|
||||||
|
),
|
||||||
|
slice(None, None),
|
||||||
|
indices[-1],
|
||||||
|
)
|
||||||
|
indices_B = indices[:-1]
|
||||||
|
return LoraTorchTensor(self._lora_A[indices_A], self._lora_B[indices_B])
|
||||||
|
else:
|
||||||
|
raise NotImplementedError # unknown indice type
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self) -> torch.dtype:
|
||||||
|
assert self._lora_A.dtype == self._lora_B.dtype
|
||||||
|
return self._lora_A.dtype
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shape(self) -> tuple[int, ...]:
|
||||||
|
assert len(self._lora_A.shape) == len(self._lora_B.shape)
|
||||||
|
return (*self._lora_B.shape[:-1], self._lora_A.shape[-1])
|
||||||
|
|
||||||
|
def size(self, dim=None):
|
||||||
|
assert dim is None
|
||||||
|
return self.shape
|
||||||
|
|
||||||
|
def reshape(self, *shape: int | tuple[int, ...]) -> LoraTorchTensor:
|
||||||
|
if isinstance(shape[0], tuple):
|
||||||
|
new_shape: tuple[int, ...] = shape[0]
|
||||||
|
else:
|
||||||
|
new_shape = cast(tuple[int, ...], shape)
|
||||||
|
orig_shape = self.shape
|
||||||
|
if len(new_shape) < 2:
|
||||||
|
raise NotImplementedError # can't become a vector
|
||||||
|
|
||||||
|
# expand -1 in the shape
|
||||||
|
if any(dim == -1 for dim in new_shape):
|
||||||
|
n_elems = prod(orig_shape)
|
||||||
|
n_new_elems = prod(dim if dim != -1 else 1 for dim in new_shape)
|
||||||
|
assert n_elems % n_new_elems == 0
|
||||||
|
new_shape = (*(dim if dim != -1 else n_elems // n_new_elems for dim in new_shape),)
|
||||||
|
|
||||||
|
if new_shape[-1] != orig_shape[-1]:
|
||||||
|
raise NotImplementedError # can't reshape the row size trivially
|
||||||
|
|
||||||
|
shape_A = (*(1 for _ in new_shape[:-2]), self._rank, orig_shape[-1])
|
||||||
|
shape_B = (*new_shape[:-1], self._rank)
|
||||||
|
return LoraTorchTensor(
|
||||||
|
self._lora_A.reshape(shape_A),
|
||||||
|
self._lora_B.reshape(shape_B),
|
||||||
|
)
|
||||||
|
|
||||||
|
def reshape_as(self, other: Tensor) -> LoraTorchTensor:
|
||||||
|
return self.reshape(*other.shape)
|
||||||
|
|
||||||
|
def view(self, *size: int) -> LoraTorchTensor:
|
||||||
|
return self.reshape(*size)
|
||||||
|
|
||||||
|
def permute(self, *dims: int) -> LoraTorchTensor:
|
||||||
|
shape = self.shape
|
||||||
|
dims = tuple(dim - len(shape) if dim >= 0 else dim for dim in dims)
|
||||||
|
if dims[-1] == -1:
|
||||||
|
# TODO: support higher dimensional A shapes bigger than 1
|
||||||
|
assert all(dim == 1 for dim in self._lora_A.shape[:-2])
|
||||||
|
return LoraTorchTensor(self._lora_A, self._lora_B.permute(*dims))
|
||||||
|
if len(shape) == 2 and dims[-1] == -2 and dims[-2] == -1:
|
||||||
|
return LoraTorchTensor(self._lora_B.permute(*dims), self._lora_A.permute(*dims))
|
||||||
|
else:
|
||||||
|
# TODO: compose the above two
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def transpose(self, dim0: int, dim1: int) -> LoraTorchTensor:
|
||||||
|
shape = self.shape
|
||||||
|
dims = [i for i in range(len(shape))]
|
||||||
|
dims[dim0], dims[dim1] = dims[dim1], dims[dim0]
|
||||||
|
return self.permute(*dims)
|
||||||
|
|
||||||
|
def swapaxes(self, axis0: int, axis1: int) -> LoraTorchTensor:
|
||||||
|
return self.transpose(axis0, axis1)
|
||||||
|
|
||||||
|
def to(self, *args, **kwargs):
|
||||||
|
return LoraTorchTensor(self._lora_A.to(*args, **kwargs), self._lora_B.to(*args, **kwargs))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __torch_function__(cls, func: Callable, types, args=(), kwargs=None):
|
||||||
|
del types # unused
|
||||||
|
|
||||||
|
if kwargs is None:
|
||||||
|
kwargs = {}
|
||||||
|
|
||||||
|
if func is torch.permute:
|
||||||
|
return type(args[0]).permute(*args, **kwargs)
|
||||||
|
elif func is torch.reshape:
|
||||||
|
return type(args[0]).reshape(*args, **kwargs)
|
||||||
|
elif func is torch.stack:
|
||||||
|
assert isinstance(args[0], Sequence)
|
||||||
|
dim = kwargs.get("dim", 0)
|
||||||
|
assert dim == 0
|
||||||
|
return LoraTorchTensor(
|
||||||
|
torch.stack([a._lora_A for a in args[0]], dim),
|
||||||
|
torch.stack([b._lora_B for b in args[0]], dim),
|
||||||
|
)
|
||||||
|
elif func is torch.cat:
|
||||||
|
assert isinstance(args[0], Sequence)
|
||||||
|
dim = kwargs.get("dim", 0)
|
||||||
|
assert dim == 0
|
||||||
|
if len(args[0][0].shape) > 2:
|
||||||
|
return LoraTorchTensor(
|
||||||
|
torch.cat([a._lora_A for a in args[0]], dim),
|
||||||
|
torch.cat([b._lora_B for b in args[0]], dim),
|
||||||
|
)
|
||||||
|
elif all(torch.equal(args[0][0]._lora_A, t._lora_A) for t in args[0][1:]):
|
||||||
|
return LoraTorchTensor(
|
||||||
|
args[0][0]._lora_A,
|
||||||
|
torch.cat([b._lora_B for b in args[0]], dim),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
def get_base_tensor_name(lora_tensor_name: str) -> str:
|
||||||
|
base_name = lora_tensor_name.replace("base_model.model.", "")
|
||||||
|
base_name = base_name.replace(".lora_A.weight", ".weight")
|
||||||
|
base_name = base_name.replace(".lora_B.weight", ".weight")
|
||||||
|
return base_name
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args() -> argparse.Namespace:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Convert a huggingface PEFT LoRA adapter to a GGML compatible file")
|
||||||
|
parser.add_argument(
|
||||||
|
"--outfile", type=Path,
|
||||||
|
help="path to write to; default: based on input. {ftype} will be replaced by the outtype.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f16",
|
||||||
|
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--bigendian", action="store_true",
|
||||||
|
help="model is executed on big endian machine",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-lazy", action="store_true",
|
||||||
|
help="use more RAM by computing all outputs before writing (use in case lazy evaluation is broken)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--verbose", action="store_true",
|
||||||
|
help="increase output verbosity",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--base", type=Path, required=True,
|
||||||
|
help="directory containing base model file",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"lora_path", type=Path,
|
||||||
|
help="directory containing LoRA adapter file",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args = parse_args()
|
||||||
|
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
|
||||||
|
|
||||||
|
ftype_map: dict[str, gguf.LlamaFileType] = {
|
||||||
|
"f32": gguf.LlamaFileType.ALL_F32,
|
||||||
|
"f16": gguf.LlamaFileType.MOSTLY_F16,
|
||||||
|
"bf16": gguf.LlamaFileType.MOSTLY_BF16,
|
||||||
|
"q8_0": gguf.LlamaFileType.MOSTLY_Q8_0,
|
||||||
|
"auto": gguf.LlamaFileType.GUESSED,
|
||||||
|
}
|
||||||
|
|
||||||
|
ftype = ftype_map[args.outtype]
|
||||||
|
|
||||||
|
dir_base_model: Path = args.base
|
||||||
|
dir_lora: Path = args.lora_path
|
||||||
|
lora_config = dir_lora / "adapter_config.json"
|
||||||
|
input_model = dir_lora / "adapter_model.safetensors"
|
||||||
|
|
||||||
|
if args.outfile is not None:
|
||||||
|
fname_out = args.outfile
|
||||||
|
else:
|
||||||
|
# output in the same directory as the model by default
|
||||||
|
fname_out = dir_lora / 'ggml-lora-{ftype}.gguf'
|
||||||
|
|
||||||
|
if os.path.exists(input_model):
|
||||||
|
# lazy import load_file only if lora is in safetensors format.
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
|
lora_model = load_file(input_model, device="cpu")
|
||||||
|
else:
|
||||||
|
input_model = os.path.join(dir_lora, "adapter_model.bin")
|
||||||
|
lora_model = torch.load(input_model, map_location="cpu", weights_only=True)
|
||||||
|
|
||||||
|
# load base model
|
||||||
|
logger.info(f"Loading base model: {dir_base_model.name}")
|
||||||
|
hparams = Model.load_hparams(dir_base_model)
|
||||||
|
with torch.inference_mode():
|
||||||
|
try:
|
||||||
|
model_class = Model.from_model_architecture(hparams["architectures"][0])
|
||||||
|
except NotImplementedError:
|
||||||
|
logger.error(f"Model {hparams['architectures'][0]} is not supported")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
class LoraModel(model_class):
|
||||||
|
model_arch = model_class.model_arch
|
||||||
|
|
||||||
|
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
|
||||||
|
tensor_map: dict[str, PartialLoraTensor] = {}
|
||||||
|
|
||||||
|
for name, tensor in lora_model.items():
|
||||||
|
if self.lazy:
|
||||||
|
tensor = LazyTorchTensor.from_eager(tensor)
|
||||||
|
base_name = get_base_tensor_name(name)
|
||||||
|
is_lora_a = ".lora_A.weight" in name
|
||||||
|
is_lora_b = ".lora_B.weight" in name
|
||||||
|
if not is_lora_a and not is_lora_b:
|
||||||
|
if ".base_layer.weight" in name:
|
||||||
|
continue
|
||||||
|
logger.error(f"Unexpected name '{name}': Not a lora_A or lora_B tensor")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if base_name in tensor_map:
|
||||||
|
if is_lora_a:
|
||||||
|
tensor_map[base_name].A = tensor
|
||||||
|
else:
|
||||||
|
tensor_map[base_name].B = tensor
|
||||||
|
else:
|
||||||
|
if is_lora_a:
|
||||||
|
tensor_map[base_name] = PartialLoraTensor(A=tensor)
|
||||||
|
else:
|
||||||
|
tensor_map[base_name] = PartialLoraTensor(B=tensor)
|
||||||
|
|
||||||
|
for name, tensor in tensor_map.items():
|
||||||
|
assert tensor.A is not None
|
||||||
|
assert tensor.B is not None
|
||||||
|
yield (name, cast(torch.Tensor, LoraTorchTensor(tensor.A, tensor.B)))
|
||||||
|
|
||||||
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||||
|
dest = super().modify_tensors(data_torch, name, bid)
|
||||||
|
for dest_name, dest_data in dest:
|
||||||
|
assert isinstance(dest_data, LoraTorchTensor)
|
||||||
|
lora_a, lora_b = dest_data.get_lora_A_B()
|
||||||
|
|
||||||
|
yield (dest_name + ".lora_a", lora_a)
|
||||||
|
yield (dest_name + ".lora_b", lora_b)
|
||||||
|
|
||||||
|
model_instance = LoraModel(
|
||||||
|
dir_base_model,
|
||||||
|
ftype,
|
||||||
|
fname_out,
|
||||||
|
is_big_endian=args.bigendian,
|
||||||
|
use_temp_file=False,
|
||||||
|
eager=args.no_lazy,
|
||||||
|
model_name=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(lora_config, "r") as f:
|
||||||
|
lparams: dict[str, Any] = json.load(f)
|
||||||
|
|
||||||
|
alpha = lparams["lora_alpha"]
|
||||||
|
|
||||||
|
model_instance.gguf_writer.add_string(gguf.Keys.General.TYPE, gguf.GGUFType.ADAPTER)
|
||||||
|
model_instance.gguf_writer.add_string(gguf.Keys.Adapter.TYPE, "lora")
|
||||||
|
model_instance.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, float(alpha))
|
||||||
|
model_instance.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION)
|
||||||
|
logger.info("Exporting model...")
|
||||||
|
model_instance.write()
|
||||||
|
logger.info(f"Model successfully exported to {model_instance.fname_out}")
|
@ -1876,7 +1876,8 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
|
|||||||
|
|
||||||
bool use_dequantize_mul_mat_vec = (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16)
|
bool use_dequantize_mul_mat_vec = (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16)
|
||||||
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
||||||
&& src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src1->ne[1] == 1;
|
&& src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src0->ne[0] >= GGML_CUDA_DMMV_X*2
|
||||||
|
&& src1->ne[1] == 1;
|
||||||
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
|
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
|
||||||
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
||||||
&& src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
|
&& src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
|
||||||
|
@ -19478,7 +19478,7 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
|
|||||||
|
|
||||||
fprintf(fp, "digraph G {\n");
|
fprintf(fp, "digraph G {\n");
|
||||||
fprintf(fp, " newrank = true;\n");
|
fprintf(fp, " newrank = true;\n");
|
||||||
fprintf(fp, " rankdir = LR;\n");
|
fprintf(fp, " rankdir = TB;\n");
|
||||||
|
|
||||||
for (int i = 0; i < gb->n_nodes; i++) {
|
for (int i = 0; i < gb->n_nodes; i++) {
|
||||||
struct ggml_tensor * node = gb->nodes[i];
|
struct ggml_tensor * node = gb->nodes[i];
|
||||||
@ -19540,7 +19540,7 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
|
|||||||
}
|
}
|
||||||
|
|
||||||
fprintf(fp, "CONST %d [%" PRId64 ", %" PRId64 "]", i, node->ne[0], node->ne[1]);
|
fprintf(fp, "CONST %d [%" PRId64 ", %" PRId64 "]", i, node->ne[0], node->ne[1]);
|
||||||
if (ggml_nelements(node) < 5) {
|
if (ggml_nelements(node) < 5 && node->data != NULL) {
|
||||||
fprintf(fp, " | (");
|
fprintf(fp, " | (");
|
||||||
for (int j = 0; j < ggml_nelements(node); j++) {
|
for (int j = 0; j < ggml_nelements(node); j++) {
|
||||||
if (node->type == GGML_TYPE_I8 || node->type == GGML_TYPE_I16 || node->type == GGML_TYPE_I32) {
|
if (node->type == GGML_TYPE_I8 || node->type == GGML_TYPE_I16 || node->type == GGML_TYPE_I32) {
|
||||||
|
@ -19,6 +19,7 @@ GGML_QUANT_VERSION = 2 # GGML_QNT_VERSION from ggml.h
|
|||||||
|
|
||||||
class Keys:
|
class Keys:
|
||||||
class General:
|
class General:
|
||||||
|
TYPE = "general.type"
|
||||||
ARCHITECTURE = "general.architecture"
|
ARCHITECTURE = "general.architecture"
|
||||||
QUANTIZATION_VERSION = "general.quantization_version"
|
QUANTIZATION_VERSION = "general.quantization_version"
|
||||||
ALIGNMENT = "general.alignment"
|
ALIGNMENT = "general.alignment"
|
||||||
@ -120,11 +121,20 @@ class Keys:
|
|||||||
MIDDLE_ID = "tokenizer.ggml.middle_token_id"
|
MIDDLE_ID = "tokenizer.ggml.middle_token_id"
|
||||||
EOT_ID = "tokenizer.ggml.eot_token_id"
|
EOT_ID = "tokenizer.ggml.eot_token_id"
|
||||||
|
|
||||||
|
class Adapter:
|
||||||
|
TYPE = "adapter.type"
|
||||||
|
LORA_ALPHA = "adapter.lora.alpha"
|
||||||
|
|
||||||
#
|
#
|
||||||
# recommended mapping of model tensor names for storage in gguf
|
# recommended mapping of model tensor names for storage in gguf
|
||||||
#
|
#
|
||||||
|
|
||||||
|
|
||||||
|
class GGUFType:
|
||||||
|
MODEL = "model"
|
||||||
|
ADAPTER = "adapter"
|
||||||
|
|
||||||
|
|
||||||
class MODEL_ARCH(IntEnum):
|
class MODEL_ARCH(IntEnum):
|
||||||
LLAMA = auto()
|
LLAMA = auto()
|
||||||
FALCON = auto()
|
FALCON = auto()
|
||||||
|
@ -424,6 +424,9 @@ class GGUFWriter:
|
|||||||
fout.close()
|
fout.close()
|
||||||
self.fout = None
|
self.fout = None
|
||||||
|
|
||||||
|
def add_type(self, type_name: str) -> None:
|
||||||
|
self.add_string(Keys.General.TYPE, type_name)
|
||||||
|
|
||||||
def add_architecture(self) -> None:
|
def add_architecture(self) -> None:
|
||||||
self.add_string(Keys.General.ARCHITECTURE, self.arch)
|
self.add_string(Keys.General.ARCHITECTURE, self.arch)
|
||||||
|
|
||||||
|
@ -43,7 +43,7 @@ def __apply_over_grouped_rows(func: Callable[[np.ndarray], np.ndarray], arr: np.
|
|||||||
osize *= dim
|
osize *= dim
|
||||||
out = np.empty(shape=osize, dtype=otype)
|
out = np.empty(shape=osize, dtype=otype)
|
||||||
# compute over groups of 16 rows (arbitrary, but seems good for performance)
|
# compute over groups of 16 rows (arbitrary, but seems good for performance)
|
||||||
n_groups = rows.shape[0] // 16
|
n_groups = (rows.shape[0] // 16) or 1
|
||||||
np.concatenate([func(group).ravel() for group in np.array_split(rows, n_groups)], axis=0, out=out)
|
np.concatenate([func(group).ravel() for group in np.array_split(rows, n_groups)], axis=0, out=out)
|
||||||
return out.reshape(oshape)
|
return out.reshape(oshape)
|
||||||
|
|
||||||
|
@ -411,6 +411,9 @@ extern "C" {
|
|||||||
const char * content;
|
const char * content;
|
||||||
} llama_chat_message;
|
} llama_chat_message;
|
||||||
|
|
||||||
|
// lora adapter
|
||||||
|
struct llama_lora_adapter;
|
||||||
|
|
||||||
// Helpers for getting default parameters
|
// Helpers for getting default parameters
|
||||||
LLAMA_API struct llama_model_params llama_model_default_params(void);
|
LLAMA_API struct llama_model_params llama_model_default_params(void);
|
||||||
LLAMA_API struct llama_context_params llama_context_default_params(void);
|
LLAMA_API struct llama_context_params llama_context_default_params(void);
|
||||||
@ -510,18 +513,28 @@ extern "C" {
|
|||||||
const char * fname_out,
|
const char * fname_out,
|
||||||
const llama_model_quantize_params * params);
|
const llama_model_quantize_params * params);
|
||||||
|
|
||||||
// Apply a LoRA adapter to a loaded model
|
// Load a LoRA adapter from file
|
||||||
// path_base_model is the path to a higher quality model to use as a base for
|
// The loaded adapter will be associated to the given model, and will be free when the model is deleted
|
||||||
// the layers modified by the adapter. Can be NULL to use the current loaded model.
|
LLAMA_API struct llama_lora_adapter * llama_lora_adapter_init(
|
||||||
// The model needs to be reloaded before applying a new adapter, otherwise the adapter
|
struct llama_model * model,
|
||||||
// will be applied on top of the previous one
|
const char * path_lora);
|
||||||
// Returns 0 on success
|
|
||||||
LLAMA_API int32_t llama_model_apply_lora_from_file(
|
// Add a loaded LoRA adapter to given context
|
||||||
const struct llama_model * model,
|
// This will not modify model's weight
|
||||||
const char * path_lora,
|
LLAMA_API int32_t llama_lora_adapter_set(
|
||||||
float scale,
|
struct llama_context * ctx,
|
||||||
const char * path_base_model,
|
struct llama_lora_adapter * adapter,
|
||||||
int32_t n_threads);
|
float scale);
|
||||||
|
|
||||||
|
// Remove a LoRA adapter from given context
|
||||||
|
// Return -1 if the adapter is not present in the context
|
||||||
|
LLAMA_API int32_t llama_lora_adapter_remove(
|
||||||
|
struct llama_context * ctx,
|
||||||
|
struct llama_lora_adapter * adapter);
|
||||||
|
|
||||||
|
// Manually free a LoRA adapter
|
||||||
|
// Note: loaded adapters will be free when the associated model is deleted
|
||||||
|
LLAMA_API void llama_lora_adapter_free(struct llama_lora_adapter * adapter);
|
||||||
|
|
||||||
// Apply a loaded control vector to a llama_context, or if data is NULL, clear
|
// Apply a loaded control vector to a llama_context, or if data is NULL, clear
|
||||||
// the currently loaded vector.
|
// the currently loaded vector.
|
||||||
|
@ -9,3 +9,4 @@
|
|||||||
-r ./requirements/requirements-convert_hf_to_gguf.txt
|
-r ./requirements/requirements-convert_hf_to_gguf.txt
|
||||||
-r ./requirements/requirements-convert_hf_to_gguf_update.txt
|
-r ./requirements/requirements-convert_hf_to_gguf_update.txt
|
||||||
-r ./requirements/requirements-convert_llama_ggml_to_gguf.txt
|
-r ./requirements/requirements-convert_llama_ggml_to_gguf.txt
|
||||||
|
-r ./requirements/requirements-convert_lora_to_gguf.txt
|
||||||
|
2
requirements/requirements-convert_lora_to_gguf.txt
Normal file
2
requirements/requirements-convert_lora_to_gguf.txt
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
-r ./requirements-convert_hf_to_gguf.txt
|
||||||
|
--extra-index-url https://download.pytorch.org/whl/cpu
|
988
src/llama.cpp
988
src/llama.cpp
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user