llama : fix K-shift with quantized K (wip)

This commit is contained in:
slaren 2024-02-22 00:28:39 +01:00
parent 7fe4678b02
commit 5271c75666

View File

@ -4671,17 +4671,30 @@ static void llm_build_k_shift(
} }
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * tmp = struct ggml_tensor * k = ggml_view_3d(ctx, kv.k_l[il],
// we rotate only the first n_rot dimensions n_embd_head_k, n_head_kv, n_ctx,
ggml_rope_custom_inplace(ctx, ggml_row_size(kv.k_l[il]->type, n_embd_head_k),
ggml_view_3d(ctx, kv.k_l[il], ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa),
n_embd_head_k, n_head_kv, n_ctx, 0);
ggml_row_size(kv.k_l[il]->type, n_embd_head_k),
ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa), struct ggml_tensor * tmp;
0), if (ggml_is_quantized(k->type)) {
// dequantize to f32 -> RoPE -> quantize back
tmp = ggml_cast(ctx, k, GGML_TYPE_F32);
cb(tmp, "K_f32", il);
tmp = ggml_rope_custom_inplace(ctx, tmp,
K_shift, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, K_shift, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow); ext_factor, attn_factor, beta_fast, beta_slow);
cb(tmp, "K_shifted", il); cb(tmp, "K_shifted_f32", il);
tmp = ggml_cpy(ctx, tmp, k);
cb(tmp, "K_shifted_q", il);
} else {
// we rotate only the first n_rot dimensions
tmp = ggml_rope_custom_inplace(ctx, k,
K_shift, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
cb(tmp, "K_shifted", il);
}
ggml_build_forward_expand(graph, tmp); ggml_build_forward_expand(graph, tmp);
} }
} }