mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 11:24:35 +00:00
convert-new.py : pick #2427 for HF 70B support
This commit is contained in:
parent
c8ee87f141
commit
5ec18934ad
@ -104,7 +104,7 @@ TENSORS_SET = set(TENSORS_LIST)
|
||||
|
||||
def find_n_mult(n_ff: int, n_embd: int) -> int:
|
||||
# hardcoded magic range
|
||||
for n_mult in range(256, 1, -1):
|
||||
for n_mult in range(8192, 1, -1):
|
||||
calc_ff = (((8*n_embd) // 3 + n_mult - 1) // n_mult)*n_mult
|
||||
if calc_ff == n_ff:
|
||||
return n_mult
|
||||
@ -113,11 +113,12 @@ def find_n_mult(n_ff: int, n_embd: int) -> int:
|
||||
|
||||
@dataclass
|
||||
class Params:
|
||||
n_vocab: int
|
||||
n_embd: int
|
||||
n_mult: int
|
||||
n_head: int
|
||||
n_layer: int
|
||||
n_vocab: int
|
||||
n_embd: int
|
||||
n_mult: int
|
||||
n_head: int
|
||||
n_layer: int
|
||||
n_kv_head: Optional[int] # This parameter is only used for Llama 2
|
||||
|
||||
@staticmethod
|
||||
def guessed(model: 'LazyModel') -> 'Params':
|
||||
@ -139,31 +140,34 @@ class Params:
|
||||
n_head=n_embd // 128 # guessed
|
||||
|
||||
return Params(
|
||||
n_vocab = n_vocab,
|
||||
n_embd = n_embd,
|
||||
n_mult = 256,
|
||||
n_head = n_head,
|
||||
n_layer = n_layer,
|
||||
n_vocab = n_vocab,
|
||||
n_embd = n_embd,
|
||||
n_mult = 256,
|
||||
n_head = n_head,
|
||||
n_layer = n_layer,
|
||||
n_kv_head = None,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def loadHFTransformerJson(model: 'LazyModel', config_path: 'Path') -> 'Params':
|
||||
config = json.load(open(config_path))
|
||||
|
||||
n_vocab = config["vocab_size"];
|
||||
n_embd = config["hidden_size"];
|
||||
n_head = config["num_attention_heads"];
|
||||
n_layer = config["num_hidden_layers"];
|
||||
n_ff = config["intermediate_size"];
|
||||
n_vocab = config["vocab_size"];
|
||||
n_embd = config["hidden_size"];
|
||||
n_head = config["num_attention_heads"];
|
||||
n_layer = config["num_hidden_layers"];
|
||||
n_ff = config["intermediate_size"];
|
||||
n_kv_head = config.get("num_key_value_heads")
|
||||
|
||||
n_mult = find_n_mult(n_ff, n_embd);
|
||||
|
||||
return Params(
|
||||
n_vocab = n_vocab,
|
||||
n_embd = n_embd,
|
||||
n_mult = n_mult,
|
||||
n_head = n_head,
|
||||
n_layer = n_layer,
|
||||
n_vocab = n_vocab,
|
||||
n_embd = n_embd,
|
||||
n_mult = n_mult,
|
||||
n_head = n_head,
|
||||
n_layer = n_layer,
|
||||
n_kv_head = n_kv_head,
|
||||
)
|
||||
|
||||
# LLaMA v2 70B params.json
|
||||
@ -182,11 +186,12 @@ class Params:
|
||||
n_vocab = model["tok_embeddings.weight"].shape[0]
|
||||
|
||||
return Params(
|
||||
n_vocab = n_vocab,
|
||||
n_embd = n_embd,
|
||||
n_mult = n_mult,
|
||||
n_head = n_head,
|
||||
n_layer = n_layer,
|
||||
n_vocab = n_vocab,
|
||||
n_embd = n_embd,
|
||||
n_mult = n_mult,
|
||||
n_head = n_head,
|
||||
n_layer = n_layer,
|
||||
n_kv_head = None,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -293,10 +298,12 @@ class SentencePieceVocab:
|
||||
Vocab = Union[BpeVocab, SentencePieceVocab]
|
||||
|
||||
|
||||
def permute(weights: NDArray, n_head: int) -> NDArray:
|
||||
def permute(weights: NDArray, n_head: int, n_kv_head: Optional[int] = None) -> NDArray:
|
||||
if n_kv_head is not None and n_head != n_kv_head:
|
||||
n_head //= n_kv_head
|
||||
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
|
||||
.swapaxes(1, 2)
|
||||
.reshape(weights.shape))
|
||||
.swapaxes(1, 2)
|
||||
.reshape(weights.shape))
|
||||
|
||||
|
||||
class Tensor(metaclass=ABCMeta):
|
||||
@ -305,7 +312,7 @@ class Tensor(metaclass=ABCMeta):
|
||||
@abstractmethod
|
||||
def astype(self, data_type: DataType) -> 'Tensor': ...
|
||||
@abstractmethod
|
||||
def permute(self, n_head: int) -> 'Tensor': ...
|
||||
def permute(self, n_head: int, n_kv_head: Optional[int] = None) -> 'Tensor': ...
|
||||
@abstractmethod
|
||||
def permute_part(self, n_part: int, n_head: int) -> 'UnquantizedTensor': ...
|
||||
@abstractmethod
|
||||
@ -343,8 +350,8 @@ class UnquantizedTensor(Tensor):
|
||||
r = self.ndarray.shape[0] // 3
|
||||
return UnquantizedTensor(self.ndarray[r * n_part : r * n_part + r, ...])
|
||||
|
||||
def permute(self, n_head: int) -> 'UnquantizedTensor':
|
||||
return UnquantizedTensor(permute(self.ndarray, n_head))
|
||||
def permute(self, n_head: int, n_kv_head: Optional[int] = None) -> 'UnquantizedTensor':
|
||||
return UnquantizedTensor(permute(self.ndarray, n_head, n_kv_head))
|
||||
|
||||
|
||||
def load_unquantized(lazy_tensor: 'LazyTensor', expected_dtype: Any = None, convert: bool = False) -> NDArray:
|
||||
@ -367,18 +374,18 @@ GGMLCompatibleTensor = Union[UnquantizedTensor]
|
||||
|
||||
|
||||
class DeferredPermutedTensor(Tensor):
|
||||
def __init__(self, base: Tensor, n_head: int) -> None:
|
||||
def __init__(self, base: Tensor, n_head: int, n_kv_head: Optional[int] = None) -> None:
|
||||
self.base = base
|
||||
self.n_head = n_head
|
||||
self.data_type = self.base.data_type
|
||||
|
||||
def astype(self, data_type: DataType) -> Tensor:
|
||||
return self.base.astype(data_type).permute(self.n_head)
|
||||
return self.base.astype(data_type).permute(self.n_head, self.n_kv_head)
|
||||
|
||||
def to_ggml(self) -> GGMLCompatibleTensor:
|
||||
return self.base.to_ggml().permute(self.n_head)
|
||||
return self.base.to_ggml().permute(self.n_head, self.n_kv_head)
|
||||
|
||||
def permute(self, n_head: int) -> Tensor:
|
||||
def permute(self, n_head: int, n_kv_head: Optional[int] = None) -> Tensor:
|
||||
raise Exception("shouldn't permute twice")
|
||||
|
||||
|
||||
@ -474,10 +481,10 @@ def merge_multifile_models(models_plus: List[ModelPlus]) -> ModelPlus:
|
||||
return ModelPlus(model, paths, format, vocab)
|
||||
|
||||
|
||||
def permute_lazy(lazy_tensor: LazyTensor, n_head: int) -> LazyTensor:
|
||||
def permute_lazy(lazy_tensor: LazyTensor, n_head: int, n_kv_head: Optional[int] = None) -> LazyTensor:
|
||||
def load() -> Tensor:
|
||||
return lazy_tensor.load().permute(n_head)
|
||||
return LazyTensor(load, lazy_tensor.shape, lazy_tensor.data_type, f'permute({n_head}) ' + lazy_tensor.description)
|
||||
return lazy_tensor.load().permute(n_head, n_kv_head)
|
||||
return LazyTensor(load, lazy_tensor.shape, lazy_tensor.data_type, f'permute({n_head}, {n_kv_head}) ' + lazy_tensor.description)
|
||||
|
||||
def permute_part_lazy(lazy_tensor: LazyTensor, n_part: int, n_head: int) -> LazyTensor:
|
||||
def load() -> Tensor:
|
||||
@ -502,7 +509,7 @@ def convert_transformers_to_orig(model: LazyModel, params: Params) -> LazyModel:
|
||||
for i in itertools.count():
|
||||
if f"model.layers.{i}.self_attn.q_proj.weight" in model:
|
||||
out[f"layers.{i}.attention.wq.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.q_proj.weight"], params.n_head)
|
||||
out[f"layers.{i}.attention.wk.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.k_proj.weight"], params.n_head)
|
||||
out[f"layers.{i}.attention.wk.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.k_proj.weight"], params.n_head, params.n_kv_head)
|
||||
out[f"layers.{i}.attention.wv.weight"] = model[f"model.layers.{i}.self_attn.v_proj.weight"]
|
||||
elif f"model.layers.{i}.self_attn.W_pack.weight" in model:
|
||||
out[f"layers.{i}.attention.wq.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 0, params.n_head)
|
||||
|
Loading…
Reference in New Issue
Block a user