From 5271c7566659081a626013989375cd52e4ad4762 Mon Sep 17 00:00:00 2001 From: slaren Date: Thu, 22 Feb 2024 00:28:39 +0100 Subject: [PATCH] llama : fix K-shift with quantized K (wip) --- llama.cpp | 31 ++++++++++++++++++++++--------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/llama.cpp b/llama.cpp index d763cc80c..10cd602ef 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4671,17 +4671,30 @@ static void llm_build_k_shift( } for (int il = 0; il < n_layer; ++il) { - struct ggml_tensor * tmp = - // we rotate only the first n_rot dimensions - ggml_rope_custom_inplace(ctx, - ggml_view_3d(ctx, kv.k_l[il], - n_embd_head_k, n_head_kv, n_ctx, - ggml_row_size(kv.k_l[il]->type, n_embd_head_k), - ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa), - 0), + struct ggml_tensor * k = ggml_view_3d(ctx, kv.k_l[il], + n_embd_head_k, n_head_kv, n_ctx, + ggml_row_size(kv.k_l[il]->type, n_embd_head_k), + ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa), + 0); + + struct ggml_tensor * tmp; + if (ggml_is_quantized(k->type)) { + // dequantize to f32 -> RoPE -> quantize back + tmp = ggml_cast(ctx, k, GGML_TYPE_F32); + cb(tmp, "K_f32", il); + tmp = ggml_rope_custom_inplace(ctx, tmp, K_shift, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); - cb(tmp, "K_shifted", il); + cb(tmp, "K_shifted_f32", il); + tmp = ggml_cpy(ctx, tmp, k); + cb(tmp, "K_shifted_q", il); + } else { + // we rotate only the first n_rot dimensions + tmp = ggml_rope_custom_inplace(ctx, k, + K_shift, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + cb(tmp, "K_shifted", il); + } ggml_build_forward_expand(graph, tmp); } }