diff --git a/ggml-metal.metal b/ggml-metal.metal index b79a1ba56..28847794c 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1960,10 +1960,10 @@ kernel void kernel_leaky_relu_f32( } kernel void kernel_flash_attn_ext_f16( - device const half * q, - device const half * k, - device const half * v, - device const half * mask, + device const half * q, + device const half * k, + device const half * v, + device const float * mask, device float * dst, constant int64_t & ne00, constant int64_t & ne01, diff --git a/ggml.c b/ggml.c index e01d938ce..9cf4784ce 100644 --- a/ggml.c +++ b/ggml.c @@ -13307,6 +13307,13 @@ static void ggml_compute_forward_flash_attn_ext_f16( GGML_ASSERT(nb1 <= nb2); GGML_ASSERT(nb2 <= nb3); + // broadcast factors + const int64_t rk2 = neq2/nek2; + const int64_t rk3 = neq3/nek3; + + const int64_t rv2 = neq2/nev2; + const int64_t rv3 = neq3/nev3; + if (params->type == GGML_TASK_INIT) { return; } @@ -13347,8 +13354,8 @@ static void ggml_compute_forward_flash_attn_ext_f16( if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) { for (int64_t ic = 0; ic < nek1; ++ic) { // k indices - const int ik3 = iq3; - const int ik2 = iq2 % nek2; + const int ik3 = iq3 / rk3; + const int ik2 = iq2 / rk2; const int ik1 = ic; // S indices @@ -13362,8 +13369,8 @@ static void ggml_compute_forward_flash_attn_ext_f16( } else { for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) { // k indices - const int ik3 = iq3; - const int ik2 = iq2 % nek2; + const int ik3 = iq3 / rk3; + const int ik2 = iq2 / rk2; const int ik1 = ic; // S indices @@ -13452,8 +13459,8 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int i3 = iq3; // v indices - const int iv2 = iq2 % nev2; - const int iv3 = iq3; + const int iv2 = iq2 / rv2; + const int iv3 = iq3 / rv3; ggml_vec_dot_f16(nev0, (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), @@ -13468,8 +13475,8 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int i3 = iq3; // v indices - const int iv2 = iq2 % nev2; - const int iv3 = iq3; + const int iv2 = iq2 / rv2; + const int iv3 = iq3 / rv3; ggml_vec_dot_f16_unroll(nev0, nbv1, (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), diff --git a/llama.cpp b/llama.cpp index cec23c23f..d4bebe520 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4220,6 +4220,10 @@ static struct ggml_tensor * llm_build_kqv( struct ggml_tensor * kqv; if (supports_flash_attn) { + //printf("q: %4d %4d %4d %4d\n", q->ne[0], q->ne[1], q->ne[2], q->ne[3]); + //printf("k: %4d %4d %4d %4d\n", k->ne[0], k->ne[1], k->ne[2], k->ne[3]); + //printf("v: %4d %4d %4d %4d\n", v->ne[0], v->ne[1], v->ne[2], v->ne[3]); + //printf("m: %4d %4d %4d %4d\n", kq_mask->ne[0], kq_mask->ne[1], kq_mask->ne[2], kq_mask->ne[3]); kqv = ggml_flash_attn_ext(ctx, ggml_cast(ctx, q, GGML_TYPE_F16), k, v, kq_mask, kq_scale); } else { struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);