mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-11-11 13:30:35 +00:00
llama : add PLaMo model (#3557)
* add plamo mock * add tensor loading * plamo convert * update norm * able to compile * fix norm_rms_eps hparam * runnable * use inp_pos * seems ok * update kqv code * remove develop code * update README * shuffle attn_q.weight and attn_output.weight for broadcasting * remove plamo_llm_build_kqv and use llm_build_kqv * fix style * update * llama : remove obsolete KQ_scale * plamo : fix tensor names for correct GPU offload --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
5bf3953d7e
commit
753be377b6
@ -102,6 +102,7 @@ as the main playground for developing new features for the [ggml](https://github
|
|||||||
- [x] [Deepseek models](https://huggingface.co/models?search=deepseek-ai/deepseek)
|
- [x] [Deepseek models](https://huggingface.co/models?search=deepseek-ai/deepseek)
|
||||||
- [x] [Qwen models](https://huggingface.co/models?search=Qwen/Qwen)
|
- [x] [Qwen models](https://huggingface.co/models?search=Qwen/Qwen)
|
||||||
- [x] [Mixtral MoE](https://huggingface.co/models?search=mistral-ai/Mixtral)
|
- [x] [Mixtral MoE](https://huggingface.co/models?search=mistral-ai/Mixtral)
|
||||||
|
- [x] [PLaMo-13B](https://github.com/ggerganov/llama.cpp/pull/3557)
|
||||||
|
|
||||||
**Multimodal models:**
|
**Multimodal models:**
|
||||||
|
|
||||||
|
@ -184,6 +184,8 @@ class Model:
|
|||||||
return MixtralModel
|
return MixtralModel
|
||||||
if model_architecture == "PhiForCausalLM":
|
if model_architecture == "PhiForCausalLM":
|
||||||
return Phi2Model
|
return Phi2Model
|
||||||
|
if model_architecture == "PlamoForCausalLM":
|
||||||
|
return PlamoModel
|
||||||
return Model
|
return Model
|
||||||
|
|
||||||
def _is_model_safetensors(self) -> bool:
|
def _is_model_safetensors(self) -> bool:
|
||||||
@ -225,6 +227,8 @@ class Model:
|
|||||||
return gguf.MODEL_ARCH.LLAMA
|
return gguf.MODEL_ARCH.LLAMA
|
||||||
if arch == "PhiForCausalLM":
|
if arch == "PhiForCausalLM":
|
||||||
return gguf.MODEL_ARCH.PHI2
|
return gguf.MODEL_ARCH.PHI2
|
||||||
|
if arch == "PlamoForCausalLM":
|
||||||
|
return gguf.MODEL_ARCH.PLAMO
|
||||||
|
|
||||||
raise NotImplementedError(f'Architecture "{arch}" not supported!')
|
raise NotImplementedError(f'Architecture "{arch}" not supported!')
|
||||||
|
|
||||||
@ -1002,11 +1006,91 @@ class Phi2Model(Model):
|
|||||||
self.gguf_writer.add_add_bos_token(False)
|
self.gguf_writer.add_add_bos_token(False)
|
||||||
|
|
||||||
|
|
||||||
|
class PlamoModel(Model):
|
||||||
|
def set_vocab(self):
|
||||||
|
self._set_vocab_sentencepiece()
|
||||||
|
|
||||||
|
def set_gguf_parameters(self):
|
||||||
|
hparams = self.hparams
|
||||||
|
block_count = hparams["num_hidden_layers"]
|
||||||
|
|
||||||
|
self.gguf_writer.add_name("PLaMo")
|
||||||
|
self.gguf_writer.add_context_length(4096) # not in config.json
|
||||||
|
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
|
||||||
|
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
|
||||||
|
self.gguf_writer.add_block_count(block_count)
|
||||||
|
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
|
||||||
|
self.gguf_writer.add_head_count_kv(5) # hparams["num_key_value_heads"]) is wrong
|
||||||
|
self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"])
|
||||||
|
|
||||||
|
def shuffle_attn_q_weight(self, data_torch):
|
||||||
|
assert data_torch.size() == (5120, 5120)
|
||||||
|
data_torch = data_torch.reshape(8, 5, 128, 5120)
|
||||||
|
data_torch = torch.permute(data_torch, (1, 0, 2, 3))
|
||||||
|
data_torch = torch.reshape(data_torch, (5120, 5120))
|
||||||
|
return data_torch
|
||||||
|
|
||||||
|
def shuffle_attn_output_weight(self, data_torch):
|
||||||
|
assert data_torch.size() == (5120, 5120)
|
||||||
|
data_torch = data_torch.reshape(5120, 8, 5, 128)
|
||||||
|
data_torch = torch.permute(data_torch, (0, 2, 1, 3))
|
||||||
|
data_torch = torch.reshape(data_torch, (5120, 5120))
|
||||||
|
return data_torch
|
||||||
|
|
||||||
|
def write_tensors(self):
|
||||||
|
block_count = self.hparams.get("num_layers", self.hparams.get("num_hidden_layers"))
|
||||||
|
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
|
||||||
|
|
||||||
|
for name, data_torch in self.get_tensors():
|
||||||
|
if "self_attn.rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# map tensor names
|
||||||
|
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
||||||
|
if new_name is None:
|
||||||
|
print(f"Can not map tensor {name!r}")
|
||||||
|
sys.exit()
|
||||||
|
|
||||||
|
# shuffle for broadcasting of gqa in ggml_mul_mat
|
||||||
|
if new_name.endswith("attn_q.weight"):
|
||||||
|
data_torch = self.shuffle_attn_q_weight(data_torch)
|
||||||
|
elif new_name.endswith("attn_output.weight"):
|
||||||
|
data_torch = self.shuffle_attn_output_weight(data_torch)
|
||||||
|
|
||||||
|
old_dtype = data_torch.dtype
|
||||||
|
|
||||||
|
# convert any unsupported data types to float32
|
||||||
|
if data_torch.dtype not in (torch.float16, torch.float32):
|
||||||
|
data_torch = data_torch.to(torch.float32)
|
||||||
|
|
||||||
|
data = data_torch.squeeze().numpy()
|
||||||
|
|
||||||
|
n_dims = len(data.shape)
|
||||||
|
data_dtype = data.dtype
|
||||||
|
|
||||||
|
# if f32 desired, convert any float16 to float32
|
||||||
|
if self.ftype == 0 and data_dtype == np.float16:
|
||||||
|
data = data.astype(np.float32)
|
||||||
|
|
||||||
|
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
|
||||||
|
if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
|
||||||
|
data = data.astype(np.float32)
|
||||||
|
|
||||||
|
# if f16 desired, convert any float32 2-dim weight tensors to float16
|
||||||
|
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
||||||
|
data = data.astype(np.float16)
|
||||||
|
|
||||||
|
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
||||||
|
|
||||||
|
self.gguf_writer.add_tensor(new_name, data)
|
||||||
|
|
||||||
|
|
||||||
###### CONVERSION LOGIC ######
|
###### CONVERSION LOGIC ######
|
||||||
|
|
||||||
|
|
||||||
def parse_args() -> argparse.Namespace:
|
def parse_args() -> argparse.Namespace:
|
||||||
parser = argparse.ArgumentParser(description="Convert a huggingface model to a GGML compatible file")
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Convert a huggingface model to a GGML compatible file")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--vocab-only", action="store_true",
|
"--vocab-only", action="store_true",
|
||||||
help="extract only the vocab",
|
help="extract only the vocab",
|
||||||
|
@ -96,6 +96,7 @@ class MODEL_ARCH(IntEnum):
|
|||||||
STABLELM = auto()
|
STABLELM = auto()
|
||||||
QWEN = auto()
|
QWEN = auto()
|
||||||
PHI2 = auto()
|
PHI2 = auto()
|
||||||
|
PLAMO = auto()
|
||||||
|
|
||||||
|
|
||||||
class MODEL_TENSOR(IntEnum):
|
class MODEL_TENSOR(IntEnum):
|
||||||
@ -142,6 +143,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
|||||||
MODEL_ARCH.STABLELM: "stablelm",
|
MODEL_ARCH.STABLELM: "stablelm",
|
||||||
MODEL_ARCH.QWEN: "qwen",
|
MODEL_ARCH.QWEN: "qwen",
|
||||||
MODEL_ARCH.PHI2: "phi2",
|
MODEL_ARCH.PHI2: "phi2",
|
||||||
|
MODEL_ARCH.PLAMO: "plamo",
|
||||||
}
|
}
|
||||||
|
|
||||||
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||||
@ -349,6 +351,21 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||||||
MODEL_TENSOR.FFN_DOWN,
|
MODEL_TENSOR.FFN_DOWN,
|
||||||
MODEL_TENSOR.FFN_UP,
|
MODEL_TENSOR.FFN_UP,
|
||||||
],
|
],
|
||||||
|
MODEL_ARCH.PLAMO: [
|
||||||
|
MODEL_TENSOR.TOKEN_EMBD,
|
||||||
|
MODEL_TENSOR.OUTPUT_NORM,
|
||||||
|
MODEL_TENSOR.OUTPUT,
|
||||||
|
MODEL_TENSOR.ROPE_FREQS,
|
||||||
|
MODEL_TENSOR.ATTN_NORM,
|
||||||
|
MODEL_TENSOR.ATTN_Q,
|
||||||
|
MODEL_TENSOR.ATTN_K,
|
||||||
|
MODEL_TENSOR.ATTN_V,
|
||||||
|
MODEL_TENSOR.ATTN_OUT,
|
||||||
|
MODEL_TENSOR.ATTN_ROT_EMBD,
|
||||||
|
MODEL_TENSOR.FFN_GATE,
|
||||||
|
MODEL_TENSOR.FFN_DOWN,
|
||||||
|
MODEL_TENSOR.FFN_UP,
|
||||||
|
],
|
||||||
MODEL_ARCH.GPT2: [
|
MODEL_ARCH.GPT2: [
|
||||||
# TODO
|
# TODO
|
||||||
],
|
],
|
||||||
|
@ -79,6 +79,7 @@ class TensorNameMap:
|
|||||||
"language_model.encoder.layers.{bid}.input_layernorm", # persimmon
|
"language_model.encoder.layers.{bid}.input_layernorm", # persimmon
|
||||||
"model.layers.{bid}.ln1", # yi
|
"model.layers.{bid}.ln1", # yi
|
||||||
"transformer.h.{bid}.ln", # phi2
|
"transformer.h.{bid}.ln", # phi2
|
||||||
|
"model.layers.layers.{bid}.norm", # plamo
|
||||||
),
|
),
|
||||||
|
|
||||||
# Attention norm 2
|
# Attention norm 2
|
||||||
@ -99,26 +100,29 @@ class TensorNameMap:
|
|||||||
|
|
||||||
# Attention query
|
# Attention query
|
||||||
MODEL_TENSOR.ATTN_Q: (
|
MODEL_TENSOR.ATTN_Q: (
|
||||||
"model.layers.{bid}.self_attn.q_proj", # llama-hf
|
"model.layers.{bid}.self_attn.q_proj", # llama-hf
|
||||||
"layers.{bid}.attention.wq", # llama-pth
|
"layers.{bid}.attention.wq", # llama-pth
|
||||||
"encoder.layer.{bid}.attention.self.query", # bert
|
"encoder.layer.{bid}.attention.self.query", # bert
|
||||||
"transformer.h.{bid}.attn.q_proj", # gpt-j
|
"transformer.h.{bid}.attn.q_proj", # gpt-j
|
||||||
|
"model.layers.layers.{bid}.self_attn.q_proj", # plamo
|
||||||
),
|
),
|
||||||
|
|
||||||
# Attention key
|
# Attention key
|
||||||
MODEL_TENSOR.ATTN_K: (
|
MODEL_TENSOR.ATTN_K: (
|
||||||
"model.layers.{bid}.self_attn.k_proj", # llama-hf
|
"model.layers.{bid}.self_attn.k_proj", # llama-hf
|
||||||
"layers.{bid}.attention.wk", # llama-pth
|
"layers.{bid}.attention.wk", # llama-pth
|
||||||
"encoder.layer.{bid}.attention.self.key", # bert
|
"encoder.layer.{bid}.attention.self.key", # bert
|
||||||
"transformer.h.{bid}.attn.k_proj", # gpt-j
|
"transformer.h.{bid}.attn.k_proj", # gpt-j
|
||||||
|
"model.layers.layers.{bid}.self_attn.k_proj", # plamo
|
||||||
),
|
),
|
||||||
|
|
||||||
# Attention value
|
# Attention value
|
||||||
MODEL_TENSOR.ATTN_V: (
|
MODEL_TENSOR.ATTN_V: (
|
||||||
"model.layers.{bid}.self_attn.v_proj", # llama-hf
|
"model.layers.{bid}.self_attn.v_proj", # llama-hf
|
||||||
"layers.{bid}.attention.wv", # llama-pth
|
"layers.{bid}.attention.wv", # llama-pth
|
||||||
"encoder.layer.{bid}.attention.self.value", # bert
|
"encoder.layer.{bid}.attention.self.value", # bert
|
||||||
"transformer.h.{bid}.attn.v_proj", # gpt-j
|
"transformer.h.{bid}.attn.v_proj", # gpt-j
|
||||||
|
"model.layers.layers.{bid}.self_attn.v_proj", # plamo
|
||||||
),
|
),
|
||||||
|
|
||||||
# Attention output
|
# Attention output
|
||||||
@ -134,12 +138,14 @@ class TensorNameMap:
|
|||||||
"transformer.h.{bid}.attn.out_proj", # gpt-j
|
"transformer.h.{bid}.attn.out_proj", # gpt-j
|
||||||
"language_model.encoder.layers.{bid}.self_attention.dense", # persimmon
|
"language_model.encoder.layers.{bid}.self_attention.dense", # persimmon
|
||||||
"transformer.h.{bid}.mixer.out_proj", # phi2
|
"transformer.h.{bid}.mixer.out_proj", # phi2
|
||||||
|
"model.layers.layers.{bid}.self_attn.o_proj", # plamo
|
||||||
),
|
),
|
||||||
|
|
||||||
# Rotary embeddings
|
# Rotary embeddings
|
||||||
MODEL_TENSOR.ATTN_ROT_EMBD: (
|
MODEL_TENSOR.ATTN_ROT_EMBD: (
|
||||||
"model.layers.{bid}.self_attn.rotary_emb.inv_freq", # llama-hf
|
"model.layers.{bid}.self_attn.rotary_emb.inv_freq", # llama-hf
|
||||||
"layers.{bid}.attention.inner_attention.rope.freqs", # llama-pth
|
"layers.{bid}.attention.inner_attention.rope.freqs", # llama-pth
|
||||||
|
"model.layers.layers.{bid}.self_attn.rotary_emb.inv_freq", # plamo
|
||||||
),
|
),
|
||||||
|
|
||||||
# Feed-forward norm
|
# Feed-forward norm
|
||||||
@ -174,6 +180,7 @@ class TensorNameMap:
|
|||||||
"language_model.encoder.layers.{bid}.mlp.dense_h_to_4h", # persimmon
|
"language_model.encoder.layers.{bid}.mlp.dense_h_to_4h", # persimmon
|
||||||
"transformer.h.{bid}.mlp.w1", # qwen
|
"transformer.h.{bid}.mlp.w1", # qwen
|
||||||
"transformer.h.{bid}.mlp.fc1", # phi2
|
"transformer.h.{bid}.mlp.fc1", # phi2
|
||||||
|
"model.layers.layers.{bid}.mlp.up_proj", # plamo
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.FFN_UP_EXP: (
|
MODEL_TENSOR.FFN_UP_EXP: (
|
||||||
@ -186,6 +193,7 @@ class TensorNameMap:
|
|||||||
"model.layers.{bid}.mlp.gate_proj", # llama-hf refact
|
"model.layers.{bid}.mlp.gate_proj", # llama-hf refact
|
||||||
"layers.{bid}.feed_forward.w1", # llama-pth
|
"layers.{bid}.feed_forward.w1", # llama-pth
|
||||||
"transformer.h.{bid}.mlp.w2", # qwen
|
"transformer.h.{bid}.mlp.w2", # qwen
|
||||||
|
"model.layers.layers.{bid}.mlp.gate_proj", # plamo
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.FFN_GATE_EXP: (
|
MODEL_TENSOR.FFN_GATE_EXP: (
|
||||||
@ -206,6 +214,7 @@ class TensorNameMap:
|
|||||||
"transformer.h.{bid}.mlp.fc_out", # gpt-j
|
"transformer.h.{bid}.mlp.fc_out", # gpt-j
|
||||||
"language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon
|
"language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon
|
||||||
"transformer.h.{bid}.mlp.fc2", # phi2
|
"transformer.h.{bid}.mlp.fc2", # phi2
|
||||||
|
"model.layers.layers.{bid}.mlp.down_proj", # plamo
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.FFN_DOWN_EXP: (
|
MODEL_TENSOR.FFN_DOWN_EXP: (
|
||||||
|
181
llama.cpp
181
llama.cpp
@ -198,6 +198,7 @@ enum llm_arch {
|
|||||||
LLM_ARCH_STABLELM,
|
LLM_ARCH_STABLELM,
|
||||||
LLM_ARCH_QWEN,
|
LLM_ARCH_QWEN,
|
||||||
LLM_ARCH_PHI2,
|
LLM_ARCH_PHI2,
|
||||||
|
LLM_ARCH_PLAMO,
|
||||||
LLM_ARCH_UNKNOWN,
|
LLM_ARCH_UNKNOWN,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -216,6 +217,7 @@ static std::map<llm_arch, std::string> LLM_ARCH_NAMES = {
|
|||||||
{ LLM_ARCH_STABLELM, "stablelm" },
|
{ LLM_ARCH_STABLELM, "stablelm" },
|
||||||
{ LLM_ARCH_QWEN, "qwen" },
|
{ LLM_ARCH_QWEN, "qwen" },
|
||||||
{ LLM_ARCH_PHI2, "phi2" },
|
{ LLM_ARCH_PHI2, "phi2" },
|
||||||
|
{ LLM_ARCH_PLAMO, "plamo" },
|
||||||
};
|
};
|
||||||
|
|
||||||
enum llm_kv {
|
enum llm_kv {
|
||||||
@ -567,6 +569,24 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES =
|
|||||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
LLM_ARCH_PLAMO,
|
||||||
|
{
|
||||||
|
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||||
|
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||||
|
{ LLM_TENSOR_OUTPUT, "output" },
|
||||||
|
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
|
||||||
|
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||||
|
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||||
|
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||||
|
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||||
|
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||||
|
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
|
||||||
|
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||||
|
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||||
|
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
{
|
{
|
||||||
LLM_ARCH_UNKNOWN,
|
LLM_ARCH_UNKNOWN,
|
||||||
@ -2749,6 +2769,15 @@ static void llm_load_hparams(
|
|||||||
default: model.type = e_model::MODEL_UNKNOWN;
|
default: model.type = e_model::MODEL_UNKNOWN;
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case LLM_ARCH_PLAMO:
|
||||||
|
{
|
||||||
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||||
|
|
||||||
|
switch (hparams.n_layer) {
|
||||||
|
case 40: model.type = e_model::MODEL_13B; break;
|
||||||
|
default: model.type = e_model::MODEL_UNKNOWN;
|
||||||
|
}
|
||||||
|
} break;
|
||||||
|
|
||||||
default: (void)0;
|
default: (void)0;
|
||||||
}
|
}
|
||||||
@ -3630,6 +3659,51 @@ static bool llm_load_tensors(
|
|||||||
layer.ffn_up_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend);
|
layer.ffn_up_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend);
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case LLM_ARCH_PLAMO:
|
||||||
|
{
|
||||||
|
model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
|
||||||
|
|
||||||
|
// output
|
||||||
|
{
|
||||||
|
ggml_backend_type backend_norm;
|
||||||
|
ggml_backend_type backend_output;
|
||||||
|
|
||||||
|
if (n_gpu_layers > int(n_layer)) {
|
||||||
|
backend_norm = llama_backend_offload;
|
||||||
|
backend_output = llama_backend_offload_split;
|
||||||
|
} else {
|
||||||
|
backend_norm = GGML_BACKEND_CPU;
|
||||||
|
backend_output = GGML_BACKEND_CPU;
|
||||||
|
}
|
||||||
|
|
||||||
|
model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm);
|
||||||
|
model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output);
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint32_t n_ff = hparams.n_ff;
|
||||||
|
|
||||||
|
const int i_gpu_start = n_layer - n_gpu_layers;
|
||||||
|
|
||||||
|
model.layers.resize(n_layer);
|
||||||
|
|
||||||
|
for (uint32_t i = 0; i < n_layer; ++i) {
|
||||||
|
const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload; // NOLINT
|
||||||
|
const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload_split; // NOLINT
|
||||||
|
|
||||||
|
auto & layer = model.layers[i];
|
||||||
|
|
||||||
|
layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend);
|
||||||
|
|
||||||
|
layer.wq = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, backend_split);
|
||||||
|
layer.wk = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, backend_split);
|
||||||
|
layer.wv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, backend_split);
|
||||||
|
layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split);
|
||||||
|
|
||||||
|
layer.ffn_gate = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split);
|
||||||
|
layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split);
|
||||||
|
layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split);
|
||||||
|
}
|
||||||
|
} break;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error("unknown architecture");
|
throw std::runtime_error("unknown architecture");
|
||||||
}
|
}
|
||||||
@ -5555,6 +5629,109 @@ struct llm_build_context {
|
|||||||
|
|
||||||
return gf;
|
return gf;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct ggml_cgraph * build_plamo() {
|
||||||
|
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||||
|
|
||||||
|
struct ggml_tensor * cur;
|
||||||
|
struct ggml_tensor * inpL;
|
||||||
|
|
||||||
|
inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
|
||||||
|
cb(inpL, "inp_embd", -1);
|
||||||
|
|
||||||
|
// inp_pos - contains the positions
|
||||||
|
struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
|
||||||
|
cb(inp_pos, "inp_pos", -1);
|
||||||
|
|
||||||
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
|
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
|
||||||
|
cb(KQ_mask, "KQ_mask", -1);
|
||||||
|
|
||||||
|
// shift the entire K-cache if needed
|
||||||
|
if (do_rope_shift) {
|
||||||
|
llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int il = 0; il < n_layer; ++il) {
|
||||||
|
|
||||||
|
// norm
|
||||||
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
||||||
|
model.layers[il].attn_norm, NULL,
|
||||||
|
LLM_NORM_RMS, cb, il);
|
||||||
|
cb(cur, "attn_norm", il);
|
||||||
|
|
||||||
|
struct ggml_tensor * attention_norm = cur;
|
||||||
|
|
||||||
|
// self-attention
|
||||||
|
{
|
||||||
|
// compute Q and K and RoPE them
|
||||||
|
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
|
||||||
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
|
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
|
||||||
|
cb(Kcur, "Kcur", il);
|
||||||
|
|
||||||
|
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
|
||||||
|
cb(Vcur, "Vcur", il);
|
||||||
|
|
||||||
|
Qcur = ggml_rope_custom(
|
||||||
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
|
||||||
|
n_embd_head, 2, 0, n_orig_ctx, freq_base, freq_scale,
|
||||||
|
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||||
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
|
Kcur = ggml_rope_custom(
|
||||||
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
|
||||||
|
n_embd_head, 2, 0, n_orig_ctx, freq_base, freq_scale,
|
||||||
|
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||||
|
cb(Kcur, "Kcur", il);
|
||||||
|
|
||||||
|
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
|
||||||
|
|
||||||
|
cur = llm_build_kqv(ctx0, model, hparams, kv_self,
|
||||||
|
model.layers[il].wo, NULL,
|
||||||
|
Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
|
cb(cur, "kqv_out", il);
|
||||||
|
}
|
||||||
|
struct ggml_tensor * sa_out = cur;
|
||||||
|
|
||||||
|
cur = attention_norm;
|
||||||
|
|
||||||
|
// feed-forward network
|
||||||
|
{
|
||||||
|
cur = llm_build_ffn(ctx0, cur,
|
||||||
|
model.layers[il].ffn_up, NULL,
|
||||||
|
model.layers[il].ffn_gate, NULL,
|
||||||
|
model.layers[il].ffn_down, NULL,
|
||||||
|
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
||||||
|
cb(cur, "ffn_out", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
cur = ggml_add(ctx0, cur, sa_out);
|
||||||
|
cb(cur, "l_out", il);
|
||||||
|
|
||||||
|
cur = ggml_add(ctx0, cur, inpL);
|
||||||
|
cb(cur, "l_out", il);
|
||||||
|
|
||||||
|
// input for next layer
|
||||||
|
inpL = cur;
|
||||||
|
}
|
||||||
|
|
||||||
|
cur = inpL;
|
||||||
|
|
||||||
|
cur = llm_build_norm(ctx0, cur, hparams,
|
||||||
|
model.output_norm, NULL,
|
||||||
|
LLM_NORM_RMS, cb, -1);
|
||||||
|
cb(cur, "result_norm", -1);
|
||||||
|
|
||||||
|
// lm_head
|
||||||
|
cur = ggml_mul_mat(ctx0, model.output, cur);
|
||||||
|
cb(cur, "result_output", -1);
|
||||||
|
|
||||||
|
ggml_build_forward_expand(gf, cur);
|
||||||
|
|
||||||
|
return gf;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
//
|
//
|
||||||
@ -6065,6 +6242,10 @@ static struct ggml_cgraph * llama_build_graph(
|
|||||||
{
|
{
|
||||||
result = llm.build_phi2();
|
result = llm.build_phi2();
|
||||||
} break;
|
} break;
|
||||||
|
case LLM_ARCH_PLAMO:
|
||||||
|
{
|
||||||
|
result = llm.build_plamo();
|
||||||
|
} break;
|
||||||
default:
|
default:
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user