From 2dd5d2c92c52e19f30f1545d922689ddd216bbbb Mon Sep 17 00:00:00 2001 From: klosax <131523366+klosax@users.noreply.github.com> Date: Tue, 15 Aug 2023 00:43:10 +0200 Subject: [PATCH] convert-llama-h5-to-gguf.py : add 70b gqa support --- convert-llama-h5-to-gguf.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/convert-llama-h5-to-gguf.py b/convert-llama-h5-to-gguf.py index 1a805cff8..22405673f 100644 --- a/convert-llama-h5-to-gguf.py +++ b/convert-llama-h5-to-gguf.py @@ -1,4 +1,4 @@ -# HF llama --> gguf conversion, GQA/70b not supported +# HF llama --> gguf conversion import gguf import gguf_namemap as tmap @@ -10,7 +10,7 @@ import json import numpy as np import torch -from typing import Any, List +from typing import Any, List, Optional from pathlib import Path from sentencepiece import SentencePieceProcessor @@ -18,11 +18,11 @@ from sentencepiece import SentencePieceProcessor # compatible with python < 3.9 NDArray: 'TypeAlias' = 'np.ndarray[Any, Any]' - -def permute(weights: NDArray, n_head: int) -> NDArray: +def permute(weights: NDArray, n_head: int, n_kv_head: Optional[int] = None) -> NDArray: + if n_kv_head is not None and n_head != n_kv_head: n_head //= n_kv_head return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) - .swapaxes(1, 2) - .reshape(weights.shape)) + .swapaxes(1, 2) + .reshape(weights.shape)) def count_model_parts(dir_model: str) -> int: num_parts = 0 @@ -220,7 +220,7 @@ for part_name in part_names: # permute these if name.endswith(".q_proj.weight") or name.endswith(".k_proj.weight"): - data = permute(data,head_count) + data = permute(data, head_count, head_count_kv) # map tensor names if name.endswith(".weight") and name[:-7] in tensor_map: @@ -289,7 +289,7 @@ for part_name in part_names: # permute these if name.endswith(".q_proj.weight") or name.endswith(".k_proj.weight"): - data = permute(data, head_count) + data = permute(data, head_count, head_count_kv) # map tensor names if name.endswith(".weight") and name[:-7] in tensor_map: