llama : add llm_build_inp_embd helper

This commit is contained in:
Georgi Gerganov 2023-10-31 16:43:08 +02:00
parent 2073347e3b
commit 7923b70cb8
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

161
llama.cpp
View File

@ -1228,8 +1228,8 @@ struct llama_model {
llama_hparams hparams = {};
llama_vocab vocab;
struct ggml_tensor * tok_embeddings;
struct ggml_tensor * pos_embeddings;
struct ggml_tensor * tok_embd;
struct ggml_tensor * pos_embd;
struct ggml_tensor * tok_norm;
struct ggml_tensor * tok_norm_b;
@ -2484,7 +2484,7 @@ static void llm_load_tensors(
case LLM_ARCH_LLAMA:
case LLM_ARCH_REFACT:
{
model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
// output
{
@ -2552,7 +2552,7 @@ static void llm_load_tensors(
} break;
case LLM_ARCH_BAICHUAN:
{
model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
{
ggml_backend_type backend_norm;
ggml_backend_type backend_output;
@ -2620,7 +2620,7 @@ static void llm_load_tensors(
{
// TODO: CPU-only for now
model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
// output
{
@ -2696,8 +2696,8 @@ static void llm_load_tensors(
} break;
case LLM_ARCH_STARCODER:
{
model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
model.pos_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, hparams.n_ctx_train}, GGML_BACKEND_CPU);
model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
model.pos_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, hparams.n_ctx_train}, GGML_BACKEND_CPU);
// output
{
@ -2775,7 +2775,7 @@ static void llm_load_tensors(
} break;
case LLM_ARCH_PERSIMMON:
{
model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
{
ggml_backend_type backend_norm;
@ -2838,9 +2838,9 @@ static void llm_load_tensors(
{
// TODO: CPU-only for now
model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
model.tok_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, GGML_BACKEND_CPU);
model.tok_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, GGML_BACKEND_CPU);
model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
model.tok_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, GGML_BACKEND_CPU);
model.tok_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, GGML_BACKEND_CPU);
// output
{
@ -2918,7 +2918,7 @@ static void llm_load_tensors(
} break;
case LLM_ARCH_MPT:
{
model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
// output
{
@ -3099,6 +3099,31 @@ enum llm_rope_type {
LLM_ROPE_GLM,
};
static struct ggml_tensor * llm_build_inp_embd(
struct ggml_context * ctx,
const llama_batch & batch,
struct ggml_tensor * tok_embd,
int64_t n_embd,
int32_t n_tokens,
const llm_build_cb & cb) {
struct ggml_tensor * inpL;
if (batch.token) {
struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_tokens);
cb(inp_tokens, "inp_tokens", -1);
inpL = ggml_get_rows(ctx, tok_embd, inp_tokens);
} else {
#ifdef GGML_USE_MPI
GGML_ASSERT(false && "not implemented");
#endif
inpL = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_tokens);
}
return inpL;
}
// Persimmon: n_rot = n_embd_head/2
// Other: n_rot = n_embd_head
static void llm_build_k_shift(
@ -3463,18 +3488,7 @@ static struct ggml_cgraph * llm_build_llama(
struct ggml_tensor * cur;
struct ggml_tensor * inpL;
if (batch.token) {
struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
cb(inp_tokens, "inp_tokens", -1);
inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens);
} else {
#ifdef GGML_USE_MPI
GGML_ASSERT(false && "not implemented");
#endif
inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens);
}
inpL = llm_build_inp_embd(ctx0, batch, model.tok_embd, n_embd, n_tokens, cb);
cb(inpL, "inp_embd", -1);
// inp_pos - contains the positions
@ -3619,18 +3633,7 @@ static struct ggml_cgraph * llm_build_baichaun(
struct ggml_tensor * cur;
struct ggml_tensor * inpL;
if (batch.token) {
struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
cb(inp_tokens, "inp_tokens", -1);
inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens);
} else {
#ifdef GGML_USE_MPI
GGML_ASSERT(false && "not implemented");
#endif
inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens);
}
inpL = llm_build_inp_embd(ctx0, batch, model.tok_embd, n_embd, n_tokens, cb);
cb(inpL, "inp_embd", -1);
// inp_pos - contains the positions
@ -3789,18 +3792,7 @@ static struct ggml_cgraph * llm_build_falcon(
struct ggml_tensor * cur;
struct ggml_tensor * inpL;
if (batch.token) {
struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
cb(inp_tokens, "inp_tokens", -1);
inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens);
} else {
#ifdef GGML_USE_MPI
GGML_ASSERT(false && "not implemented");
#endif
inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens);
}
inpL = llm_build_inp_embd(ctx0, batch, model.tok_embd, n_embd, n_tokens, cb);
cb(inpL, "inp_embd", -1);
// inp_pos - contains the positions
@ -3953,23 +3945,11 @@ static struct ggml_cgraph * llm_build_starcoder(
ggml_cgraph * gf = ggml_new_graph(ctx0);
struct ggml_tensor * cur;
struct ggml_tensor * embd;
struct ggml_tensor * pos;
struct ggml_tensor * inpL;
if (batch.token) {
struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
cb(inp_tokens, "inp_tokens", -1);
embd = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens);
} else {
#ifdef GGML_USE_MPI
GGML_ASSERT(false && "not implemented");
#endif
embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens);
}
cb(embd, "inp_embd", -1);
inpL = llm_build_inp_embd(ctx0, batch, model.tok_embd, n_embd, n_tokens, cb);
cb(inpL, "inp_embd", -1);
// inp_pos - contains the positions
struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
@ -3983,10 +3963,10 @@ static struct ggml_cgraph * llm_build_starcoder(
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
cb(KQ_mask, "KQ_mask", -1);
pos = ggml_get_rows(ctx0, model.pos_embeddings, inp_pos);
pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos);
cb(pos, "pos_embd", -1);
inpL = ggml_add(ctx0, embd, pos);
inpL = ggml_add(ctx0, inpL, pos);
cb(inpL, "inpL", -1);
for (int il = 0; il < n_layer; ++il) {
@ -4108,14 +4088,7 @@ static struct ggml_cgraph * llm_build_persimmon(
struct ggml_tensor * cur;
struct ggml_tensor * inpL;
if (batch.token) {
struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
cb(inp_tokens, "inp_tokens", -1);
inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens);
} else {
inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens);
}
inpL = llm_build_inp_embd(ctx0, batch, model.tok_embd, n_embd, n_tokens, cb);
cb(inpL, "imp_embd", -1);
struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
@ -4358,18 +4331,7 @@ static struct ggml_cgraph * llm_build_refact(
struct ggml_tensor * cur;
struct ggml_tensor * inpL;
if (batch.token) {
struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
cb(inp_tokens, "inp_tokens", -1);
inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens);
} else {
#ifdef GGML_USE_MPI
GGML_ASSERT(false && "not implemented");
#endif
inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens);
}
inpL = llm_build_inp_embd(ctx0, batch, model.tok_embd, n_embd, n_tokens, cb);
cb(inpL, "inp_embd", -1);
// KQ_scale
@ -4499,22 +4461,10 @@ static struct ggml_cgraph * llm_build_bloom(
ggml_cgraph * gf = ggml_new_graph(ctx0);
struct ggml_tensor * cur;
struct ggml_tensor * embd;
struct ggml_tensor * inpL;
if (batch.token) {
struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
cb(inp_tokens, "inp_tokens", -1);
embd = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens);
} else {
#ifdef GGML_USE_MPI
GGML_ASSERT(false && "not implemented");
#endif
embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens);
}
cb(embd, "inp_embd", -1);
inpL = llm_build_inp_embd(ctx0, batch, model.tok_embd, n_embd, n_tokens, cb);
cb(inpL, "inp_embd", -1);
// KQ_scale
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
@ -4524,7 +4474,7 @@ static struct ggml_cgraph * llm_build_bloom(
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
cb(KQ_mask, "KQ_mask", -1);
inpL = llm_build_norm(ctx0, embd,
inpL = llm_build_norm(ctx0, inpL,
model.tok_norm,
model.tok_norm_b,
LLM_NORM, norm_eps, cb, -1);
@ -4648,18 +4598,7 @@ static struct ggml_cgraph * llm_build_mpt(
struct ggml_tensor * cur;
struct ggml_tensor * inpL;
if (batch.token) {
struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
cb(inp_tokens, "inp_tokens", -1);
inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens);
} else {
#ifdef GGML_USE_MPI
GGML_ASSERT(false && "not implemented");
#endif
inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens);
}
inpL = llm_build_inp_embd(ctx0, batch, model.tok_embd, n_embd, n_tokens, cb);
cb(inpL, "inp_embd", -1);
// KQ_scale