mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 03:44:35 +00:00
phi2 implementation
This commit is contained in:
parent
6744dbe924
commit
12cc80cb89
@ -182,6 +182,8 @@ class Model:
|
|||||||
return QwenModel
|
return QwenModel
|
||||||
if model_architecture == "MixtralForCausalLM":
|
if model_architecture == "MixtralForCausalLM":
|
||||||
return MixtralModel
|
return MixtralModel
|
||||||
|
if model_architecture == "PhiForCausalLM":
|
||||||
|
return Phi2Model
|
||||||
return Model
|
return Model
|
||||||
|
|
||||||
def _is_model_safetensors(self) -> bool:
|
def _is_model_safetensors(self) -> bool:
|
||||||
@ -221,6 +223,8 @@ class Model:
|
|||||||
return gguf.MODEL_ARCH.QWEN
|
return gguf.MODEL_ARCH.QWEN
|
||||||
if arch == "MixtralForCausalLM":
|
if arch == "MixtralForCausalLM":
|
||||||
return gguf.MODEL_ARCH.LLAMA
|
return gguf.MODEL_ARCH.LLAMA
|
||||||
|
if arch == "PhiForCausalLM":
|
||||||
|
return gguf.MODEL_ARCH.PHI2
|
||||||
|
|
||||||
raise NotImplementedError(f'Architecture "{arch}" not supported!')
|
raise NotImplementedError(f'Architecture "{arch}" not supported!')
|
||||||
|
|
||||||
@ -980,6 +984,21 @@ class QwenModel(Model):
|
|||||||
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
||||||
self.gguf_writer.add_tensor(new_name, data)
|
self.gguf_writer.add_tensor(new_name, data)
|
||||||
|
|
||||||
|
class Phi2Model(Model):
|
||||||
|
def set_gguf_parameters(self):
|
||||||
|
block_count = self.hparams["n_layer"]
|
||||||
|
|
||||||
|
self.gguf_writer.add_name("Phi2")
|
||||||
|
self.gguf_writer.add_context_length(self.hparams["n_positions"])
|
||||||
|
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
|
||||||
|
self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
|
||||||
|
self.gguf_writer.add_block_count(block_count)
|
||||||
|
self.gguf_writer.add_head_count(self.hparams["n_head"])
|
||||||
|
self.gguf_writer.add_head_count_kv(self.hparams["n_head"])
|
||||||
|
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
|
||||||
|
self.gguf_writer.add_rope_dimension_count(self.hparams["rotary_dim"])
|
||||||
|
self.gguf_writer.add_file_type(self.ftype)
|
||||||
|
|
||||||
###### CONVERSION LOGIC ######
|
###### CONVERSION LOGIC ######
|
||||||
|
|
||||||
|
|
||||||
|
@ -95,6 +95,7 @@ class MODEL_ARCH(IntEnum):
|
|||||||
BLOOM = auto()
|
BLOOM = auto()
|
||||||
STABLELM = auto()
|
STABLELM = auto()
|
||||||
QWEN = auto()
|
QWEN = auto()
|
||||||
|
PHI2 = auto()
|
||||||
|
|
||||||
|
|
||||||
class MODEL_TENSOR(IntEnum):
|
class MODEL_TENSOR(IntEnum):
|
||||||
@ -140,6 +141,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
|||||||
MODEL_ARCH.BLOOM: "bloom",
|
MODEL_ARCH.BLOOM: "bloom",
|
||||||
MODEL_ARCH.STABLELM: "stablelm",
|
MODEL_ARCH.STABLELM: "stablelm",
|
||||||
MODEL_ARCH.QWEN: "qwen",
|
MODEL_ARCH.QWEN: "qwen",
|
||||||
|
MODEL_ARCH.PHI2: "phi2",
|
||||||
}
|
}
|
||||||
|
|
||||||
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||||
@ -350,6 +352,17 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||||||
MODEL_ARCH.GPT2: [
|
MODEL_ARCH.GPT2: [
|
||||||
# TODO
|
# TODO
|
||||||
],
|
],
|
||||||
|
MODEL_ARCH.PHI2: [
|
||||||
|
MODEL_TENSOR.TOKEN_EMBD,
|
||||||
|
MODEL_TENSOR.OUTPUT_NORM,
|
||||||
|
MODEL_TENSOR.OUTPUT,
|
||||||
|
MODEL_TENSOR.ATTN_NORM,
|
||||||
|
MODEL_TENSOR.ATTN_QKV,
|
||||||
|
MODEL_TENSOR.ATTN_OUT,
|
||||||
|
MODEL_TENSOR.FFN_NORM,
|
||||||
|
MODEL_TENSOR.FFN_DOWN,
|
||||||
|
MODEL_TENSOR.FFN_UP,
|
||||||
|
]
|
||||||
# TODO
|
# TODO
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -17,6 +17,7 @@ class TensorNameMap:
|
|||||||
"tok_embeddings", # llama-pth
|
"tok_embeddings", # llama-pth
|
||||||
"embeddings.word_embeddings", # bert
|
"embeddings.word_embeddings", # bert
|
||||||
"language_model.embedding.word_embeddings", # persimmon
|
"language_model.embedding.word_embeddings", # persimmon
|
||||||
|
"transformer.embd.wte", # phi2
|
||||||
),
|
),
|
||||||
|
|
||||||
# Token type embeddings
|
# Token type embeddings
|
||||||
@ -41,6 +42,7 @@ class TensorNameMap:
|
|||||||
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen
|
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen
|
||||||
"output", # llama-pth bloom
|
"output", # llama-pth bloom
|
||||||
"word_embeddings_for_head", # persimmon
|
"word_embeddings_for_head", # persimmon
|
||||||
|
"lm_head.linear", # phi2
|
||||||
),
|
),
|
||||||
|
|
||||||
# Output norm
|
# Output norm
|
||||||
@ -53,6 +55,7 @@ class TensorNameMap:
|
|||||||
"transformer.norm_f", # mpt
|
"transformer.norm_f", # mpt
|
||||||
"ln_f", # refact bloom qwen
|
"ln_f", # refact bloom qwen
|
||||||
"language_model.encoder.final_layernorm", # persimmon
|
"language_model.encoder.final_layernorm", # persimmon
|
||||||
|
"lm_head.ln", # phi2
|
||||||
),
|
),
|
||||||
|
|
||||||
# Rope frequencies
|
# Rope frequencies
|
||||||
@ -75,6 +78,7 @@ class TensorNameMap:
|
|||||||
"encoder.layer.{bid}.attention.output.LayerNorm", # bert
|
"encoder.layer.{bid}.attention.output.LayerNorm", # bert
|
||||||
"language_model.encoder.layers.{bid}.input_layernorm", # persimmon
|
"language_model.encoder.layers.{bid}.input_layernorm", # persimmon
|
||||||
"model.layers.{bid}.ln1", # yi
|
"model.layers.{bid}.ln1", # yi
|
||||||
|
"transformer.h.{bid}.ln", # phi2
|
||||||
),
|
),
|
||||||
|
|
||||||
# Attention norm 2
|
# Attention norm 2
|
||||||
@ -90,6 +94,7 @@ class TensorNameMap:
|
|||||||
"transformer.h.{bid}.self_attention.query_key_value", # falcon
|
"transformer.h.{bid}.self_attention.query_key_value", # falcon
|
||||||
"h.{bid}.self_attention.query_key_value", # bloom
|
"h.{bid}.self_attention.query_key_value", # bloom
|
||||||
"language_model.encoder.layers.{bid}.self_attention.query_key_value", # persimmon
|
"language_model.encoder.layers.{bid}.self_attention.query_key_value", # persimmon
|
||||||
|
"transformer.h.{bid}.mixer.Wqkv", # phi2
|
||||||
),
|
),
|
||||||
|
|
||||||
# Attention query
|
# Attention query
|
||||||
@ -128,6 +133,7 @@ class TensorNameMap:
|
|||||||
"encoder.layer.{bid}.attention.output.dense", # bert
|
"encoder.layer.{bid}.attention.output.dense", # bert
|
||||||
"transformer.h.{bid}.attn.out_proj", # gpt-j
|
"transformer.h.{bid}.attn.out_proj", # gpt-j
|
||||||
"language_model.encoder.layers.{bid}.self_attention.dense", # persimmon
|
"language_model.encoder.layers.{bid}.self_attention.dense", # persimmon
|
||||||
|
"transformer.h.{bid}.mixer.out_proj", # phi2
|
||||||
),
|
),
|
||||||
|
|
||||||
# Rotary embeddings
|
# Rotary embeddings
|
||||||
@ -167,6 +173,7 @@ class TensorNameMap:
|
|||||||
"transformer.h.{bid}.mlp.fc_in", # gpt-j
|
"transformer.h.{bid}.mlp.fc_in", # gpt-j
|
||||||
"language_model.encoder.layers.{bid}.mlp.dense_h_to_4h", # persimmon
|
"language_model.encoder.layers.{bid}.mlp.dense_h_to_4h", # persimmon
|
||||||
"transformer.h.{bid}.mlp.w1", # qwen
|
"transformer.h.{bid}.mlp.w1", # qwen
|
||||||
|
"transformer.h.{bid}.mlp.fc1", # phi2
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.FFN_UP_EXP: (
|
MODEL_TENSOR.FFN_UP_EXP: (
|
||||||
@ -198,6 +205,7 @@ class TensorNameMap:
|
|||||||
"encoder.layer.{bid}.output.dense", # bert
|
"encoder.layer.{bid}.output.dense", # bert
|
||||||
"transformer.h.{bid}.mlp.fc_out", # gpt-j
|
"transformer.h.{bid}.mlp.fc_out", # gpt-j
|
||||||
"language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon
|
"language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon
|
||||||
|
"transformer.h.{bid}.mlp.fc2", # phi2
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.FFN_DOWN_EXP: (
|
MODEL_TENSOR.FFN_DOWN_EXP: (
|
||||||
|
187
llama.cpp
187
llama.cpp
@ -195,6 +195,7 @@ enum llm_arch {
|
|||||||
LLM_ARCH_BLOOM,
|
LLM_ARCH_BLOOM,
|
||||||
LLM_ARCH_STABLELM,
|
LLM_ARCH_STABLELM,
|
||||||
LLM_ARCH_QWEN,
|
LLM_ARCH_QWEN,
|
||||||
|
LLM_ARCH_PHI2,
|
||||||
LLM_ARCH_UNKNOWN,
|
LLM_ARCH_UNKNOWN,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -212,6 +213,7 @@ static std::map<llm_arch, std::string> LLM_ARCH_NAMES = {
|
|||||||
{ LLM_ARCH_BLOOM, "bloom" },
|
{ LLM_ARCH_BLOOM, "bloom" },
|
||||||
{ LLM_ARCH_STABLELM, "stablelm" },
|
{ LLM_ARCH_STABLELM, "stablelm" },
|
||||||
{ LLM_ARCH_QWEN, "qwen" },
|
{ LLM_ARCH_QWEN, "qwen" },
|
||||||
|
{ LLM_ARCH_PHI2, "phi2" },
|
||||||
};
|
};
|
||||||
|
|
||||||
enum llm_kv {
|
enum llm_kv {
|
||||||
@ -550,6 +552,19 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES =
|
|||||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
LLM_ARCH_PHI2,
|
||||||
|
{
|
||||||
|
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||||
|
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||||
|
{ LLM_TENSOR_OUTPUT, "output" },
|
||||||
|
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||||
|
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
|
||||||
|
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||||
|
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||||
|
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
{
|
{
|
||||||
LLM_ARCH_UNKNOWN,
|
LLM_ARCH_UNKNOWN,
|
||||||
@ -1420,6 +1435,7 @@ struct llama_model {
|
|||||||
struct ggml_tensor * output_norm;
|
struct ggml_tensor * output_norm;
|
||||||
struct ggml_tensor * output_norm_b;
|
struct ggml_tensor * output_norm_b;
|
||||||
struct ggml_tensor * output;
|
struct ggml_tensor * output;
|
||||||
|
struct ggml_tensor * output_b;
|
||||||
|
|
||||||
std::vector<llama_layer> layers;
|
std::vector<llama_layer> layers;
|
||||||
|
|
||||||
@ -3625,7 +3641,77 @@ static void llm_load_tensors(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case LLM_ARCH_PHI2:
|
||||||
|
{
|
||||||
|
// TODO: CPU-only for now
|
||||||
|
|
||||||
|
model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
|
||||||
|
|
||||||
|
// output
|
||||||
|
{
|
||||||
|
ggml_backend_type backend_norm;
|
||||||
|
ggml_backend_type backend_output;
|
||||||
|
|
||||||
|
if (n_gpu_layers > int(n_layer)) {
|
||||||
|
backend_norm = llama_backend_offload;
|
||||||
|
backend_output = llama_backend_offload_split;
|
||||||
|
} else {
|
||||||
|
backend_norm = GGML_BACKEND_CPU;
|
||||||
|
backend_output = GGML_BACKEND_CPU;
|
||||||
|
}
|
||||||
|
|
||||||
|
model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm);
|
||||||
|
model.output_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, backend_norm);
|
||||||
|
model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output);
|
||||||
|
model.output_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "bias"), {n_vocab}, backend_output);
|
||||||
|
|
||||||
|
if (backend_norm == GGML_BACKEND_GPU) {
|
||||||
|
vram_weights += ggml_nbytes(model.output_norm);
|
||||||
|
vram_weights += ggml_nbytes(model.output_norm_b);
|
||||||
|
}
|
||||||
|
if (backend_output == GGML_BACKEND_GPU_SPLIT) {
|
||||||
|
vram_weights += ggml_nbytes(model.output);
|
||||||
|
vram_weights += ggml_nbytes(model.output_b);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint32_t n_ff = hparams.n_ff;
|
||||||
|
|
||||||
|
const int i_gpu_start = n_layer - n_gpu_layers;
|
||||||
|
|
||||||
|
model.layers.resize(n_layer);
|
||||||
|
|
||||||
|
for (uint32_t i = 0; i < n_layer; ++i) {
|
||||||
|
const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload; // NOLINT
|
||||||
|
const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload_split; // NOLINT
|
||||||
|
|
||||||
|
auto & layer = model.layers[i];
|
||||||
|
|
||||||
|
layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend);
|
||||||
|
layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, backend);
|
||||||
|
|
||||||
|
layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split);
|
||||||
|
layer.bqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, backend);
|
||||||
|
|
||||||
|
layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split);
|
||||||
|
layer.bo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, backend);
|
||||||
|
|
||||||
|
layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, backend_split);
|
||||||
|
layer.ffn_down_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, backend);
|
||||||
|
|
||||||
|
layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split);
|
||||||
|
layer.ffn_up_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend);
|
||||||
|
|
||||||
|
if (backend == GGML_BACKEND_GPU) {
|
||||||
|
vram_weights +=
|
||||||
|
ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.attn_norm_b) +
|
||||||
|
ggml_nbytes(layer.wqkv) + ggml_nbytes(layer.bqkv) +
|
||||||
|
ggml_nbytes(layer.wo) + ggml_nbytes(layer.bo) +
|
||||||
|
ggml_nbytes(layer.ffn_up) + ggml_nbytes(layer.ffn_up_b) +
|
||||||
|
ggml_nbytes(layer.ffn_down) + ggml_nbytes(layer.ffn_down_b);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} break;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error("unknown architecture");
|
throw std::runtime_error("unknown architecture");
|
||||||
}
|
}
|
||||||
@ -5417,6 +5503,101 @@ struct llm_build_context {
|
|||||||
|
|
||||||
ggml_build_forward_expand(gf, cur);
|
ggml_build_forward_expand(gf, cur);
|
||||||
|
|
||||||
|
return gf;
|
||||||
|
}
|
||||||
|
struct ggml_cgraph * build_phi2() {
|
||||||
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
|
||||||
|
|
||||||
|
struct ggml_tensor * cur;
|
||||||
|
struct ggml_tensor * attn_norm_output;
|
||||||
|
struct ggml_tensor * ffn_output;
|
||||||
|
struct ggml_tensor * inpL;
|
||||||
|
|
||||||
|
inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
|
||||||
|
cb(inpL, "inp_embd", -1);
|
||||||
|
|
||||||
|
// inp_pos - contains the positions
|
||||||
|
struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
|
||||||
|
cb(inp_pos, "inp_pos", -1);
|
||||||
|
|
||||||
|
// KQ_scale
|
||||||
|
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
|
||||||
|
cb(KQ_scale, "KQ_scale", -1);
|
||||||
|
|
||||||
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
|
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
|
||||||
|
cb(KQ_mask, "KQ_mask", -1);
|
||||||
|
|
||||||
|
for (int il = 0; il < n_layer; ++il) {
|
||||||
|
|
||||||
|
attn_norm_output = llm_build_norm(ctx0, inpL, hparams,
|
||||||
|
model.layers[il].attn_norm,
|
||||||
|
model.layers[il].attn_norm_b,
|
||||||
|
LLM_NORM, cb, il);
|
||||||
|
cb(attn_norm_output, "attn_norm", il);
|
||||||
|
|
||||||
|
// self-attention
|
||||||
|
{
|
||||||
|
cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, attn_norm_output);
|
||||||
|
cb(cur, "wqkv", il);
|
||||||
|
|
||||||
|
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
|
||||||
|
cb(cur, "bqkv", il);
|
||||||
|
|
||||||
|
struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
|
||||||
|
struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
|
||||||
|
struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
|
||||||
|
|
||||||
|
cb(Qcur, "Qcur", il);
|
||||||
|
cb(Kcur, "Kcur", il);
|
||||||
|
cb(Vcur, "Vcur", il);
|
||||||
|
|
||||||
|
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||||
|
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||||
|
cb(Qcur, "Qcur", il);
|
||||||
|
cb(Kcur, "Kcur", il);
|
||||||
|
// RoPE
|
||||||
|
Qcur = ggml_rope(ctx0, Qcur, inp_pos, 32, 2, 0);
|
||||||
|
Kcur = ggml_rope(ctx0, Kcur, inp_pos, 32, 2, 0);
|
||||||
|
cb(Qcur, "Qcur", il);
|
||||||
|
cb(Kcur, "Kcur", il);
|
||||||
|
|
||||||
|
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
|
||||||
|
|
||||||
|
cur = llm_build_kqv(ctx0, hparams, kv_self,
|
||||||
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
|
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il);
|
||||||
|
cb(cur, "kqv_out", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
// FF
|
||||||
|
{
|
||||||
|
ffn_output = llm_build_ffn(ctx0, attn_norm_output,
|
||||||
|
model.layers[il].ffn_up, model.layers[il].ffn_up_b,
|
||||||
|
NULL, NULL,
|
||||||
|
model.layers[il].ffn_down, model.layers[il].ffn_down_b,
|
||||||
|
LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
|
||||||
|
cb(ffn_output, "ffn_out", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
inpL = ggml_add(ctx0, cur, ggml_add_inplace(ctx0, ffn_output, inpL));
|
||||||
|
cb(inpL, "l_out", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
||||||
|
model.output_norm,
|
||||||
|
model.output_norm_b,
|
||||||
|
LLM_NORM, cb, -1);
|
||||||
|
cb(cur, "result_norm", -1);
|
||||||
|
|
||||||
|
cur = ggml_mul_mat(ctx0, model.output, cur);
|
||||||
|
cb(cur, "result_output", -1);
|
||||||
|
|
||||||
|
cur = ggml_add(ctx0, cur, model.output_b);
|
||||||
|
cb(cur, "result_output", -1);
|
||||||
|
|
||||||
|
ggml_build_forward_expand(gf, cur);
|
||||||
|
|
||||||
return gf;
|
return gf;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -5917,6 +6098,10 @@ static struct ggml_cgraph * llama_build_graph(
|
|||||||
{
|
{
|
||||||
result = llm.build_qwen();
|
result = llm.build_qwen();
|
||||||
} break;
|
} break;
|
||||||
|
case LLM_ARCH_PHI2:
|
||||||
|
{
|
||||||
|
result = llm.build_phi2();
|
||||||
|
} break;
|
||||||
default:
|
default:
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
}
|
}
|
||||||
@ -6051,7 +6236,7 @@ static int llama_decode_internal(
|
|||||||
ggml_allocr_alloc_graph(lctx.alloc, gf);
|
ggml_allocr_alloc_graph(lctx.alloc, gf);
|
||||||
|
|
||||||
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
|
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
|
||||||
struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2];
|
struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 3];
|
||||||
|
|
||||||
GGML_ASSERT(strcmp(res->name, "result_output") == 0);
|
GGML_ASSERT(strcmp(res->name, "result_output") == 0);
|
||||||
GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0);
|
GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0);
|
||||||
|
Loading…
Reference in New Issue
Block a user