py : try to fix flake stuff

This commit is contained in:
Georgi Gerganov 2024-01-13 13:34:08 +02:00
parent fe252237a3
commit 1fb563ebdc
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -25,7 +25,6 @@ import gguf
# check for any of the given keys in the dictionary and return the value of the first key found
def get_key_opts(d, keys):
vals = []
for k in keys:
if k in d:
return d[k]
@ -267,7 +266,6 @@ class Model:
toktypes.append(gguf.TokenType.USER_DEFINED)
elif reverse_vocab[i] in added_vocab:
tokens.append(reverse_vocab[i])
# check if tokenizer has added_tokens_decoder
if hasattr(tokenizer, "added_tokens_decoder"):
if tokenizer.added_tokens_decoder[i].special:
toktypes.append(gguf.TokenType.CONTROL)
@ -1082,17 +1080,20 @@ class Phi2Model(Model):
def set_gguf_parameters(self):
block_count = get_key_opts(self.hparams, ["num_hidden_layers", "n_layer"])
rot_pct = get_key_opts(self.hparams, ["partial_rotary_factor"])
n_embd = get_key_opts(self.hparams, ["hidden_size", "n_embd"])
n_head = get_key_opts(self.hparams, ["num_attention_heads", "n_head"])
self.gguf_writer.add_name("Phi2")
self.gguf_writer.add_context_length(get_key_opts(self.hparams, ["n_positions", "max_position_embeddings"]))
self.gguf_writer.add_embedding_length(get_key_opts(self.hparams, ["n_embd", "hidden_size"]))
self.gguf_writer.add_feed_forward_length(4 * get_key_opts(self.hparams, ["n_embd", "hidden_size"]))
self.gguf_writer.add_embedding_length(n_embd)
self.gguf_writer.add_feed_forward_length(4 * n_embd)
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_head_count(get_key_opts(self.hparams, ["n_head", "num_attention_heads"]))
self.gguf_writer.add_head_count_kv(get_key_opts(self.hparams, ["n_head", "num_attention_heads"]))
self.gguf_writer.add_head_count(n_head)
self.gguf_writer.add_head_count_kv(n_head)
self.gguf_writer.add_layer_norm_eps(get_key_opts(self.hparams, ["layer_norm_epsilon", "layer_norm_eps"]))
self.gguf_writer.add_rope_dimension_count(
int(get_key_opts(self.hparams, ["partial_rotary_factor"]) * get_key_opts(self.hparams, ["n_embd", "hidden_size"])) // get_key_opts(self.hparams, ["n_head", "num_attention_heads"]))
self.gguf_writer.add_rope_dimension_count(int(rot_pct * n_embd) // n_head)
self.gguf_writer.add_file_type(self.ftype)
self.gguf_writer.add_add_bos_token(False)