mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-29 04:44:34 +00:00
convert-llama-h5-to-gguf.py : add 70b gqa support
This commit is contained in:
parent
ca4758290c
commit
2dd5d2c92c
@ -1,4 +1,4 @@
|
|||||||
# HF llama --> gguf conversion, GQA/70b not supported
|
# HF llama --> gguf conversion
|
||||||
|
|
||||||
import gguf
|
import gguf
|
||||||
import gguf_namemap as tmap
|
import gguf_namemap as tmap
|
||||||
@ -10,7 +10,7 @@ import json
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from typing import Any, List
|
from typing import Any, List, Optional
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from sentencepiece import SentencePieceProcessor
|
from sentencepiece import SentencePieceProcessor
|
||||||
|
|
||||||
@ -18,11 +18,11 @@ from sentencepiece import SentencePieceProcessor
|
|||||||
# compatible with python < 3.9
|
# compatible with python < 3.9
|
||||||
NDArray: 'TypeAlias' = 'np.ndarray[Any, Any]'
|
NDArray: 'TypeAlias' = 'np.ndarray[Any, Any]'
|
||||||
|
|
||||||
|
def permute(weights: NDArray, n_head: int, n_kv_head: Optional[int] = None) -> NDArray:
|
||||||
def permute(weights: NDArray, n_head: int) -> NDArray:
|
if n_kv_head is not None and n_head != n_kv_head: n_head //= n_kv_head
|
||||||
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
|
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
|
||||||
.swapaxes(1, 2)
|
.swapaxes(1, 2)
|
||||||
.reshape(weights.shape))
|
.reshape(weights.shape))
|
||||||
|
|
||||||
def count_model_parts(dir_model: str) -> int:
|
def count_model_parts(dir_model: str) -> int:
|
||||||
num_parts = 0
|
num_parts = 0
|
||||||
@ -220,7 +220,7 @@ for part_name in part_names:
|
|||||||
|
|
||||||
# permute these
|
# permute these
|
||||||
if name.endswith(".q_proj.weight") or name.endswith(".k_proj.weight"):
|
if name.endswith(".q_proj.weight") or name.endswith(".k_proj.weight"):
|
||||||
data = permute(data,head_count)
|
data = permute(data, head_count, head_count_kv)
|
||||||
|
|
||||||
# map tensor names
|
# map tensor names
|
||||||
if name.endswith(".weight") and name[:-7] in tensor_map:
|
if name.endswith(".weight") and name[:-7] in tensor_map:
|
||||||
@ -289,7 +289,7 @@ for part_name in part_names:
|
|||||||
|
|
||||||
# permute these
|
# permute these
|
||||||
if name.endswith(".q_proj.weight") or name.endswith(".k_proj.weight"):
|
if name.endswith(".q_proj.weight") or name.endswith(".k_proj.weight"):
|
||||||
data = permute(data, head_count)
|
data = permute(data, head_count, head_count_kv)
|
||||||
|
|
||||||
# map tensor names
|
# map tensor names
|
||||||
if name.endswith(".weight") and name[:-7] in tensor_map:
|
if name.endswith(".weight") and name[:-7] in tensor_map:
|
||||||
|
Loading…
Reference in New Issue
Block a user