convert : update phi-2 to latest HF repo (#4903)

* convert : update phi-2 to latest HF repo

ggml-ci

* py : try to fix flake stuff
This commit is contained in:
Georgi Gerganov 2024-01-13 13:44:37 +02:00 committed by GitHub
parent de473f5f8e
commit 15ebe59210
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 65 additions and 21 deletions

View File

@ -23,6 +23,15 @@ if 'NO_LOCAL_GGUF' not in os.environ:
import gguf import gguf
# check for any of the given keys in the dictionary and return the value of the first key found
def get_key_opts(d, keys):
for k in keys:
if k in d:
return d[k]
print(f"Could not find any of {keys}")
sys.exit()
###### MODEL DEFINITIONS ###### ###### MODEL DEFINITIONS ######
class SentencePieceTokenTypes(IntEnum): class SentencePieceTokenTypes(IntEnum):
@ -257,6 +266,7 @@ class Model:
toktypes.append(gguf.TokenType.USER_DEFINED) toktypes.append(gguf.TokenType.USER_DEFINED)
elif reverse_vocab[i] in added_vocab: elif reverse_vocab[i] in added_vocab:
tokens.append(reverse_vocab[i]) tokens.append(reverse_vocab[i])
if hasattr(tokenizer, "added_tokens_decoder"):
if tokenizer.added_tokens_decoder[i].special: if tokenizer.added_tokens_decoder[i].special:
toktypes.append(gguf.TokenType.CONTROL) toktypes.append(gguf.TokenType.CONTROL)
else: else:
@ -1068,17 +1078,22 @@ class GPT2Model(Model):
class Phi2Model(Model): class Phi2Model(Model):
def set_gguf_parameters(self): def set_gguf_parameters(self):
block_count = self.hparams["n_layer"] block_count = get_key_opts(self.hparams, ["num_hidden_layers", "n_layer"])
rot_pct = get_key_opts(self.hparams, ["partial_rotary_factor"])
n_embd = get_key_opts(self.hparams, ["hidden_size", "n_embd"])
n_head = get_key_opts(self.hparams, ["num_attention_heads", "n_head"])
self.gguf_writer.add_name("Phi2") self.gguf_writer.add_name("Phi2")
self.gguf_writer.add_context_length(self.hparams["n_positions"]) self.gguf_writer.add_context_length(get_key_opts(self.hparams, ["n_positions", "max_position_embeddings"]))
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"]) self.gguf_writer.add_embedding_length(n_embd)
self.gguf_writer.add_feed_forward_length(4 * n_embd)
self.gguf_writer.add_block_count(block_count) self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_head_count(self.hparams["n_head"]) self.gguf_writer.add_head_count(n_head)
self.gguf_writer.add_head_count_kv(self.hparams["n_head"]) self.gguf_writer.add_head_count_kv(n_head)
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"]) self.gguf_writer.add_layer_norm_eps(get_key_opts(self.hparams, ["layer_norm_epsilon", "layer_norm_eps"]))
self.gguf_writer.add_rope_dimension_count(self.hparams["rotary_dim"]) self.gguf_writer.add_rope_dimension_count(int(rot_pct * n_embd) // n_head)
self.gguf_writer.add_file_type(self.ftype) self.gguf_writer.add_file_type(self.ftype)
self.gguf_writer.add_add_bos_token(False) self.gguf_writer.add_add_bos_token(False)

View File

@ -389,6 +389,9 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.OUTPUT, MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM, MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT, MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM, MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_DOWN,

View File

@ -191,6 +191,7 @@ class TensorNameMap:
"transformer.h.{bid}.mlp.w1", # qwen "transformer.h.{bid}.mlp.w1", # qwen
"h.{bid}.mlp.c_fc", # gpt2 "h.{bid}.mlp.c_fc", # gpt2
"transformer.h.{bid}.mlp.fc1", # phi2 "transformer.h.{bid}.mlp.fc1", # phi2
"model.layers.{bid}.mlp.fc1", # phi2
"model.layers.layers.{bid}.mlp.up_proj", # plamo "model.layers.layers.{bid}.mlp.up_proj", # plamo
), ),
@ -232,6 +233,7 @@ class TensorNameMap:
"model.layers.{bid}.mlp.dense_4h_to_h", # persimmon "model.layers.{bid}.mlp.dense_4h_to_h", # persimmon
"h.{bid}.mlp.c_proj", # gpt2 "h.{bid}.mlp.c_proj", # gpt2
"transformer.h.{bid}.mlp.fc2", # phi2 "transformer.h.{bid}.mlp.fc2", # phi2
"model.layers.{bid}.mlp.fc2", # phi2
"model.layers.layers.{bid}.mlp.down_proj", # plamo "model.layers.layers.{bid}.mlp.down_proj", # plamo
), ),

View File

@ -574,6 +574,9 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES =
{ LLM_TENSOR_OUTPUT, "output" }, { LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
@ -3676,8 +3679,19 @@ static bool llm_load_tensors(
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd});
layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, false);
layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}); layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, false);
if (layer.wqkv == nullptr) {
layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd});
layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa});
layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa});
}
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd});
@ -5637,15 +5651,25 @@ struct llm_build_context {
// self-attention // self-attention
{ {
struct ggml_tensor * Qcur = nullptr;
struct ggml_tensor * Kcur = nullptr;
struct ggml_tensor * Vcur = nullptr;
if (model.layers[il].wqkv) {
cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, attn_norm_output); cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, attn_norm_output);
cb(cur, "wqkv", il); cb(cur, "wqkv", il);
cur = ggml_add(ctx0, cur, model.layers[il].bqkv); cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
cb(cur, "bqkv", il); cb(cur, "bqkv", il);
struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
} else {
Qcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, attn_norm_output), model.layers[il].bq);
Kcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, attn_norm_output), model.layers[il].bk);
Vcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, attn_norm_output), model.layers[il].bv);
}
cb(Qcur, "Qcur", il); cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il); cb(Kcur, "Kcur", il);