metal : clean-up (cont)
Some checks failed
flake8 Lint / Lint (push) Has been cancelled

This commit is contained in:
Georgi Gerganov 2024-11-04 17:46:30 +02:00
parent dd0d9ed102
commit 1e129611b1
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 76 additions and 88 deletions

View File

@ -3187,7 +3187,7 @@ static void ggml_metal_encode_node(
}
nsg /= 2;
const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2);
const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + 2*nsg*ne00)*(sizeof(float)/2);
//printf("smem: %zu, max: %zu\n", smem, device.maxThreadgroupMemoryLength);
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);

View File

@ -2819,22 +2819,25 @@ kernel void kernel_flash_attn_ext(
float S[Q] = { [0 ... Q-1] = 0.0h };
float M[Q] = { [0 ... Q-1] = -FLT_MAX/2 };
// thread indices inside the simdgroup
const short tx = tiisg%4;
const short ty = tiisg/4;
// assume K and V are same shape
const short ne22 = ne12;
const short ne23 = ne13;
// broadcast
// broadcast k
const short rk2 = ne02/ne12;
const short rk3 = ne03/ne13;
const short rv2 = ne02/ne22;
const short rv3 = ne03/ne23;
// k indices
const short ik2 = iq2/rk2;
const short ik3 = iq3/rk3;
// v indices
// broadcast v
const short rv2 = ne02/ne22;
const short rv3 = ne03/ne23;
const short iv2 = iq2/rv2;
const short iv3 = iq3/rv3;
@ -2885,15 +2888,12 @@ kernel void kernel_flash_attn_ext(
}
} else {
for (short ii = 0; ii < D16; ii += 4) {
const short i = tiisg%4;
const short j = tiisg/4;
device const block_q * pk4 = (device const block_q *) ((device const char *) k + ((ic + 8*cc + j)*nb11 + ik2*nb12 + ik3*nb13));
device const block_q * pk4 = (device const block_q *) ((device const char *) k + ((ic + 8*cc + ty)*nb11 + ik2*nb12 + ik3*nb13));
if (D16%4 == 0) {
half4x4 tmp;
dequantize_func(pk4 + (ii + i)/nl, (ii + i)%nl, tmp);
skv4[4*j + i] = tmp;
dequantize_func(pk4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
skv4[4*ty + tx] = tmp;
simdgroup_barrier(mem_flags::mem_threadgroup);
@ -2908,10 +2908,10 @@ kernel void kernel_flash_attn_ext(
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
}
} else {
if (ii + i < D16) {
if (ii + tx < D16) {
half4x4 tmp;
dequantize_func(pk4 + (ii + i)/nl, (ii + i)%nl, tmp);
skv4[4*j + i] = tmp;
dequantize_func(pk4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
skv4[4*ty + tx] = tmp;
}
simdgroup_barrier(mem_flags::mem_threadgroup);
@ -3006,15 +3006,12 @@ kernel void kernel_flash_attn_ext(
}
} else {
for (short ii = 0; ii < D16; ii += 4) {
const short i = tiisg%4;
const short j = tiisg/4;
device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 8*cc + j)*nb21 + iv2*nb22 + iv3*nb23));
device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 8*cc + ty)*nb21 + iv2*nb22 + iv3*nb23));
if (D16%4 == 0) {
half4x4 tmp;
dequantize_func(pv4 + (ii + i)/nl, (ii + i)%nl, tmp);
skv4[4*j + i] = tmp;
dequantize_func(pv4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
skv4[4*ty + tx] = tmp;
simdgroup_barrier(mem_flags::mem_threadgroup);
@ -3029,10 +3026,10 @@ kernel void kernel_flash_attn_ext(
simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
}
} else {
if (ii + i < D16) {
if (ii + tx < D16) {
half4x4 tmp;
dequantize_func(pv4 + (ii + i)/nl, (ii + i)%nl, tmp);
skv4[4*j + i] = tmp;
dequantize_func(pv4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
skv4[4*ty + tx] = tmp;
}
simdgroup_barrier(mem_flags::mem_threadgroup);
@ -3187,6 +3184,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_
template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 128>;
template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 256>;
// NOTE: can use half instead of float precision for some extra perf
// D - head size, Q - queries per threadgroup, C - cache items per threadgroup
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &), short D, short Q = 1, short C = 32>
kernel void kernel_flash_attn_ext_vec(
@ -3239,26 +3237,15 @@ kernel void kernel_flash_attn_ext_vec(
const short T = D + 2*nsg*SH; // shared memory size per query in (half)
float slope = 1.0f;
// ALiBi
if (max_bias > 0.0f) {
const uint32_t h = iq2;
const float base = h < n_head_log2 ? m0 : m1;
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
slope = pow(base, exp);
}
//threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
threadgroup half4x4 * sq44 = (threadgroup half4x4 *) (shared + 0*D); // same as above but in half4x4
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4
threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + Q*T); // scratch buffer for the results
threadgroup float4x4 * sr44 = (threadgroup float4x4 *) (shared + 2*sgitg*D + Q*T); // scratch buffer for the results
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
half4x4 lo[D16/NW4];
float4x4 lo[D16/NW4];
// load heads from Q to shared memory
device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));
@ -3273,7 +3260,7 @@ kernel void kernel_flash_attn_ext_vec(
// zero out lo
for (short i = 0; i < D16/NW4; i += NW4) {
lo[i] = half4x4(0.0h);
lo[i] = float4x4(0.0h);
}
// zero out shared memory SH
@ -3284,25 +3271,28 @@ kernel void kernel_flash_attn_ext_vec(
threadgroup_barrier(mem_flags::mem_threadgroup);
{
float S = { 0.0h };
float M = { -FLT_MAX/2 };
float S = 0.0h;
float M = -FLT_MAX/2;
// thread indices inside the simdgroup
const short tx = tiisg%8;
const short ty = tiisg/8;
// assume K and V are same shape
const short ne22 = ne12;
const short ne23 = ne13;
// broadcast
// broadcast k
const short rk2 = ne02/ne12;
const short rk3 = ne03/ne13;
const short rv2 = ne02/ne22;
const short rv3 = ne03/ne23;
// k indices
const short ik2 = iq2/rk2;
const short ik3 = iq3/rk3;
// v indices
// broadcast v
const short rv2 = ne02/ne22;
const short rv3 = ne03/ne23;
const short iv2 = iq2/rv2;
const short iv3 = iq3/rv3;
@ -3310,16 +3300,24 @@ kernel void kernel_flash_attn_ext_vec(
float4x4 mq[D16/NW4];
for (short ii = 0; ii < D16; ii += NW4) {
short i = ii + tiisg%8;
mq[ii/NW4][0] = (float4) sq4[4*i + 0];
mq[ii/NW4][1] = (float4) sq4[4*i + 1];
mq[ii/NW4][2] = (float4) sq4[4*i + 2];
mq[ii/NW4][3] = (float4) sq4[4*i + 3];
mq[ii/NW4] = (float4x4) sq44[ii + tx];
}
// pointer to the mask
device const half * mp = (device const half *) (mask + iq1*nb31);
float slope = 1.0f;
// ALiBi
if (max_bias > 0.0f) {
const uint32_t h = iq2;
const float base = h < n_head_log2 ? m0 : m1;
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
slope = pow(base, exp);
}
// loop over the KV cache
// each simdgroup handles blocks of Q rows and C columns
for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
@ -3331,18 +3329,16 @@ kernel void kernel_flash_attn_ext_vec(
// Q*K^T
{
// each simdgroup processes 1 query and 4 keys
const short j = tiisg/8;
#pragma unroll
for (short cc = 0; cc < C/4; ++cc) {
float mqk = 0.0;
device const block_q * pk = (device const block_q *) ((device const char *) k + ((ic + 4*cc + j)*nb11 + ik2*nb12 + ik3*nb13));
device const block_q * pk = (device const block_q *) ((device const char *) k + ((ic + 4*cc + ty)*nb11 + ik2*nb12 + ik3*nb13));
float4x4 mk;
#pragma unroll
for (short ii = 0; ii < D16; ii += NW4) {
const short i = ii + tiisg%8; // 0..7
const short i = ii + tx;
float4x4 mk;
dequantize_func(pk + i/nl, i%nl, mk);
mqk +=
@ -3364,16 +3360,16 @@ kernel void kernel_flash_attn_ext_vec(
mqk += simd_shuffle_down(mqk, 1);
// mqk = mqk*scale + mask*slope
if (tiisg%8 == 0) {
if (tx == 0) {
mqk *= scale;
if (logit_softcap != 0.0f) {
mqk = logit_softcap*precise::tanh(mqk);
}
mqk += (mask != q) ? ((float) mp[ic + 4*cc + j])*slope : (float) 0.0f;
mqk += (mask != q) ? ((float) mp[ic + 4*cc + ty])*slope : (float) 0.0f;
ss[4*cc + j] = mqk;
ss[4*cc + ty] = mqk;
}
}
}
@ -3408,20 +3404,20 @@ kernel void kernel_flash_attn_ext_vec(
// O = O + (Q*K^T)*V
{
const short j = tiisg/8;
#pragma unroll
for (short cc = 0; cc < C/4; ++cc) {
device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 4*cc + j)*nb21 + iv2*nb22 + iv3*nb23));
device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 4*cc + ty)*nb21 + iv2*nb22 + iv3*nb23));
const float4x4 lss(ss[4*cc + ty]);
float4x4 mv;
const float4x4 lss(ss[4*cc + j]);
#pragma unroll
for (short ii = 0; ii < D16; ii += NW4) {
const short i = ii + tiisg%8;
const short i = ii + tx;
float4x4 mv;
dequantize_func(pv4 + i/nl, i%nl, mv);
lo[ii/NW4] += (half4x4)(mv*lss);
lo[ii/NW4] += mv*lss;
}
}
}
@ -3458,14 +3454,8 @@ kernel void kernel_flash_attn_ext_vec(
}
// store results to shared memory
for (short ii = 0; ii < D16; ii += NW4) {
short i = ii + tiisg;
if (tiisg < 8) {
sr4[4*i + 0] = lo[ii/NW4][0];
sr4[4*i + 1] = lo[ii/NW4][1];
sr4[4*i + 2] = lo[ii/NW4][2];
sr4[4*i + 3] = lo[ii/NW4][3];
}
for (short i = tiisg; i < D16; i += NW4) {
sr44[i] = lo[i/NW4];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
@ -3492,24 +3482,22 @@ kernel void kernel_flash_attn_ext_vec(
}
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
for (short ii = 0; ii < D4; ii += NW) {
short i = ii + tiisg;
sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1;
for (short i = tiisg; i < D16; i += NW) {
sr44[i] = sr44[i]*ms0 + sr44[i + r*D16]*ms1;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
device float4 * dst4 = (device float4 *) dst;
device float4x4 * dst44 = (device float4x4 *) dst;
// final rescale with 1/S and store to global memory
if (sgitg == 0) {
const float S = ss[0];
for (short ii = 0; ii < D4; ii += NW) {
short i = ii + tiisg;
dst4[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D4 + i] = (float4) sr4[i]/S;
for (short i = tiisg; i < D16; i += NW) {
dst44[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D16 + i] = sr44[i]/S;
}
}
}