diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index b56be8448..aed84f1b2 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -249,7 +249,7 @@ class Model: toktypes.append(gguf.TokenType.USER_DEFINED) elif reverse_vocab[i] in added_vocab: tokens.append(reverse_vocab[i]) - if tokenizer.added_tokens_decoder[i].special: + if hasattr(tokenizer, "added_tokens_decoder") and tokenizer.added_tokens_decoder[i].special: toktypes.append(gguf.TokenType.CONTROL) else: toktypes.append(gguf.TokenType.USER_DEFINED) @@ -998,7 +998,7 @@ class Phi2Model(Model): self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"]) self.gguf_writer.add_rope_dimension_count(self.hparams["rotary_dim"]) self.gguf_writer.add_file_type(self.ftype) - + ###### CONVERSION LOGIC ###### diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 0a63c1ecf..b2e086a19 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -4998,7 +4998,16 @@ static __global__ void rope_neox( const int ib = col / n_dims; const int ic = col % n_dims; - const int i = row*ncols + ib*n_dims + ic/2; + // IMPORTANT: consider the case ncols == 80 and n_dims == 32 (phi-2) + // I don't know what we are supposed to compute, because the row is not divisible by n_dims + // this check matches the CPU code, but it is likely wrong as well + // I can't understand the Python code, so if you know what to do here, please fix it + // ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26 + if (ncols % n_dims != 0 && ib == ncols/n_dims) { + return; + } + + const int i = row*ncols + ib*n_dims + ic/2; const int i2 = row/p_delta_rows; float cur_rot = inv_ndims * ic - ib; diff --git a/ggml.c b/ggml.c index 1feb7ead3..772be5616 100644 --- a/ggml.c +++ b/ggml.c @@ -9168,6 +9168,8 @@ static void ggml_compute_forward_norm_f32( float eps; memcpy(&eps, dst->op_params, sizeof(float)); + GGML_ASSERT(eps > 0.0f); + // TODO: optimize for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { @@ -9237,6 +9239,8 @@ static void ggml_compute_forward_rms_norm_f32( float eps; memcpy(&eps, dst->op_params, sizeof(float)); + GGML_ASSERT(eps > 0.0f); + // TODO: optimize for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { diff --git a/llama.cpp b/llama.cpp index 162692ce8..b4e62341b 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2998,7 +2998,7 @@ static void llm_load_tensors( (void) main_gpu; - enum ggml_backend_type llama_backend_offload = GGML_BACKEND_CPU; + enum ggml_backend_type llama_backend_offload = GGML_BACKEND_CPU; enum ggml_backend_type llama_backend_offload_split = GGML_BACKEND_CPU; #ifdef GGML_USE_CUBLAS @@ -3643,9 +3643,7 @@ static void llm_load_tensors( } break; case LLM_ARCH_PHI2: { - // TODO: CPU-only for now - - model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); + model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); // output { @@ -3654,7 +3652,7 @@ static void llm_load_tensors( if (n_gpu_layers > int(n_layer)) { backend_norm = llama_backend_offload; - backend_output = llama_backend_offload_split; + backend_output = llama_backend_offload; } else { backend_norm = GGML_BACKEND_CPU; backend_output = GGML_BACKEND_CPU; @@ -3663,13 +3661,11 @@ static void llm_load_tensors( model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); model.output_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, backend_norm); model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); - model.output_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "bias"), {n_vocab}, backend_output); + model.output_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "bias"), {n_vocab}, backend_output); if (backend_norm == GGML_BACKEND_GPU) { vram_weights += ggml_nbytes(model.output_norm); vram_weights += ggml_nbytes(model.output_norm_b); - } - if (backend_output == GGML_BACKEND_GPU_SPLIT) { vram_weights += ggml_nbytes(model.output); vram_weights += ggml_nbytes(model.output_b); } @@ -3687,20 +3683,20 @@ static void llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend); - layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, backend); + layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend); + layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, backend); layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split); layer.bqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, backend); - layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); - layer.bo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, backend); + layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); + layer.bo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, backend); layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, backend_split); layer.ffn_down_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, backend); - layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); - layer.ffn_up_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend); + layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); + layer.ffn_up_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend); if (backend == GGML_BACKEND_GPU) { vram_weights += @@ -5401,15 +5397,15 @@ struct llm_build_context { cb(inpL, "inp_embd", -1); // inp_pos - contains the positions - struct ggml_tensor * inp_pos= ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); cb(inp_pos, "inp_pos", -1); // KQ_scale - struct ggml_tensor * KQ_scale= ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); cb(KQ_scale, "KQ_scale", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask= ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); cb(KQ_mask, "KQ_mask", -1); // shift the entire K-cache if needed @@ -5528,8 +5524,12 @@ struct llm_build_context { struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); cb(KQ_mask, "KQ_mask", -1); - for (int il = 0; il < n_layer; ++il) { + // shift the entire K-cache if needed + if (do_rope_shift) { + llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb); + } + for (int il = 0; il < n_layer; ++il) { attn_norm_output = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, model.layers[il].attn_norm_b, @@ -5552,14 +5552,19 @@ struct llm_build_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_custom( + ctx0, Qcur, inp_pos, hparams.n_rot, 2, 0, n_orig_ctx, + freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow + ); cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - // RoPE - Qcur = ggml_rope(ctx0, Qcur, inp_pos, 32, 2, 0); - Kcur = ggml_rope(ctx0, Kcur, inp_pos, 32, 2, 0); - cb(Qcur, "Qcur", il); + + Kcur = ggml_rope_custom( + ctx0, Kcur, inp_pos, hparams.n_rot, 2, 0, n_orig_ctx, + freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow + ); cb(Kcur, "Kcur", il); llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); @@ -5580,8 +5585,13 @@ struct llm_build_context { cb(ffn_output, "ffn_out", il); } - inpL = ggml_add(ctx0, cur, ggml_add_inplace(ctx0, ffn_output, inpL)); - cb(inpL, "l_out", il); + cur = ggml_add(ctx0, cur, ffn_output); + cb(cur, "l_out", il); + + cur = ggml_add(ctx0, cur, inpL); + cb(cur, "l_out", il); + + inpL = cur; } cur = llm_build_norm(ctx0, inpL, hparams, @@ -5589,9 +5599,9 @@ struct llm_build_context { model.output_norm_b, LLM_NORM, cb, -1); cb(cur, "result_norm", -1); - + cur = ggml_mul_mat(ctx0, model.output, cur); - cb(cur, "result_norm", -1); + cb(cur, "result_output_no_bias", -1); cur = ggml_add(ctx0, cur, model.output_b); cb(cur, "result_output", -1); @@ -5613,7 +5623,7 @@ enum llm_offload_func_e { OFFLOAD_FUNC_FRC, // force offload OFFLOAD_FUNC_KQV, OFFLOAD_FUNC_NR, - OFFLOAD_FUNC_EMB, + OFFLOAD_FUNC_EMB, // embeddings OFFLOAD_FUNC_OUT, }; @@ -5782,6 +5792,7 @@ static const std::unordered_map k_offload_map { "l_out", OFFLOAD_FUNC }, { "result_norm", OFFLOAD_FUNC_EMB }, + { "result_output_no_bias", OFFLOAD_FUNC_EMB }, { "result_output", OFFLOAD_FUNC_OUT }, }; @@ -6235,12 +6246,16 @@ static int llama_decode_internal( ggml_allocr_alloc_graph(lctx.alloc, gf); - struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; + // the output is always the last tensor in the graph + struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; + GGML_ASSERT(strcmp(res->name, "result_output") == 0); + + // the embeddings could be the second to last tensor, or the third to last tensor struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2]; - - GGML_ASSERT(strcmp(res->name, "result_output") == 0); - GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0); - + if (strcmp(embeddings->name, "result_norm") != 0) { + embeddings = gf->nodes[gf->n_nodes - 3]; + GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0); + } #ifdef GGML_USE_CUBLAS for (int i = 0; i < gf->n_leafs; i++) { diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index df2c3fb6e..f04b9438a 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1555,6 +1555,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_rope(type, { 64, 8, 10, 1}, 64, 2, 512)); // neox (falcon 40B) test_cases.emplace_back(new test_rope(type, { 64, 128, 10, 1}, 64, 2, 512)); // neox (falcon 40B) test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 20, 2, 512)); // neox (stablelm) + test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 32, 2, 512)); // neox (phi-2) } test_cases.emplace_back(new test_alibi());