mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-07 17:21:46 +00:00
This commit is contained in:
parent
dd0d9ed102
commit
1e129611b1
@ -3187,7 +3187,7 @@ static void ggml_metal_encode_node(
|
||||
}
|
||||
nsg /= 2;
|
||||
|
||||
const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2);
|
||||
const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + 2*nsg*ne00)*(sizeof(float)/2);
|
||||
|
||||
//printf("smem: %zu, max: %zu\n", smem, device.maxThreadgroupMemoryLength);
|
||||
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
|
||||
|
@ -2819,22 +2819,25 @@ kernel void kernel_flash_attn_ext(
|
||||
float S[Q] = { [0 ... Q-1] = 0.0h };
|
||||
float M[Q] = { [0 ... Q-1] = -FLT_MAX/2 };
|
||||
|
||||
// thread indices inside the simdgroup
|
||||
const short tx = tiisg%4;
|
||||
const short ty = tiisg/4;
|
||||
|
||||
// assume K and V are same shape
|
||||
const short ne22 = ne12;
|
||||
const short ne23 = ne13;
|
||||
|
||||
// broadcast
|
||||
// broadcast k
|
||||
const short rk2 = ne02/ne12;
|
||||
const short rk3 = ne03/ne13;
|
||||
|
||||
const short rv2 = ne02/ne22;
|
||||
const short rv3 = ne03/ne23;
|
||||
|
||||
// k indices
|
||||
const short ik2 = iq2/rk2;
|
||||
const short ik3 = iq3/rk3;
|
||||
|
||||
// v indices
|
||||
// broadcast v
|
||||
const short rv2 = ne02/ne22;
|
||||
const short rv3 = ne03/ne23;
|
||||
|
||||
const short iv2 = iq2/rv2;
|
||||
const short iv3 = iq3/rv3;
|
||||
|
||||
@ -2885,15 +2888,12 @@ kernel void kernel_flash_attn_ext(
|
||||
}
|
||||
} else {
|
||||
for (short ii = 0; ii < D16; ii += 4) {
|
||||
const short i = tiisg%4;
|
||||
const short j = tiisg/4;
|
||||
|
||||
device const block_q * pk4 = (device const block_q *) ((device const char *) k + ((ic + 8*cc + j)*nb11 + ik2*nb12 + ik3*nb13));
|
||||
device const block_q * pk4 = (device const block_q *) ((device const char *) k + ((ic + 8*cc + ty)*nb11 + ik2*nb12 + ik3*nb13));
|
||||
|
||||
if (D16%4 == 0) {
|
||||
half4x4 tmp;
|
||||
dequantize_func(pk4 + (ii + i)/nl, (ii + i)%nl, tmp);
|
||||
skv4[4*j + i] = tmp;
|
||||
dequantize_func(pk4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
|
||||
skv4[4*ty + tx] = tmp;
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
@ -2908,10 +2908,10 @@ kernel void kernel_flash_attn_ext(
|
||||
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
|
||||
}
|
||||
} else {
|
||||
if (ii + i < D16) {
|
||||
if (ii + tx < D16) {
|
||||
half4x4 tmp;
|
||||
dequantize_func(pk4 + (ii + i)/nl, (ii + i)%nl, tmp);
|
||||
skv4[4*j + i] = tmp;
|
||||
dequantize_func(pk4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
|
||||
skv4[4*ty + tx] = tmp;
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_threadgroup);
|
||||
@ -3006,15 +3006,12 @@ kernel void kernel_flash_attn_ext(
|
||||
}
|
||||
} else {
|
||||
for (short ii = 0; ii < D16; ii += 4) {
|
||||
const short i = tiisg%4;
|
||||
const short j = tiisg/4;
|
||||
|
||||
device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 8*cc + j)*nb21 + iv2*nb22 + iv3*nb23));
|
||||
device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 8*cc + ty)*nb21 + iv2*nb22 + iv3*nb23));
|
||||
|
||||
if (D16%4 == 0) {
|
||||
half4x4 tmp;
|
||||
dequantize_func(pv4 + (ii + i)/nl, (ii + i)%nl, tmp);
|
||||
skv4[4*j + i] = tmp;
|
||||
dequantize_func(pv4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
|
||||
skv4[4*ty + tx] = tmp;
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
@ -3029,10 +3026,10 @@ kernel void kernel_flash_attn_ext(
|
||||
simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
|
||||
}
|
||||
} else {
|
||||
if (ii + i < D16) {
|
||||
if (ii + tx < D16) {
|
||||
half4x4 tmp;
|
||||
dequantize_func(pv4 + (ii + i)/nl, (ii + i)%nl, tmp);
|
||||
skv4[4*j + i] = tmp;
|
||||
dequantize_func(pv4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
|
||||
skv4[4*ty + tx] = tmp;
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_threadgroup);
|
||||
@ -3187,6 +3184,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 256>;
|
||||
|
||||
// NOTE: can use half instead of float precision for some extra perf
|
||||
// D - head size, Q - queries per threadgroup, C - cache items per threadgroup
|
||||
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &), short D, short Q = 1, short C = 32>
|
||||
kernel void kernel_flash_attn_ext_vec(
|
||||
@ -3239,26 +3237,15 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
|
||||
const short T = D + 2*nsg*SH; // shared memory size per query in (half)
|
||||
|
||||
float slope = 1.0f;
|
||||
|
||||
// ALiBi
|
||||
if (max_bias > 0.0f) {
|
||||
const uint32_t h = iq2;
|
||||
|
||||
const float base = h < n_head_log2 ? m0 : m1;
|
||||
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
||||
|
||||
slope = pow(base, exp);
|
||||
}
|
||||
|
||||
//threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
|
||||
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
|
||||
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
|
||||
threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4
|
||||
threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + Q*T); // scratch buffer for the results
|
||||
//threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
|
||||
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
|
||||
threadgroup half4x4 * sq44 = (threadgroup half4x4 *) (shared + 0*D); // same as above but in half4x4
|
||||
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
|
||||
threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4
|
||||
threadgroup float4x4 * sr44 = (threadgroup float4x4 *) (shared + 2*sgitg*D + Q*T); // scratch buffer for the results
|
||||
|
||||
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
|
||||
half4x4 lo[D16/NW4];
|
||||
float4x4 lo[D16/NW4];
|
||||
|
||||
// load heads from Q to shared memory
|
||||
device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));
|
||||
@ -3273,7 +3260,7 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
|
||||
// zero out lo
|
||||
for (short i = 0; i < D16/NW4; i += NW4) {
|
||||
lo[i] = half4x4(0.0h);
|
||||
lo[i] = float4x4(0.0h);
|
||||
}
|
||||
|
||||
// zero out shared memory SH
|
||||
@ -3284,42 +3271,53 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
{
|
||||
float S = { 0.0h };
|
||||
float M = { -FLT_MAX/2 };
|
||||
float S = 0.0h;
|
||||
float M = -FLT_MAX/2;
|
||||
|
||||
// thread indices inside the simdgroup
|
||||
const short tx = tiisg%8;
|
||||
const short ty = tiisg/8;
|
||||
|
||||
// assume K and V are same shape
|
||||
const short ne22 = ne12;
|
||||
const short ne23 = ne13;
|
||||
|
||||
// broadcast
|
||||
// broadcast k
|
||||
const short rk2 = ne02/ne12;
|
||||
const short rk3 = ne03/ne13;
|
||||
|
||||
const short ik2 = iq2/rk2;
|
||||
const short ik3 = iq3/rk3;
|
||||
|
||||
// broadcast v
|
||||
const short rv2 = ne02/ne22;
|
||||
const short rv3 = ne03/ne23;
|
||||
|
||||
// k indices
|
||||
const short ik2 = iq2 / rk2;
|
||||
const short ik3 = iq3 / rk3;
|
||||
|
||||
// v indices
|
||||
const short iv2 = iq2 / rv2;
|
||||
const short iv3 = iq3 / rv3;
|
||||
const short iv2 = iq2/rv2;
|
||||
const short iv3 = iq3/rv3;
|
||||
|
||||
// load the queries from shared memory into local memory
|
||||
float4x4 mq[D16/NW4];
|
||||
|
||||
for (short ii = 0; ii < D16; ii += NW4) {
|
||||
short i = ii + tiisg%8;
|
||||
mq[ii/NW4][0] = (float4) sq4[4*i + 0];
|
||||
mq[ii/NW4][1] = (float4) sq4[4*i + 1];
|
||||
mq[ii/NW4][2] = (float4) sq4[4*i + 2];
|
||||
mq[ii/NW4][3] = (float4) sq4[4*i + 3];
|
||||
mq[ii/NW4] = (float4x4) sq44[ii + tx];
|
||||
}
|
||||
|
||||
// pointer to the mask
|
||||
device const half * mp = (device const half *) (mask + iq1*nb31);
|
||||
|
||||
float slope = 1.0f;
|
||||
|
||||
// ALiBi
|
||||
if (max_bias > 0.0f) {
|
||||
const uint32_t h = iq2;
|
||||
|
||||
const float base = h < n_head_log2 ? m0 : m1;
|
||||
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
||||
|
||||
slope = pow(base, exp);
|
||||
}
|
||||
|
||||
// loop over the KV cache
|
||||
// each simdgroup handles blocks of Q rows and C columns
|
||||
for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
|
||||
@ -3331,18 +3329,16 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
// Q*K^T
|
||||
{
|
||||
// each simdgroup processes 1 query and 4 keys
|
||||
const short j = tiisg/8;
|
||||
#pragma unroll
|
||||
for (short cc = 0; cc < C/4; ++cc) {
|
||||
float mqk = 0.0;
|
||||
|
||||
device const block_q * pk = (device const block_q *) ((device const char *) k + ((ic + 4*cc + j)*nb11 + ik2*nb12 + ik3*nb13));
|
||||
device const block_q * pk = (device const block_q *) ((device const char *) k + ((ic + 4*cc + ty)*nb11 + ik2*nb12 + ik3*nb13));
|
||||
|
||||
float4x4 mk;
|
||||
#pragma unroll
|
||||
for (short ii = 0; ii < D16; ii += NW4) {
|
||||
const short i = ii + tiisg%8; // 0..7
|
||||
const short i = ii + tx;
|
||||
|
||||
float4x4 mk;
|
||||
dequantize_func(pk + i/nl, i%nl, mk);
|
||||
|
||||
mqk +=
|
||||
@ -3364,16 +3360,16 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
mqk += simd_shuffle_down(mqk, 1);
|
||||
|
||||
// mqk = mqk*scale + mask*slope
|
||||
if (tiisg%8 == 0) {
|
||||
if (tx == 0) {
|
||||
mqk *= scale;
|
||||
|
||||
if (logit_softcap != 0.0f) {
|
||||
mqk = logit_softcap*precise::tanh(mqk);
|
||||
}
|
||||
|
||||
mqk += (mask != q) ? ((float) mp[ic + 4*cc + j])*slope : (float) 0.0f;
|
||||
mqk += (mask != q) ? ((float) mp[ic + 4*cc + ty])*slope : (float) 0.0f;
|
||||
|
||||
ss[4*cc + j] = mqk;
|
||||
ss[4*cc + ty] = mqk;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -3408,20 +3404,20 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
|
||||
// O = O + (Q*K^T)*V
|
||||
{
|
||||
const short j = tiisg/8;
|
||||
#pragma unroll
|
||||
for (short cc = 0; cc < C/4; ++cc) {
|
||||
device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 4*cc + j)*nb21 + iv2*nb22 + iv3*nb23));
|
||||
device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 4*cc + ty)*nb21 + iv2*nb22 + iv3*nb23));
|
||||
|
||||
const float4x4 lss(ss[4*cc + ty]);
|
||||
|
||||
float4x4 mv;
|
||||
const float4x4 lss(ss[4*cc + j]);
|
||||
#pragma unroll
|
||||
for (short ii = 0; ii < D16; ii += NW4) {
|
||||
const short i = ii + tiisg%8;
|
||||
const short i = ii + tx;
|
||||
|
||||
float4x4 mv;
|
||||
dequantize_func(pv4 + i/nl, i%nl, mv);
|
||||
|
||||
lo[ii/NW4] += (half4x4)(mv*lss);
|
||||
lo[ii/NW4] += mv*lss;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -3458,14 +3454,8 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
}
|
||||
|
||||
// store results to shared memory
|
||||
for (short ii = 0; ii < D16; ii += NW4) {
|
||||
short i = ii + tiisg;
|
||||
if (tiisg < 8) {
|
||||
sr4[4*i + 0] = lo[ii/NW4][0];
|
||||
sr4[4*i + 1] = lo[ii/NW4][1];
|
||||
sr4[4*i + 2] = lo[ii/NW4][2];
|
||||
sr4[4*i + 3] = lo[ii/NW4][3];
|
||||
}
|
||||
for (short i = tiisg; i < D16; i += NW4) {
|
||||
sr44[i] = lo[i/NW4];
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
@ -3492,24 +3482,22 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
}
|
||||
|
||||
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
|
||||
for (short ii = 0; ii < D4; ii += NW) {
|
||||
short i = ii + tiisg;
|
||||
sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1;
|
||||
for (short i = tiisg; i < D16; i += NW) {
|
||||
sr44[i] = sr44[i]*ms0 + sr44[i + r*D16]*ms1;
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
device float4 * dst4 = (device float4 *) dst;
|
||||
device float4x4 * dst44 = (device float4x4 *) dst;
|
||||
|
||||
// final rescale with 1/S and store to global memory
|
||||
if (sgitg == 0) {
|
||||
const float S = ss[0];
|
||||
|
||||
for (short ii = 0; ii < D4; ii += NW) {
|
||||
short i = ii + tiisg;
|
||||
dst4[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D4 + i] = (float4) sr4[i]/S;
|
||||
for (short i = tiisg; i < D16; i += NW) {
|
||||
dst44[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D16 + i] = sr44[i]/S;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user