MPT : support GQA for replit-code-v1.5 (#3627)

This commit is contained in:
cebtenzzre 2023-10-15 02:32:06 -04:00 committed by GitHub
parent 11dc1091f6
commit 11bff29045
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 5 additions and 3 deletions

View File

@ -98,6 +98,8 @@ gguf_writer.add_embedding_length(hparams["d_model"])
gguf_writer.add_block_count(block_count) gguf_writer.add_block_count(block_count)
gguf_writer.add_feed_forward_length(4 * hparams["d_model"]) gguf_writer.add_feed_forward_length(4 * hparams["d_model"])
gguf_writer.add_head_count(hparams["n_heads"]) gguf_writer.add_head_count(hparams["n_heads"])
if kv_n_heads := hparams["attn_config"].get("kv_n_heads"):
gguf_writer.add_head_count_kv(kv_n_heads)
gguf_writer.add_layer_norm_eps(1e-05) gguf_writer.add_layer_norm_eps(1e-05)
if hparams["attn_config"]["clip_qkv"] is not None: if hparams["attn_config"]["clip_qkv"] is not None:
gguf_writer.add_clamp_kqv(hparams["attn_config"]["clip_qkv"]) gguf_writer.add_clamp_kqv(hparams["attn_config"]["clip_qkv"])

View File

@ -2839,7 +2839,7 @@ static void llm_load_tensors(
auto & layer = model.layers[i]; auto & layer = model.layers[i];
layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend); layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend);
layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, 3*n_embd}, backend_split); layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split);
layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split);
layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend); layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend);
@ -5368,7 +5368,7 @@ static struct ggml_cgraph * llm_build_mpt(
const int64_t n_layer = hparams.n_layer; const int64_t n_layer = hparams.n_layer;
const int64_t n_ctx = cparams.n_ctx; const int64_t n_ctx = cparams.n_ctx;
const int64_t n_head = hparams.n_head; const int64_t n_head = hparams.n_head;
const int64_t n_head_kv = hparams.n_head_kv; // == n_head for MPT, as there's no MQA/GQA const int64_t n_head_kv = hparams.n_head_kv;
const int64_t n_embd_head = hparams.n_embd_head(); const int64_t n_embd_head = hparams.n_embd_head();
const int64_t n_embd_gqa = hparams.n_embd_gqa(); const int64_t n_embd_gqa = hparams.n_embd_gqa();