diff --git a/llama.cpp b/llama.cpp index 581a8399d..ef2cc57ca 100644 --- a/llama.cpp +++ b/llama.cpp @@ -28,6 +28,8 @@ #define LLAMA_USE_SCRATCH #define LLAMA_MAX_SCRATCH_BUFFERS 16 +#define LLAMA_USE_FLASH_ATTN + #define LLAMA_ASSERT(x) \ do { \ if (!(x)) { \ @@ -829,6 +831,30 @@ static bool llama_eval_internal( ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v)); } +#ifdef LLAMA_USE_FLASH_ATTN + struct ggml_tensor * Q = + ggml_permute(ctx0, + ggml_cpy(ctx0, + Qcur, + ggml_new_tensor_3d(ctx0, GGML_TYPE_F16, n_embd/n_head, n_head, N)), + 0, 2, 1, 3); + + struct ggml_tensor * K = + ggml_permute(ctx0, + ggml_reshape_3d(ctx0, + ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.k)*n_embd), + n_embd/n_head, n_head, n_past + N), + 0, 2, 1, 3); + + struct ggml_tensor * V = + ggml_view_3d(ctx0, kv_self.v, + n_past + N, n_embd/n_head, n_head, + n_ctx*ggml_element_size(kv_self.v), + n_ctx*ggml_element_size(kv_self.v)*n_embd/n_head, + il*n_ctx*ggml_element_size(kv_self.v)*n_embd); + + struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, true); +#else struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, @@ -872,6 +898,7 @@ static bool llama_eval_internal( // is there a better way? struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_embd/n_head, n_head)); struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_cont, KQ_soft_max); +#endif #endif // KQV_merged = KQV.permute(0, 2, 1, 3)