mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 03:14:35 +00:00
llama : add flash attention (demo)
This commit is contained in:
parent
986b6ce9f9
commit
36ddd12924
27
llama.cpp
27
llama.cpp
@ -28,6 +28,8 @@
|
|||||||
#define LLAMA_USE_SCRATCH
|
#define LLAMA_USE_SCRATCH
|
||||||
#define LLAMA_MAX_SCRATCH_BUFFERS 16
|
#define LLAMA_MAX_SCRATCH_BUFFERS 16
|
||||||
|
|
||||||
|
#define LLAMA_USE_FLASH_ATTN
|
||||||
|
|
||||||
#define LLAMA_ASSERT(x) \
|
#define LLAMA_ASSERT(x) \
|
||||||
do { \
|
do { \
|
||||||
if (!(x)) { \
|
if (!(x)) { \
|
||||||
@ -829,6 +831,30 @@ static bool llama_eval_internal(
|
|||||||
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
|
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef LLAMA_USE_FLASH_ATTN
|
||||||
|
struct ggml_tensor * Q =
|
||||||
|
ggml_permute(ctx0,
|
||||||
|
ggml_cpy(ctx0,
|
||||||
|
Qcur,
|
||||||
|
ggml_new_tensor_3d(ctx0, GGML_TYPE_F16, n_embd/n_head, n_head, N)),
|
||||||
|
0, 2, 1, 3);
|
||||||
|
|
||||||
|
struct ggml_tensor * K =
|
||||||
|
ggml_permute(ctx0,
|
||||||
|
ggml_reshape_3d(ctx0,
|
||||||
|
ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.k)*n_embd),
|
||||||
|
n_embd/n_head, n_head, n_past + N),
|
||||||
|
0, 2, 1, 3);
|
||||||
|
|
||||||
|
struct ggml_tensor * V =
|
||||||
|
ggml_view_3d(ctx0, kv_self.v,
|
||||||
|
n_past + N, n_embd/n_head, n_head,
|
||||||
|
n_ctx*ggml_element_size(kv_self.v),
|
||||||
|
n_ctx*ggml_element_size(kv_self.v)*n_embd/n_head,
|
||||||
|
il*n_ctx*ggml_element_size(kv_self.v)*n_embd);
|
||||||
|
|
||||||
|
struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, true);
|
||||||
|
#else
|
||||||
struct ggml_tensor * Q =
|
struct ggml_tensor * Q =
|
||||||
ggml_permute(ctx0,
|
ggml_permute(ctx0,
|
||||||
Qcur,
|
Qcur,
|
||||||
@ -872,6 +898,7 @@ static bool llama_eval_internal(
|
|||||||
// is there a better way?
|
// is there a better way?
|
||||||
struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_embd/n_head, n_head));
|
struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_embd/n_head, n_head));
|
||||||
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_cont, KQ_soft_max);
|
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_cont, KQ_soft_max);
|
||||||
|
#endif
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// KQV_merged = KQV.permute(0, 2, 1, 3)
|
// KQV_merged = KQV.permute(0, 2, 1, 3)
|
||||||
|
Loading…
Reference in New Issue
Block a user