mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-30 13:24:35 +00:00
metal : improve perf via smaller int registers
This commit is contained in:
parent
6be02b5969
commit
57c03b78b6
131
ggml-metal.metal
131
ggml-metal.metal
@ -2064,8 +2064,8 @@ typedef void (flash_attn_ext_f16_t)(
|
|||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
uint3 ntg[[threads_per_threadgroup]],
|
uint3 ntg[[threads_per_threadgroup]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
ushort tiisg[[thread_index_in_simdgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]);
|
ushort sgitg[[simdgroup_index_in_threadgroup]]);
|
||||||
|
|
||||||
// ref: https://arxiv.org/pdf/2307.08691.pdf
|
// ref: https://arxiv.org/pdf/2307.08691.pdf
|
||||||
template<int64_t D, int64_t Q, int64_t C> // head size, queries per threadgroup, cache items per threadgroup
|
template<int64_t D, int64_t Q, int64_t C> // head size, queries per threadgroup, cache items per threadgroup
|
||||||
@ -2102,22 +2102,22 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
uint3 ntg[[threads_per_threadgroup]],
|
uint3 ntg[[threads_per_threadgroup]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
ushort tiisg[[thread_index_in_simdgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
const uint nsg = ntg.y; // number of simdgroups
|
const short nsg = ntg.y; // number of simdgroups
|
||||||
|
|
||||||
const int64_t iq3 = tgpig[2];
|
const short iq3 = tgpig[2];
|
||||||
const int64_t iq2 = tgpig[1];
|
const short iq2 = tgpig[1];
|
||||||
const int64_t iq1 = tgpig[0]*Q;
|
const short iq1 = tgpig[0]*Q;
|
||||||
|
|
||||||
const int64_t D4 = D/4;
|
const short D4 = D/4;
|
||||||
const int64_t D8 = D/8;
|
const short D8 = D/8;
|
||||||
const int64_t Q8 = Q/8;
|
const short Q8 = Q/8;
|
||||||
const int64_t NW = N_SIMDWIDTH;
|
const short NW = N_SIMDWIDTH;
|
||||||
const int64_t SH = (C + Q); // shared memory per simdgroup in (half)
|
const short SH = (C + Q); // shared memory per simdgroup in (half)
|
||||||
|
|
||||||
const int64_t T = D + nsg*SH; // shared memory size per query in (half)
|
const short T = D + nsg*SH; // shared memory size per query in (half)
|
||||||
const int64_t T4 = T/4; // shared memory size per query in (half4)
|
const short T4 = T/4; // shared memory size per query in (half4)
|
||||||
|
|
||||||
threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
|
threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
|
||||||
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
|
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
|
||||||
@ -2127,10 +2127,10 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
simdgroup_half8x8 lo[Q8][D8];
|
simdgroup_half8x8 lo[Q8][D8];
|
||||||
|
|
||||||
// load heads from Q to shared memory
|
// load heads from Q to shared memory
|
||||||
for (int64_t j = sgitg; j < Q; j += nsg) {
|
for (short j = sgitg; j < Q; j += nsg) {
|
||||||
device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
|
device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
|
||||||
|
|
||||||
for (int64_t i = tiisg; i < D4; i += NW) {
|
for (short i = tiisg; i < D4; i += NW) {
|
||||||
if (iq1 + j < ne01) {
|
if (iq1 + j < ne01) {
|
||||||
sq4[j*T4 + i] = (half4) q4[i];
|
sq4[j*T4 + i] = (half4) q4[i];
|
||||||
} else {
|
} else {
|
||||||
@ -2140,15 +2140,15 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// zero out lo
|
// zero out lo
|
||||||
for (int64_t j = 0; j < Q8; ++j) {
|
for (short j = 0; j < Q8; ++j) {
|
||||||
for (int64_t i = 0; i < D8; ++i) {
|
for (short i = 0; i < D8; ++i) {
|
||||||
lo[j][i] = make_filled_simdgroup_matrix<half, 8>(0.0h);
|
lo[j][i] = make_filled_simdgroup_matrix<half, 8>(0.0h);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// zero out shared memory SH
|
// zero out shared memory SH
|
||||||
for (int64_t j = 0; j < Q; ++j) {
|
for (short j = 0; j < Q; ++j) {
|
||||||
for (int64_t i = tiisg; i < SH; i += NW) {
|
for (short i = tiisg; i < SH; i += NW) {
|
||||||
ss[j*T + i] = 0.0h;
|
ss[j*T + i] = 0.0h;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -2160,33 +2160,33 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
half M[Q] = { [0 ... Q-1] = -INFINITY };
|
half M[Q] = { [0 ... Q-1] = -INFINITY };
|
||||||
|
|
||||||
// assume K and V are same shape
|
// assume K and V are same shape
|
||||||
const int64_t ne22 = ne12;
|
const short ne22 = ne12;
|
||||||
const int64_t ne23 = ne13;
|
const short ne23 = ne13;
|
||||||
|
|
||||||
const uint64_t nb21 = nb11;
|
const uint nb21 = nb11;
|
||||||
const uint64_t nb22 = nb12;
|
const uint nb22 = nb12;
|
||||||
const uint64_t nb23 = nb13;
|
const uint nb23 = nb13;
|
||||||
|
|
||||||
// broadcast
|
// broadcast
|
||||||
const int64_t rk2 = ne02/ne12;
|
const short rk2 = ne02/ne12;
|
||||||
const int64_t rk3 = ne03/ne13;
|
const short rk3 = ne03/ne13;
|
||||||
|
|
||||||
const int64_t rv2 = ne02/ne22;
|
const short rv2 = ne02/ne22;
|
||||||
const int64_t rv3 = ne03/ne23;
|
const short rv3 = ne03/ne23;
|
||||||
|
|
||||||
// k indices
|
// k indices
|
||||||
const int64_t ik2 = iq2 / rk2;
|
const short ik2 = iq2 / rk2;
|
||||||
const int64_t ik3 = iq3 / rk3;
|
const short ik3 = iq3 / rk3;
|
||||||
|
|
||||||
// v indices
|
// v indices
|
||||||
const int64_t iv2 = iq2 / rv2;
|
const short iv2 = iq2 / rv2;
|
||||||
const int64_t iv3 = iq3 / rv3;
|
const short iv3 = iq3 / rv3;
|
||||||
|
|
||||||
// load the queries from shared memory into local memory
|
// load the queries from shared memory into local memory
|
||||||
simdgroup_half8x8 mq[Q8][D8];
|
simdgroup_half8x8 mq[Q8][D8];
|
||||||
|
|
||||||
for (int64_t j = 0; j < Q8; ++j) {
|
for (short j = 0; j < Q8; ++j) {
|
||||||
for (int64_t i = 0; i < D8; ++i) {
|
for (short i = 0; i < D8; ++i) {
|
||||||
simdgroup_load(mq[j][i], sq + 8*j*T + i*8, T);
|
simdgroup_load(mq[j][i], sq + 8*j*T + i*8, T);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -2199,28 +2199,33 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
|
|
||||||
// loop over the KV cache
|
// loop over the KV cache
|
||||||
// each simdgroup handles blocks of Q rows and C columns
|
// each simdgroup handles blocks of Q rows and C columns
|
||||||
for (int64_t ic = C*sgitg; ic < ne11; ic += C*nsg) {
|
for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
|
||||||
|
const int ic = ic0 + C*sgitg;
|
||||||
|
if (ic >= ne11) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
// Q*K^T
|
// Q*K^T
|
||||||
{
|
{
|
||||||
for (int cc = 0; cc < C/8; ++cc) {
|
for (short cc = 0; cc < C/8; ++cc) {
|
||||||
simdgroup_half8x8 mqk[Q8];
|
simdgroup_half8x8 mqk[Q8];
|
||||||
for (int64_t j = 0; j < Q8; ++j) {
|
for (short j = 0; j < Q8; ++j) {
|
||||||
mqk[j] = make_filled_simdgroup_matrix<half, 8>(0.h);
|
mqk[j] = make_filled_simdgroup_matrix<half, 8>(0.h);
|
||||||
}
|
}
|
||||||
|
|
||||||
device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
||||||
|
|
||||||
for (int64_t i = 0; i < D8; ++i) {
|
for (short i = 0; i < D8; ++i) {
|
||||||
simdgroup_half8x8 mk;
|
simdgroup_half8x8 mk;
|
||||||
simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose
|
simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose
|
||||||
|
|
||||||
for (int64_t j = 0; j < Q8; ++j) {
|
for (short j = 0; j < Q8; ++j) {
|
||||||
simdgroup_multiply_accumulate(mqk[j], mq[j][i], mk, mqk[j]);
|
simdgroup_multiply_accumulate(mqk[j], mq[j][i], mk, mqk[j]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// mqk = mqk*scale + mask
|
// mqk = mqk*scale + mask
|
||||||
for (int64_t j = 0; j < Q8; ++j) {
|
for (short j = 0; j < Q8; ++j) {
|
||||||
simdgroup_half8x8 mm;
|
simdgroup_half8x8 mm;
|
||||||
simdgroup_load(mm, mp + 8*j*(nb31/sizeof(half)) + ic + 8*cc, nb31/sizeof(half), 0, false);
|
simdgroup_load(mm, mp + 8*j*(nb31/sizeof(half)) + ic + 8*cc, nb31/sizeof(half), 0, false);
|
||||||
simdgroup_multiply_accumulate(mqk[j], mqk[j], mscale, mm);
|
simdgroup_multiply_accumulate(mqk[j], mqk[j], mscale, mm);
|
||||||
@ -2237,8 +2242,8 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
if (C == 32) {
|
if (C == 32) {
|
||||||
half ms[Q];
|
half ms[Q];
|
||||||
|
|
||||||
for (int64_t j = 0; j < Q; ++j) {
|
for (short j = 0; j < Q; ++j) {
|
||||||
const int64_t p = tiisg;
|
const short p = tiisg;
|
||||||
|
|
||||||
const half m = M[j];
|
const half m = M[j];
|
||||||
const half s = ss[j*T + p];
|
const half s = ss[j*T + p];
|
||||||
@ -2262,10 +2267,10 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
} else {
|
} else {
|
||||||
half ms[Q];
|
half ms[Q];
|
||||||
|
|
||||||
for (int64_t j = 0; j < Q; ++j) {
|
for (short j = 0; j < Q; ++j) {
|
||||||
const half m = M[j];
|
const half m = M[j];
|
||||||
|
|
||||||
for (int64_t p = tiisg; p < C; p += NW) {
|
for (short p = tiisg; p < C; p += NW) {
|
||||||
const half s = ss[j*T + p];
|
const half s = ss[j*T + p];
|
||||||
|
|
||||||
smax = max(smax, s);
|
smax = max(smax, s);
|
||||||
@ -2280,7 +2285,7 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
// local sum
|
// local sum
|
||||||
half ls = 0.0h;
|
half ls = 0.0h;
|
||||||
|
|
||||||
for (int64_t p = tiisg; p < C; p += NW) {
|
for (short p = tiisg; p < C; p += NW) {
|
||||||
const half s = ss[j*T + p];
|
const half s = ss[j*T + p];
|
||||||
|
|
||||||
const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]);
|
const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]);
|
||||||
@ -2306,25 +2311,25 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// O = diag(ms)*O
|
// O = diag(ms)*O
|
||||||
for (int64_t j = 0; j < Q8; ++j) {
|
for (short j = 0; j < Q8; ++j) {
|
||||||
simdgroup_half8x8 mm;
|
simdgroup_half8x8 mm;
|
||||||
simdgroup_load(mm, ss + 8*j*T + C + 8*j, T, 0, false);
|
simdgroup_load(mm, ss + 8*j*T + C + 8*j, T, 0, false);
|
||||||
|
|
||||||
for (int64_t i = 0; i < D8; ++i) {
|
for (short i = 0; i < D8; ++i) {
|
||||||
simdgroup_multiply(lo[j][i], mm, lo[j][i]);
|
simdgroup_multiply(lo[j][i], mm, lo[j][i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// O = O + (Q*K^T)*V
|
// O = O + (Q*K^T)*V
|
||||||
{
|
{
|
||||||
for (int cc = 0; cc < C/8; ++cc) {
|
for (short cc = 0; cc < C/8; ++cc) {
|
||||||
device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
||||||
|
|
||||||
for (int64_t i = 0; i < D8; ++i) {
|
for (short i = 0; i < D8; ++i) {
|
||||||
simdgroup_half8x8 mk;
|
simdgroup_half8x8 mk;
|
||||||
simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false);
|
simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false);
|
||||||
|
|
||||||
for (int64_t j = 0; j < Q8; ++j) {
|
for (short j = 0; j < Q8; ++j) {
|
||||||
simdgroup_half8x8 mv;
|
simdgroup_half8x8 mv;
|
||||||
simdgroup_load(mv, ss + 8*j*T + 8*cc, T, 0, false);
|
simdgroup_load(mv, ss + 8*j*T + 8*cc, T, 0, false);
|
||||||
|
|
||||||
@ -2336,7 +2341,7 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
|
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
|
||||||
for (int64_t j = 0; j < Q; ++j) {
|
for (short j = 0; j < Q; ++j) {
|
||||||
if (tiisg == 0) {
|
if (tiisg == 0) {
|
||||||
ss[j*T + 0] = S[j];
|
ss[j*T + 0] = S[j];
|
||||||
ss[j*T + 1] = M[j];
|
ss[j*T + 1] = M[j];
|
||||||
@ -2345,7 +2350,7 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// reduce the warps sequentially
|
// reduce the warps sequentially
|
||||||
for (int64_t sg = 1; sg < nsg; ++sg) {
|
for (short sg = 1; sg < nsg; ++sg) {
|
||||||
half S = { 0.0h };
|
half S = { 0.0h };
|
||||||
half M = { -INFINITY };
|
half M = { -INFINITY };
|
||||||
|
|
||||||
@ -2353,8 +2358,8 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
|
|
||||||
// each simdgroup stores its output to shared memory, reusing sq
|
// each simdgroup stores its output to shared memory, reusing sq
|
||||||
if (sgitg == sg) {
|
if (sgitg == sg) {
|
||||||
for (int64_t j = 0; j < Q8; ++j) {
|
for (short j = 0; j < Q8; ++j) {
|
||||||
for (int64_t i = 0; i < D8; ++i) {
|
for (short i = 0; i < D8; ++i) {
|
||||||
simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false);
|
simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -2364,7 +2369,7 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
|
|
||||||
// the first simdgroup accumulates the results from the other simdgroups
|
// the first simdgroup accumulates the results from the other simdgroups
|
||||||
if (sgitg == 0) {
|
if (sgitg == 0) {
|
||||||
for (int64_t j = 0; j < Q; ++j) {
|
for (short j = 0; j < Q; ++j) {
|
||||||
const half S0 = ss[j*T + 0];
|
const half S0 = ss[j*T + 0];
|
||||||
const half S1 = ss[j*T + sg*SH + 0];
|
const half S1 = ss[j*T + sg*SH + 0];
|
||||||
|
|
||||||
@ -2388,7 +2393,7 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
|
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
|
||||||
for (int64_t j = 0; j < Q8; ++j) {
|
for (short j = 0; j < Q8; ++j) {
|
||||||
simdgroup_half8x8 t;
|
simdgroup_half8x8 t;
|
||||||
simdgroup_half8x8 ms0;
|
simdgroup_half8x8 ms0;
|
||||||
simdgroup_half8x8 ms1;
|
simdgroup_half8x8 ms1;
|
||||||
@ -2396,7 +2401,7 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
simdgroup_load(ms0, ss + 8*j*T + C + 8*j, T, 0, false);
|
simdgroup_load(ms0, ss + 8*j*T + C + 8*j, T, 0, false);
|
||||||
simdgroup_load(ms1, ss + 8*j*T + C + 8*j + sg*SH, T, 0, false);
|
simdgroup_load(ms1, ss + 8*j*T + C + 8*j + sg*SH, T, 0, false);
|
||||||
|
|
||||||
for (int64_t i = 0; i < D8; ++i) {
|
for (short i = 0; i < D8; ++i) {
|
||||||
simdgroup_load (t, sq + 8*j*T + i*8, T, 0, false);
|
simdgroup_load (t, sq + 8*j*T + i*8, T, 0, false);
|
||||||
simdgroup_multiply(t, ms1, t);
|
simdgroup_multiply(t, ms1, t);
|
||||||
|
|
||||||
@ -2408,8 +2413,8 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
|
|
||||||
// store result to shared memory (reuse sq)
|
// store result to shared memory (reuse sq)
|
||||||
if (sgitg == 0) {
|
if (sgitg == 0) {
|
||||||
for (int64_t j = 0; j < Q8; ++j) {
|
for (short j = 0; j < Q8; ++j) {
|
||||||
for (int64_t i = 0; i < D8; ++i) {
|
for (short i = 0; i < D8; ++i) {
|
||||||
simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false);
|
simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -2419,10 +2424,10 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
|
|
||||||
// final rescale with 1/S and store to global memory
|
// final rescale with 1/S and store to global memory
|
||||||
if (sgitg == 0) {
|
if (sgitg == 0) {
|
||||||
for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) {
|
for (short j = 0; j < Q && iq1 + j < ne01; ++j) {
|
||||||
const half S = ss[j*T + 0];
|
const half S = ss[j*T + 0];
|
||||||
|
|
||||||
for (int64_t i = tiisg; i < D4; i += NW) {
|
for (short i = tiisg; i < D4; i += NW) {
|
||||||
dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S;
|
dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user