mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-07 17:21:46 +00:00
vec move mask to shmem
This commit is contained in:
parent
3b9625032c
commit
94accca4c2
@ -3297,7 +3297,7 @@ static void ggml_metal_encode_node(
|
|||||||
// ne00*(nsg)
|
// ne00*(nsg)
|
||||||
// each simdgroup has a full f16 head vector in shared mem to accumulate results
|
// each simdgroup has a full f16 head vector in shared mem to accumulate results
|
||||||
//
|
//
|
||||||
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*ncpsg*(nsg)) + ne00*(nsg))*(sizeof(float)/2), 16))
|
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 4*ncpsg*(nsg)) + ne00*(nsg))*(sizeof(float)/2), 16))
|
||||||
|
|
||||||
int64_t nsgmax = 2;
|
int64_t nsgmax = 2;
|
||||||
|
|
||||||
|
@ -2844,7 +2844,7 @@ kernel void kernel_flash_attn_ext(
|
|||||||
const short D8 = D/8;
|
const short D8 = D/8;
|
||||||
const short D16 = D/16;
|
const short D16 = D/16;
|
||||||
const short NW = N_SIMDWIDTH;
|
const short NW = N_SIMDWIDTH;
|
||||||
const short SH = (2*C + Q); // shared memory per simdgroup in (half)
|
const short SH = (2*C + Q); // shared memory per simdgroup (s_t == float)
|
||||||
|
|
||||||
const short TS = nsg*SH; // shared memory size per query in (s_t == float)
|
const short TS = nsg*SH; // shared memory size per query in (s_t == float)
|
||||||
const short T = D + 2*TS; // shared memory size per query in (half)
|
const short T = D + 2*TS; // shared memory size per query in (half)
|
||||||
@ -3353,16 +3353,17 @@ kernel void kernel_flash_attn_ext_vec(
|
|||||||
const short D16 = D/16;
|
const short D16 = D/16;
|
||||||
const short NW = N_SIMDWIDTH;
|
const short NW = N_SIMDWIDTH;
|
||||||
const short NW4 = NW/4;
|
const short NW4 = NW/4;
|
||||||
const short SH = C; // shared memory per simdgroup in (half)
|
const short SH = 2*C; // shared memory per simdgroup
|
||||||
|
|
||||||
const short T = D + 2*nsg*SH; // shared memory size per query in (half)
|
const short T = D + 2*nsg*SH; // shared memory size per query in (half)
|
||||||
|
|
||||||
//threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data
|
//threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data
|
||||||
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0*D); // same as above but in half4
|
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0*D); // same as above but in q4_t
|
||||||
threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shared + 0*D); // same as above but in half4x4
|
threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shared + 0*D); // same as above but in q4x4_t
|
||||||
threadgroup s_t * ss = (threadgroup s_t *) (shared + 2*sgitg*SH + Q*D); // scratch buffer for attention
|
threadgroup s_t * ss = (threadgroup s_t *) (shared + 2*sgitg*SH + Q*D); // scratch buffer for attention
|
||||||
threadgroup s4_t * ss4 = (threadgroup s4_t *) (shared + 2*sgitg*SH + Q*D); // same as above but in half4
|
threadgroup s4_t * ss4 = (threadgroup s4_t *) (shared + 2*sgitg*SH + Q*D); // same as above but in s4_t
|
||||||
threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shared + sgitg*D + Q*T); // scratch buffer for the results
|
threadgroup half * sm = (threadgroup half *) (shared + 2*sgitg*SH + SH + Q*D); // scratch buffer for mask
|
||||||
|
threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shared + sgitg*D + Q*T); // scratch buffer for the results
|
||||||
|
|
||||||
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
|
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
|
||||||
o4x4_t lo[D16/NW4];
|
o4x4_t lo[D16/NW4];
|
||||||
@ -3412,8 +3413,10 @@ kernel void kernel_flash_attn_ext_vec(
|
|||||||
mq[ii/NW4] = sq4x4[ii + tx];
|
mq[ii/NW4] = sq4x4[ii + tx];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const bool has_mask = mask != q;
|
||||||
|
|
||||||
// pointer to the mask
|
// pointer to the mask
|
||||||
device const half * mp = (device const half *) (mask + iq1*nb31);
|
device const half * pm = (device const half *) (mask + iq1*nb31);
|
||||||
|
|
||||||
half slope = 1.0f;
|
half slope = 1.0f;
|
||||||
|
|
||||||
@ -3435,6 +3438,10 @@ kernel void kernel_flash_attn_ext_vec(
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (has_mask) {
|
||||||
|
sm[tiisg] = pm[ic + tiisg];
|
||||||
|
}
|
||||||
|
|
||||||
// Q*K^T
|
// Q*K^T
|
||||||
{
|
{
|
||||||
// each simdgroup processes 1 query and 4 keys
|
// each simdgroup processes 1 query and 4 keys
|
||||||
@ -3476,7 +3483,7 @@ kernel void kernel_flash_attn_ext_vec(
|
|||||||
mqk = logit_softcap*precise::tanh(mqk);
|
mqk = logit_softcap*precise::tanh(mqk);
|
||||||
}
|
}
|
||||||
|
|
||||||
mqk += (s_t) ((mask != q) ? ((float) mp[ic + 4*cc + ty])*slope : (float) 0.0f);
|
mqk += sm[4*cc + ty]*slope;
|
||||||
|
|
||||||
ss[4*cc + ty] = mqk;
|
ss[4*cc + ty] = mqk;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user