store mqa directly

This commit is contained in:
Meng Zhang 2023-09-15 14:18:36 +08:00
parent 4420cff654
commit ab13d071e1
2 changed files with 2 additions and 20 deletions

View File

@ -209,24 +209,6 @@ for part_name in part_names:
data = data.squeeze().numpy() data = data.squeeze().numpy()
if name.endswith(".attn.c_attn.weight") or name.endswith(".attn.c_attn.bias"):
print("Duplicate K,V heads to use MHA instead of MQA for", name)
embed_dim = hparams["n_embd"]
head_dim = embed_dim // hparams["n_head"]
# ((n_heads + 2) * head_dim, hidden_dim) -> (3 * n_heads * head_dim, hidden_dim)
q, k ,v = np.split(data, (hparams["n_head"] * head_dim, (hparams["n_head"] + 1) * head_dim), axis=0)
# duplicate k, v along the first axis (head_dim, hidden_dim) -> (n_heads * head_dim, hidden_dim)
if len(k.shape) == 2:
k = np.tile(k, (hparams["n_head"], 1))
v = np.tile(v, (hparams["n_head"], 1))
elif len(k.shape) == 1:
k = np.tile(k, (hparams["n_head"]))
v = np.tile(v, (hparams["n_head"]))
# concat q, k, v along the first axis (n_heads * head_dim, hidden_dim) -> (3 * n_heads * head_dim, hidden_dim)
data = np.concatenate((q, k, v), axis=0)
# map tensor names # map tensor names
new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias")) new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias"))
if new_name is None: if new_name is None:

View File

@ -2259,8 +2259,8 @@ static void llm_load_tensors(
layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend); layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend);
layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, backend); layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, backend);
layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, 3*n_embd}, backend_split); layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split);
layer.bqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {3*n_embd}, backend_split); layer.bqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, backend_split);
layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split);
layer.bo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, backend_split); layer.bo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, backend_split);