mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 10:54:36 +00:00
convert : update Falcon script for new HF config (#3448)
Also adds Falcon-180B support. Closes #3049 Co-authored-by: jb <jonathan.t.barnard@gmail.com>
This commit is contained in:
parent
45eba9369f
commit
48edda30ee
@ -4,6 +4,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import contextlib
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import struct
|
import struct
|
||||||
@ -20,10 +21,10 @@ if 'NO_LOCAL_GGUF' not in os.environ:
|
|||||||
import gguf
|
import gguf
|
||||||
|
|
||||||
|
|
||||||
def count_model_parts(dir_model: Path) -> int:
|
def count_model_parts(dir_model: Path, prefix: str) -> int:
|
||||||
num_parts = 0
|
num_parts = 0
|
||||||
for filename in os.listdir(dir_model):
|
for filename in os.listdir(dir_model):
|
||||||
if filename.startswith("pytorch_model-"):
|
if filename.startswith(prefix):
|
||||||
num_parts += 1
|
num_parts += 1
|
||||||
|
|
||||||
if num_parts > 0:
|
if num_parts > 0:
|
||||||
@ -77,20 +78,26 @@ print("gguf: loading model "+dir_model.name)
|
|||||||
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
|
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
|
||||||
hparams = json.load(f)
|
hparams = json.load(f)
|
||||||
|
|
||||||
if hparams["architectures"][0] != "RWForCausalLM":
|
if hparams["architectures"][0] != "FalconForCausalLM":
|
||||||
print("Model architecture not supported: " + hparams["architectures"][0])
|
print("Model architecture not supported: " + hparams["architectures"][0])
|
||||||
|
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
# get number of model parts
|
# get number of model parts
|
||||||
num_parts = count_model_parts(dir_model)
|
num_parts = count_model_parts(dir_model, "model-00")
|
||||||
|
if num_parts:
|
||||||
|
is_safetensors = True
|
||||||
|
from safetensors import safe_open
|
||||||
|
else:
|
||||||
|
is_safetensors = False
|
||||||
|
num_parts = count_model_parts(dir_model, "pytorch_model-")
|
||||||
|
|
||||||
ARCH=gguf.MODEL_ARCH.FALCON
|
ARCH=gguf.MODEL_ARCH.FALCON
|
||||||
gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH])
|
gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH])
|
||||||
|
|
||||||
print("gguf: get model metadata")
|
print("gguf: get model metadata")
|
||||||
|
|
||||||
block_count = hparams["n_layer"]
|
block_count = hparams["num_hidden_layers"]
|
||||||
|
|
||||||
gguf_writer.add_name("Falcon")
|
gguf_writer.add_name("Falcon")
|
||||||
gguf_writer.add_context_length(2048) # not in config.json
|
gguf_writer.add_context_length(2048) # not in config.json
|
||||||
@ -98,9 +105,9 @@ gguf_writer.add_tensor_data_layout("jploski") # qkv tensor transform
|
|||||||
gguf_writer.add_embedding_length(hparams["hidden_size"])
|
gguf_writer.add_embedding_length(hparams["hidden_size"])
|
||||||
gguf_writer.add_feed_forward_length(4 * hparams["hidden_size"])
|
gguf_writer.add_feed_forward_length(4 * hparams["hidden_size"])
|
||||||
gguf_writer.add_block_count(block_count)
|
gguf_writer.add_block_count(block_count)
|
||||||
gguf_writer.add_head_count(hparams["n_head"])
|
gguf_writer.add_head_count(hparams["num_attention_heads"])
|
||||||
if "n_head_kv" in hparams:
|
if "num_kv_heads" in hparams:
|
||||||
gguf_writer.add_head_count_kv(hparams["n_head_kv"])
|
gguf_writer.add_head_count_kv(hparams["num_kv_heads"])
|
||||||
else:
|
else:
|
||||||
gguf_writer.add_head_count_kv(1)
|
gguf_writer.add_head_count_kv(1)
|
||||||
gguf_writer.add_layer_norm_eps(hparams["layer_norm_epsilon"])
|
gguf_writer.add_layer_norm_eps(hparams["layer_norm_epsilon"])
|
||||||
@ -146,8 +153,8 @@ special_vocab.add_to_gguf(gguf_writer)
|
|||||||
tensor_map = gguf.get_tensor_name_map(ARCH,block_count)
|
tensor_map = gguf.get_tensor_name_map(ARCH,block_count)
|
||||||
|
|
||||||
# params for qkv transform
|
# params for qkv transform
|
||||||
n_head = hparams["n_head"]
|
n_head = hparams["num_attention_heads"]
|
||||||
n_head_kv = hparams["n_head_kv"] if "n_head_kv" in hparams else 1
|
n_head_kv = hparams["num_kv_heads"] if "num_kv_heads" in hparams else 1
|
||||||
|
|
||||||
head_dim = hparams["hidden_size"] // n_head
|
head_dim = hparams["hidden_size"] // n_head
|
||||||
|
|
||||||
@ -156,6 +163,10 @@ print("gguf: get tensor metadata")
|
|||||||
|
|
||||||
if num_parts == 0:
|
if num_parts == 0:
|
||||||
part_names = iter(("pytorch_model.bin",))
|
part_names = iter(("pytorch_model.bin",))
|
||||||
|
elif is_safetensors:
|
||||||
|
part_names = (
|
||||||
|
f"model-{n:05}-of-{num_parts:05}.safetensors" for n in range(1, num_parts + 1)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
part_names = (
|
part_names = (
|
||||||
f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1)
|
f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1)
|
||||||
@ -165,60 +176,64 @@ for part_name in part_names:
|
|||||||
if args.vocab_only:
|
if args.vocab_only:
|
||||||
break
|
break
|
||||||
print("gguf: loading model part '" + part_name + "'")
|
print("gguf: loading model part '" + part_name + "'")
|
||||||
model_part = torch.load(dir_model / part_name, map_location="cpu")
|
if is_safetensors:
|
||||||
|
ctx = safe_open(dir_model / part_name, framework="pt", device="cpu")
|
||||||
|
else:
|
||||||
|
ctx = contextlib.nullcontext(torch.load(dir_model / part_name, map_location="cpu"))
|
||||||
|
|
||||||
for name in model_part.keys():
|
with ctx as model_part:
|
||||||
data = model_part[name]
|
for name in model_part.keys():
|
||||||
|
data = model_part.get_tensor(name) if is_safetensors else model_part[name]
|
||||||
|
|
||||||
old_dtype = data.dtype
|
old_dtype = data.dtype
|
||||||
|
|
||||||
# convert any unsupported data types to float32
|
# convert any unsupported data types to float32
|
||||||
if data.dtype != torch.float16 and data.dtype != torch.float32:
|
if data.dtype != torch.float16 and data.dtype != torch.float32:
|
||||||
data = data.to(torch.float32)
|
data = data.to(torch.float32)
|
||||||
|
|
||||||
# QKV tensor transform
|
# QKV tensor transform
|
||||||
# The original query_key_value tensor contains n_head_kv "kv groups",
|
# The original query_key_value tensor contains n_head_kv "kv groups",
|
||||||
# each consisting of n_head/n_head_kv query weights followed by one key
|
# each consisting of n_head/n_head_kv query weights followed by one key
|
||||||
# and one value weight (shared by all query heads in the kv group).
|
# and one value weight (shared by all query heads in the kv group).
|
||||||
# This layout makes it a big pain to work with in GGML.
|
# This layout makes it a big pain to work with in GGML.
|
||||||
# So we rearrange them here,, so that we have n_head query weights
|
# So we rearrange them here,, so that we have n_head query weights
|
||||||
# followed by n_head_kv key weights followed by n_head_kv value weights,
|
# followed by n_head_kv key weights followed by n_head_kv value weights,
|
||||||
# in contiguous fashion.
|
# in contiguous fashion.
|
||||||
# ref: https://github.com/jploski/ggml/blob/falcon40b/examples/falcon/convert-hf-to-ggml.py
|
# ref: https://github.com/jploski/ggml/blob/falcon40b/examples/falcon/convert-hf-to-ggml.py
|
||||||
|
|
||||||
if "query_key_value" in name:
|
if "query_key_value" in name:
|
||||||
qkv = data.view(n_head_kv, n_head // n_head_kv + 2, head_dim, head_dim * n_head)
|
qkv = data.view(n_head_kv, n_head // n_head_kv + 2, head_dim, head_dim * n_head)
|
||||||
q = qkv[:, :-2 ].reshape(n_head * head_dim, head_dim * n_head)
|
q = qkv[:, :-2 ].reshape(n_head * head_dim, head_dim * n_head)
|
||||||
k = qkv[:, [-2]].reshape(n_head_kv * head_dim, head_dim * n_head)
|
k = qkv[:, [-2]].reshape(n_head_kv * head_dim, head_dim * n_head)
|
||||||
v = qkv[:, [-1]].reshape(n_head_kv * head_dim, head_dim * n_head)
|
v = qkv[:, [-1]].reshape(n_head_kv * head_dim, head_dim * n_head)
|
||||||
data = torch.cat((q,k,v)).reshape_as(data)
|
data = torch.cat((q,k,v)).reshape_as(data)
|
||||||
|
|
||||||
data = data.squeeze().numpy()
|
data = data.squeeze().numpy()
|
||||||
|
|
||||||
# map tensor names
|
# map tensor names
|
||||||
new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias"))
|
new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias"))
|
||||||
if new_name is None:
|
if new_name is None:
|
||||||
print("Can not map tensor '" + name + "'")
|
print("Can not map tensor '" + name + "'")
|
||||||
sys.exit()
|
sys.exit()
|
||||||
|
|
||||||
n_dims = len(data.shape)
|
n_dims = len(data.shape)
|
||||||
data_dtype = data.dtype
|
data_dtype = data.dtype
|
||||||
|
|
||||||
# if f32 desired, convert any float16 to float32
|
# if f32 desired, convert any float16 to float32
|
||||||
if ftype == 0 and data_dtype == np.float16:
|
if ftype == 0 and data_dtype == np.float16:
|
||||||
data = data.astype(np.float32)
|
data = data.astype(np.float32)
|
||||||
|
|
||||||
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
|
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
|
||||||
if ftype == 1 and data_dtype == np.float16 and n_dims == 1:
|
if ftype == 1 and data_dtype == np.float16 and n_dims == 1:
|
||||||
data = data.astype(np.float32)
|
data = data.astype(np.float32)
|
||||||
|
|
||||||
# if f16 desired, convert any float32 2-dim weight tensors to float16
|
# if f16 desired, convert any float32 2-dim weight tensors to float16
|
||||||
if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
||||||
data = data.astype(np.float16)
|
data = data.astype(np.float16)
|
||||||
|
|
||||||
print(new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))
|
print(new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))
|
||||||
|
|
||||||
gguf_writer.add_tensor(new_name, data)
|
gguf_writer.add_tensor(new_name, data)
|
||||||
|
|
||||||
|
|
||||||
print("gguf: write header")
|
print("gguf: write header")
|
||||||
|
Loading…
Reference in New Issue
Block a user