mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-07 17:21:46 +00:00
wip 3
This commit is contained in:
parent
2335086fd3
commit
9bd5ae09ae
@ -3161,11 +3161,11 @@ kernel void kernel_flash_attn_ext(
|
||||
// the first simdgroup accumulates the results from the other simdgroups
|
||||
if (sgitg == 0) {
|
||||
for (short j = 0; j < Q; ++j) {
|
||||
const float S0 = (s_t) ss[j*TS + 0];
|
||||
const float S1 = (s_t) ss[j*TS + sg*SH + 0];
|
||||
const float S0 = ss[j*TS + 0];
|
||||
const float S1 = ss[j*TS + sg*SH + 0];
|
||||
|
||||
const float M0 = (s_t) ss[j*TS + 1];
|
||||
const float M1 = (s_t) ss[j*TS + sg*SH + 1];
|
||||
const float M0 = ss[j*TS + 1];
|
||||
const float M1 = ss[j*TS + sg*SH + 1];
|
||||
|
||||
M = max(M0, M1);
|
||||
|
||||
@ -3234,7 +3234,7 @@ kernel void kernel_flash_attn_ext(
|
||||
half, half4, simdgroup_half8x8, \
|
||||
half, half4x4, simdgroup_half8x8, \
|
||||
half, half4x4, simdgroup_half8x8, \
|
||||
half, simdgroup_half8x8, \
|
||||
half, simdgroup_half8x8, \
|
||||
half, half4, simdgroup_half8x8
|
||||
#else
|
||||
#define S_T float
|
||||
@ -3243,10 +3243,10 @@ kernel void kernel_flash_attn_ext(
|
||||
#define S8x8_T simdgroup_float8x8
|
||||
|
||||
#define FA_TYPES \
|
||||
half, half4, simdgroup_half8x8, \
|
||||
half, half4x4, simdgroup_half8x8, \
|
||||
half, half4x4, simdgroup_half8x8, \
|
||||
float, simdgroup_float8x8, \
|
||||
half, half4, simdgroup_half8x8, \
|
||||
half, half4x4, simdgroup_half8x8, \
|
||||
half, half4x4, simdgroup_half8x8, \
|
||||
float, simdgroup_float8x8, \
|
||||
half, half4, simdgroup_half8x8
|
||||
#endif
|
||||
|
||||
@ -3297,11 +3297,28 @@ template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, 2, dequantize_q8_0, 256>;
|
||||
|
||||
#undef FA_TYPES
|
||||
#undef S8x8_T
|
||||
#undef S4x4_T
|
||||
#undef S4_T
|
||||
#undef S_T
|
||||
|
||||
// NOTE: can use half instead of float precision for some extra perf
|
||||
// however, by default use F32 since the op should be mostly memory bandwidth bound
|
||||
// D - head size, Q - queries per threadgroup, C - cache items per threadgroup
|
||||
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &), short D, short Q = 1, short C = 32>
|
||||
template<
|
||||
typename q4_t,
|
||||
typename q4x4_t,
|
||||
typename k4x4_t,
|
||||
typename v4x4_t,
|
||||
typename s_t, // attention accumulation types
|
||||
typename s4_t,
|
||||
typename s4x4_t,
|
||||
typename o4x4_t,
|
||||
typename block_q,
|
||||
short nl_k,
|
||||
void (*deq_k)(device const block_q *, short, thread k4x4_t &),
|
||||
short nl_v,
|
||||
void (*deq_v)(device const block_q *, short, thread v4x4_t &),
|
||||
short D, // head size
|
||||
short Q = 1, // queries per threadgroup
|
||||
short C = 32> // cache items per threadgroup
|
||||
kernel void kernel_flash_attn_ext_vec(
|
||||
device const char * q,
|
||||
device const char * k,
|
||||
@ -3350,37 +3367,39 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
const short NW4 = NW/4;
|
||||
const short SH = C; // shared memory per simdgroup in (half)
|
||||
|
||||
const short T = D + 2*nsg*SH; // shared memory size per query in (half)
|
||||
const short SF = sizeof(s_t)/sizeof(half);
|
||||
|
||||
//threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
|
||||
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
|
||||
threadgroup half4x4 * sq44 = (threadgroup half4x4 *) (shared + 0*D); // same as above but in half4x4
|
||||
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention
|
||||
threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4
|
||||
threadgroup float4x4 * sr44 = (threadgroup float4x4 *) (shared + 2*sgitg*D + Q*T); // scratch buffer for the results
|
||||
const short T = D + SF*nsg*SH; // shared memory size per query in (half)
|
||||
|
||||
//threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data
|
||||
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0*D); // same as above but in half4
|
||||
threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shared + 0*D); // same as above but in half4x4
|
||||
threadgroup s_t * ss = (threadgroup s_t *) (shared + SF*sgitg*SH + 1*D); // scratch buffer for attention
|
||||
threadgroup s4_t * ss4 = (threadgroup s4_t *) (shared + SF*sgitg*SH + 1*D); // same as above but in half4
|
||||
threadgroup s4x4_t * sr4x4 = (threadgroup s4x4_t *) (shared + SF*sgitg*D + Q*T); // scratch buffer for the results
|
||||
|
||||
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
|
||||
float4x4 lo[D16/NW4];
|
||||
o4x4_t lo[D16/NW4];
|
||||
|
||||
// load heads from Q to shared memory
|
||||
device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));
|
||||
|
||||
for (short i = tiisg; i < D4; i += NW) {
|
||||
if (iq1 < ne01) {
|
||||
sq4[i] = (half4) q4[i];
|
||||
sq4[i] = (q4_t) q4[i];
|
||||
} else {
|
||||
sq4[i] = 0.0h;
|
||||
sq4[i] = (q4_t) (float4) 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
// zero out lo
|
||||
for (short i = 0; i < D16/NW4; i += NW4) {
|
||||
lo[i] = float4x4(0.0f);
|
||||
lo[i] = (o4x4_t) 0.0f;
|
||||
}
|
||||
|
||||
// zero out shared memory SH
|
||||
for (short i = tiisg; i < SH/4; i += NW) {
|
||||
ss4[i] = 0.0h;
|
||||
ss4[i] = (s4_t) (float4) 0.0f;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
@ -3412,10 +3431,10 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
const short iv3 = iq3/rv3;
|
||||
|
||||
// load the queries from shared memory into local memory
|
||||
float4x4 mq[D16/NW4];
|
||||
k4x4_t mq[D16/NW4];
|
||||
|
||||
for (short ii = 0; ii < D16; ii += NW4) {
|
||||
mq[ii/NW4] = (float4x4) sq44[ii + tx];
|
||||
mq[ii/NW4] = (k4x4_t) sq4x4[ii + tx];
|
||||
}
|
||||
|
||||
// pointer to the mask
|
||||
@ -3445,7 +3464,7 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
{
|
||||
// each simdgroup processes 1 query and 4 keys
|
||||
for (short cc = 0; cc < C/4; ++cc) {
|
||||
float mqk = 0.0;
|
||||
s_t mqk = 0.0;
|
||||
|
||||
device const block_q * pk = (device const block_q *) ((device const char *) k + ((ic + 4*cc + ty)*nb11 + ik2*nb12 + ik3*nb13));
|
||||
|
||||
@ -3453,8 +3472,8 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
for (short ii = 0; ii < D16; ii += NW4) {
|
||||
const short i = ii + tx;
|
||||
|
||||
float4x4 mk;
|
||||
dequantize_func(pk + i/nl, i%nl, mk);
|
||||
k4x4_t mk;
|
||||
deq_k(pk + i/nl_k, i%nl_k, mk);
|
||||
|
||||
mqk +=
|
||||
dot(mq[ii/NW4][0], mk[0]) +
|
||||
@ -3482,7 +3501,7 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
mqk = logit_softcap*precise::tanh(mqk);
|
||||
}
|
||||
|
||||
mqk += (mask != q) ? ((float) mp[ic + 4*cc + ty])*slope : (float) 0.0f;
|
||||
mqk += (s_t) ((mask != q) ? ((float) mp[ic + 4*cc + ty])*slope : (float) 0.0f);
|
||||
|
||||
ss[4*cc + ty] = mqk;
|
||||
}
|
||||
@ -3523,16 +3542,16 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
for (short cc = 0; cc < C/4; ++cc) {
|
||||
device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 4*cc + ty)*nb21 + iv2*nb22 + iv3*nb23));
|
||||
|
||||
const float4x4 lss(ss[4*cc + ty]);
|
||||
const s4x4_t ms(ss[4*cc + ty]);
|
||||
|
||||
#pragma unroll
|
||||
for (short ii = 0; ii < D16; ii += NW4) {
|
||||
const short i = ii + tx;
|
||||
|
||||
float4x4 mv;
|
||||
dequantize_func(pv4 + i/nl, i%nl, mv);
|
||||
v4x4_t mv;
|
||||
deq_v(pv4 + i/nl_v, i%nl_v, mv);
|
||||
|
||||
lo[ii/NW4] += mv*lss;
|
||||
lo[ii/NW4] += mv*ms;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -3540,8 +3559,8 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
|
||||
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
|
||||
if (tiisg == 0) {
|
||||
ss[0] = S;
|
||||
ss[1] = M;
|
||||
ss[0] = (s_t) S;
|
||||
ss[1] = (s_t) M;
|
||||
}
|
||||
}
|
||||
|
||||
@ -3570,7 +3589,7 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
|
||||
// store results to shared memory
|
||||
for (short i = tiisg; i < D16; i += NW4) {
|
||||
sr44[i] = lo[i/NW4];
|
||||
sr4x4[i] = lo[i/NW4];
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
@ -3592,13 +3611,13 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
const float S = S0*ms0 + S1*ms1;
|
||||
|
||||
if (tiisg == 0) {
|
||||
ss[0] = S;
|
||||
ss[1] = M;
|
||||
ss[0] = (s_t) S;
|
||||
ss[1] = (s_t) M;
|
||||
}
|
||||
|
||||
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
|
||||
for (short i = tiisg; i < D16; i += NW) {
|
||||
sr44[i] = sr44[i]*ms0 + sr44[i + r*D16]*ms1;
|
||||
sr4x4[i] = sr4x4[i]*ms0 + sr4x4[i + r*D16]*ms1;
|
||||
}
|
||||
}
|
||||
|
||||
@ -3612,26 +3631,45 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
const float S = ss[0];
|
||||
|
||||
for (short i = tiisg; i < D16; i += NW) {
|
||||
dst44[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D16 + i] = sr44[i]/S;
|
||||
dst44[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D16 + i] = (float4x4) sr4x4[i]/S;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_flash_attn_ext_vec<half4x4, 1, dequantize_f16, 64>) flash_attn_ext_vec_t;
|
||||
// NOTE: can use half instead of float precision for some extra perf
|
||||
// however, by default use F32 since the op should be mostly memory bandwidth bound
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<half4x4, 1, dequantize_f16, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q4_0, 2, dequantize_q4_0, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q4_1, 2, dequantize_q4_1, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q5_0, 2, dequantize_q5_0, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q5_1, 2, dequantize_q5_1, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q8_0, 2, dequantize_q8_0, 128>;
|
||||
#define S_T float
|
||||
#define S4_T float4
|
||||
#define S4x4_T float4x4
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<half4x4, 1, dequantize_f16, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q4_0, 2, dequantize_q4_0, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q4_1, 2, dequantize_q4_1, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q5_0, 2, dequantize_q5_0, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q5_1, 2, dequantize_q5_1, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q8_0, 2, dequantize_q8_0, 256>;
|
||||
#define FA_TYPES \
|
||||
half4, half4x4, \
|
||||
float4x4, \
|
||||
float4x4, \
|
||||
float, float4, float4x4, \
|
||||
float4x4
|
||||
|
||||
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, 1, dequantize_f16, 64>) flash_attn_ext_vec_t;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, 1, dequantize_f16, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 2, dequantize_q4_0, 2, dequantize_q4_0, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 2, dequantize_q4_1, 2, dequantize_q4_1, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 2, dequantize_q5_0, 2, dequantize_q5_0, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 2, dequantize_q5_1, 2, dequantize_q5_1, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 2, dequantize_q8_0, 2, dequantize_q8_0, 128>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, 1, dequantize_f16, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 2, dequantize_q4_0, 2, dequantize_q4_0, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 2, dequantize_q4_1, 2, dequantize_q4_1, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 2, dequantize_q5_0, 2, dequantize_q5_0, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 2, dequantize_q5_1, 2, dequantize_q5_1, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 2, dequantize_q8_0, 2, dequantize_q8_0, 256>;
|
||||
|
||||
#undef FA_TYPES
|
||||
#undef S4x4_T
|
||||
#undef S4_T
|
||||
#undef S_T
|
||||
|
||||
template<typename T0, typename T1>
|
||||
kernel void kernel_cpy(
|
||||
|
Loading…
Reference in New Issue
Block a user