convert.py : better always have n_head_kv and default it to n_head

This commit is contained in:
Georgi Gerganov 2023-08-17 18:47:06 +03:00
parent d646c4efce
commit 8ace03ad3d
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -155,12 +155,7 @@ class Params:
n_layer = config["num_hidden_layers"] n_layer = config["num_hidden_layers"]
n_ff = config["intermediate_size"] n_ff = config["intermediate_size"]
n_head = config["num_attention_heads"] n_head = config["num_attention_heads"]
n_head_kv = config["num_key_value_heads"] if "num_key_value_heads" in config else n_head
if "num_key_value_heads" in config:
n_head_kv = config["num_key_value_heads"]
else:
n_head_kv = None
f_norm_eps = config["rms_norm_eps"] f_norm_eps = config["rms_norm_eps"]
n_mult = Params.find_n_mult(n_ff, n_embd) n_mult = Params.find_n_mult(n_ff, n_embd)
@ -719,7 +714,7 @@ class OutputFile:
self.gguf.add_feed_forward_length (params.n_ff) self.gguf.add_feed_forward_length (params.n_ff)
self.gguf.add_rope_dimension_count(params.n_embd // params.n_head) self.gguf.add_rope_dimension_count(params.n_embd // params.n_head)
self.gguf.add_head_count (params.n_head) self.gguf.add_head_count (params.n_head)
if params.n_head_kv is not None: self.gguf.add_head_count_kv(params.n_head_kv) self.gguf.add_head_count_kv (params.n_head_kv)
self.gguf.add_layer_norm_rms_eps (params.f_norm_eps) self.gguf.add_layer_norm_rms_eps (params.f_norm_eps)
def add_meta_vocab(self, vocab: Vocab) -> None: def add_meta_vocab(self, vocab: Vocab) -> None: