phi-2 : various fixes

This commit is contained in:
Georgi Gerganov 2023-12-16 10:46:18 +02:00
parent e20765534d
commit a2a3d2c8d7
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
5 changed files with 66 additions and 37 deletions

View File

@ -249,7 +249,7 @@ class Model:
toktypes.append(gguf.TokenType.USER_DEFINED)
elif reverse_vocab[i] in added_vocab:
tokens.append(reverse_vocab[i])
if tokenizer.added_tokens_decoder[i].special:
if hasattr(tokenizer, "added_tokens_decoder") and tokenizer.added_tokens_decoder[i].special:
toktypes.append(gguf.TokenType.CONTROL)
else:
toktypes.append(gguf.TokenType.USER_DEFINED)
@ -998,7 +998,7 @@ class Phi2Model(Model):
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
self.gguf_writer.add_rope_dimension_count(self.hparams["rotary_dim"])
self.gguf_writer.add_file_type(self.ftype)
###### CONVERSION LOGIC ######

View File

@ -4998,7 +4998,16 @@ static __global__ void rope_neox(
const int ib = col / n_dims;
const int ic = col % n_dims;
const int i = row*ncols + ib*n_dims + ic/2;
// IMPORTANT: consider the case ncols == 80 and n_dims == 32 (phi-2)
// I don't know what we are supposed to compute, because the row is not divisible by n_dims
// this check matches the CPU code, but it is likely wrong as well
// I can't understand the Python code, so if you know what to do here, please fix it
// ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
if (ncols % n_dims != 0 && ib == ncols/n_dims) {
return;
}
const int i = row*ncols + ib*n_dims + ic/2;
const int i2 = row/p_delta_rows;
float cur_rot = inv_ndims * ic - ib;

4
ggml.c
View File

@ -9168,6 +9168,8 @@ static void ggml_compute_forward_norm_f32(
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
GGML_ASSERT(eps > 0.0f);
// TODO: optimize
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
@ -9237,6 +9239,8 @@ static void ggml_compute_forward_rms_norm_f32(
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
GGML_ASSERT(eps > 0.0f);
// TODO: optimize
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {

View File

@ -2998,7 +2998,7 @@ static void llm_load_tensors(
(void) main_gpu;
enum ggml_backend_type llama_backend_offload = GGML_BACKEND_CPU;
enum ggml_backend_type llama_backend_offload = GGML_BACKEND_CPU;
enum ggml_backend_type llama_backend_offload_split = GGML_BACKEND_CPU;
#ifdef GGML_USE_CUBLAS
@ -3643,9 +3643,7 @@ static void llm_load_tensors(
} break;
case LLM_ARCH_PHI2:
{
// TODO: CPU-only for now
model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
// output
{
@ -3654,7 +3652,7 @@ static void llm_load_tensors(
if (n_gpu_layers > int(n_layer)) {
backend_norm = llama_backend_offload;
backend_output = llama_backend_offload_split;
backend_output = llama_backend_offload;
} else {
backend_norm = GGML_BACKEND_CPU;
backend_output = GGML_BACKEND_CPU;
@ -3663,13 +3661,11 @@ static void llm_load_tensors(
model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm);
model.output_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, backend_norm);
model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output);
model.output_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "bias"), {n_vocab}, backend_output);
model.output_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "bias"), {n_vocab}, backend_output);
if (backend_norm == GGML_BACKEND_GPU) {
vram_weights += ggml_nbytes(model.output_norm);
vram_weights += ggml_nbytes(model.output_norm_b);
}
if (backend_output == GGML_BACKEND_GPU_SPLIT) {
vram_weights += ggml_nbytes(model.output);
vram_weights += ggml_nbytes(model.output_b);
}
@ -3687,20 +3683,20 @@ static void llm_load_tensors(
auto & layer = model.layers[i];
layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend);
layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, backend);
layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend);
layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, backend);
layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split);
layer.bqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, backend);
layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split);
layer.bo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, backend);
layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split);
layer.bo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, backend);
layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, backend_split);
layer.ffn_down_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, backend);
layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split);
layer.ffn_up_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend);
layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split);
layer.ffn_up_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend);
if (backend == GGML_BACKEND_GPU) {
vram_weights +=
@ -5401,15 +5397,15 @@ struct llm_build_context {
cb(inpL, "inp_embd", -1);
// inp_pos - contains the positions
struct ggml_tensor * inp_pos= ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
cb(inp_pos, "inp_pos", -1);
// KQ_scale
struct ggml_tensor * KQ_scale= ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
cb(KQ_scale, "KQ_scale", -1);
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask= ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
cb(KQ_mask, "KQ_mask", -1);
// shift the entire K-cache if needed
@ -5528,8 +5524,12 @@ struct llm_build_context {
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
cb(KQ_mask, "KQ_mask", -1);
for (int il = 0; il < n_layer; ++il) {
// shift the entire K-cache if needed
if (do_rope_shift) {
llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb);
}
for (int il = 0; il < n_layer; ++il) {
attn_norm_output = llm_build_norm(ctx0, inpL, hparams,
model.layers[il].attn_norm,
model.layers[il].attn_norm_b,
@ -5552,14 +5552,19 @@ struct llm_build_context {
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Qcur = ggml_rope_custom(
ctx0, Qcur, inp_pos, hparams.n_rot, 2, 0, n_orig_ctx,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
// RoPE
Qcur = ggml_rope(ctx0, Qcur, inp_pos, 32, 2, 0);
Kcur = ggml_rope(ctx0, Kcur, inp_pos, 32, 2, 0);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_custom(
ctx0, Kcur, inp_pos, hparams.n_rot, 2, 0, n_orig_ctx,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur", il);
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
@ -5580,8 +5585,13 @@ struct llm_build_context {
cb(ffn_output, "ffn_out", il);
}
inpL = ggml_add(ctx0, cur, ggml_add_inplace(ctx0, ffn_output, inpL));
cb(inpL, "l_out", il);
cur = ggml_add(ctx0, cur, ffn_output);
cb(cur, "l_out", il);
cur = ggml_add(ctx0, cur, inpL);
cb(cur, "l_out", il);
inpL = cur;
}
cur = llm_build_norm(ctx0, inpL, hparams,
@ -5589,9 +5599,9 @@ struct llm_build_context {
model.output_norm_b,
LLM_NORM, cb, -1);
cb(cur, "result_norm", -1);
cur = ggml_mul_mat(ctx0, model.output, cur);
cb(cur, "result_norm", -1);
cb(cur, "result_output_no_bias", -1);
cur = ggml_add(ctx0, cur, model.output_b);
cb(cur, "result_output", -1);
@ -5613,7 +5623,7 @@ enum llm_offload_func_e {
OFFLOAD_FUNC_FRC, // force offload
OFFLOAD_FUNC_KQV,
OFFLOAD_FUNC_NR,
OFFLOAD_FUNC_EMB,
OFFLOAD_FUNC_EMB, // embeddings
OFFLOAD_FUNC_OUT,
};
@ -5782,6 +5792,7 @@ static const std::unordered_map<const char *, llm_offload_func_e> k_offload_map
{ "l_out", OFFLOAD_FUNC },
{ "result_norm", OFFLOAD_FUNC_EMB },
{ "result_output_no_bias", OFFLOAD_FUNC_EMB },
{ "result_output", OFFLOAD_FUNC_OUT },
};
@ -6235,12 +6246,16 @@ static int llama_decode_internal(
ggml_allocr_alloc_graph(lctx.alloc, gf);
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
// the output is always the last tensor in the graph
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
GGML_ASSERT(strcmp(res->name, "result_output") == 0);
// the embeddings could be the second to last tensor, or the third to last tensor
struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2];
GGML_ASSERT(strcmp(res->name, "result_output") == 0);
GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0);
if (strcmp(embeddings->name, "result_norm") != 0) {
embeddings = gf->nodes[gf->n_nodes - 3];
GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0);
}
#ifdef GGML_USE_CUBLAS
for (int i = 0; i < gf->n_leafs; i++) {

View File

@ -1555,6 +1555,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
test_cases.emplace_back(new test_rope(type, { 64, 8, 10, 1}, 64, 2, 512)); // neox (falcon 40B)
test_cases.emplace_back(new test_rope(type, { 64, 128, 10, 1}, 64, 2, 512)); // neox (falcon 40B)
test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 20, 2, 512)); // neox (stablelm)
test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 32, 2, 512)); // neox (phi-2)
}
test_cases.emplace_back(new test_alibi());