mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-30 21:34:36 +00:00
metal : parallelize across KV size
This commit is contained in:
parent
a4b6341c7b
commit
77d08f3272
@ -2252,15 +2252,15 @@ static bool ggml_metal_graph_compute(
|
|||||||
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26];
|
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26];
|
||||||
[encoder setBytes:&scale length:sizeof( float) atIndex:27];
|
[encoder setBytes:&scale length:sizeof( float) atIndex:27];
|
||||||
|
|
||||||
const int64_t nwarps = 8;
|
const int64_t nwarps = 16;
|
||||||
const int64_t nhpw = 4; // heads per warp
|
const int64_t nhptg = 4; // heads per threadgroup
|
||||||
|
|
||||||
const size_t smem = nwarps*(2*nhpw*ne00 + 128)*(sizeof(float)/2);
|
const size_t smem = (nhptg*ne00 + nwarps*(nhptg*ne00 + 32))*(sizeof(float)/2);
|
||||||
|
|
||||||
GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
|
GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
|
||||||
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne02 + nhpw*nwarps - 1)/(nhpw*nwarps), ne03) threadsPerThreadgroup:MTLSizeMake(32*nwarps, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne02 + nhptg - 1)/(nhptg), ne03) threadsPerThreadgroup:MTLSizeMake(32, nwarps, 1)];
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_DUP:
|
case GGML_OP_DUP:
|
||||||
case GGML_OP_CPY:
|
case GGML_OP_CPY:
|
||||||
|
129
ggml-metal.metal
129
ggml-metal.metal
@ -1995,7 +1995,7 @@ typedef void (flash_attn_ext_f16_t)(
|
|||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]);
|
uint sgitg[[simdgroup_index_in_threadgroup]]);
|
||||||
|
|
||||||
template<int64_t D, int64_t R> // head size, rows per warp
|
template<int64_t D, int64_t R> // head size, rows per threadgroup
|
||||||
kernel void kernel_flash_attn_ext_f16(
|
kernel void kernel_flash_attn_ext_f16(
|
||||||
device const char * q,
|
device const char * q,
|
||||||
device const char * k,
|
device const char * k,
|
||||||
@ -2031,15 +2031,11 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
uint3 ntg[[threads_per_threadgroup]],
|
uint3 ntg[[threads_per_threadgroup]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
//const int64_t iq3 = tgpig[2];
|
const uint nsg = ntg.y; // number of simdgroups
|
||||||
//const int64_t iq2 = tgpig[1];
|
|
||||||
//const int64_t iq1 = tgpig[0]*N_SIMDWIDTH + tiisg;
|
|
||||||
|
|
||||||
const uint nsg = ntg.x/N_SIMDWIDTH; // number of simdgroups
|
|
||||||
const uint tph = N_SIMDWIDTH/R; // threads per head
|
const uint tph = N_SIMDWIDTH/R; // threads per head
|
||||||
|
|
||||||
const int64_t iq3 = tgpig[2];
|
const int64_t iq3 = tgpig[2];
|
||||||
const int64_t iq2 = tgpig[1]*(R*nsg) + R*sgitg + tiisg/tph;
|
const int64_t iq2 = tgpig[1]*R + tiisg/tph;
|
||||||
const int64_t iq1 = tgpig[0];
|
const int64_t iq1 = tgpig[0];
|
||||||
|
|
||||||
if (iq2 >= ne02) {
|
if (iq2 >= ne02) {
|
||||||
@ -2073,94 +2069,30 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
|
|
||||||
device const float * mp = mask ? (device const float *) (mask + (ir%ne31)*nb31) : nullptr;
|
device const float * mp = mask ? (device const float *) (mask + (ir%ne31)*nb31) : nullptr;
|
||||||
|
|
||||||
// const int64_t D4 = D/4;
|
|
||||||
//
|
|
||||||
// // TODO: can we move this to the stack?
|
|
||||||
// threadgroup half4x4 * V16 = (threadgroup half4x4 *) (shared);
|
|
||||||
//
|
|
||||||
// // initialize with zeros
|
|
||||||
// for (int64_t d = 0; d < D4; ++d) {
|
|
||||||
//
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 4*D);
|
|
||||||
//
|
|
||||||
// // load Q to shared memory
|
|
||||||
// for (int64_t d = 0; d < D4; ++d) {
|
|
||||||
// pq4[d] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[d];
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// half S = 0.0h;
|
|
||||||
// half M = -INFINITY;
|
|
||||||
//
|
|
||||||
// for (int64_t ic = 0; ic < ne11; ++ic) {
|
|
||||||
// const half mv = mp ? mp[ic] : 0.0h;
|
|
||||||
// if (mv == -INFINITY) {
|
|
||||||
// continue;
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// device const half4 * pk4 = (device const half4 *) ((device char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13));
|
|
||||||
// device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23));
|
|
||||||
//
|
|
||||||
// half4 s4 = 0.0h;
|
|
||||||
//
|
|
||||||
// for (int64_t d = 0; d < D4; ++d) {
|
|
||||||
// s4 += pk4[d] * pq4[d];
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv;
|
|
||||||
//
|
|
||||||
// const half Mold = M;
|
|
||||||
//
|
|
||||||
// M = max(M, s);
|
|
||||||
//
|
|
||||||
// const half ms = exp(Mold - M);
|
|
||||||
// const half vs = exp(s - M);
|
|
||||||
//
|
|
||||||
// for (int64_t d = 0; d < D4; ++d) {
|
|
||||||
// V16[d] = V16[d]*ms + pv4[d]*vs;
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// S = S*ms + vs;
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// for (int64_t d = 0; d < D4; ++d) {
|
|
||||||
// V16[d] /= S;
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// // dst indices
|
|
||||||
// const int64_t i1 = iq1;
|
|
||||||
// const int64_t i2 = iq2;
|
|
||||||
// const int64_t i3 = iq3;
|
|
||||||
//
|
|
||||||
// device float4 * dst4 = (device float4 *) dst;
|
|
||||||
//
|
|
||||||
// for (int64_t d = 0; d < D4; ++d) {
|
|
||||||
// dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + d] = (float4) V16[d];
|
|
||||||
// }
|
|
||||||
|
|
||||||
const int64_t D4 = D/4;
|
const int64_t D4 = D/4;
|
||||||
|
|
||||||
threadgroup half4 * pq4 = (threadgroup half4 *) (shared + sgitg*(2*R*D + 128) + 0*R*D);
|
threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 0*R*D);
|
||||||
threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(2*R*D + 128) + 1*R*D);
|
threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(R*D + 32) + 1*R*D);
|
||||||
threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*(2*R*D + 128) + 2*R*D);
|
threadgroup half * ss = (threadgroup half *) (shared + sgitg*(R*D + 32) + 2*R*D);
|
||||||
threadgroup half * ss = (threadgroup half *) (shared + sgitg*(2*R*D + 128) + 2*R*D);
|
|
||||||
|
|
||||||
const uint tiih = tiisg%tph; // thread index in head
|
const uint tiih = tiisg%tph; // thread index in head
|
||||||
const uint hiisg = tiisg/tph; // head index in simdgroup
|
const uint hiisg = tiisg/tph; // head index in simdgroup
|
||||||
|
|
||||||
// load R heads from Q to shared memory
|
// load R heads from Q to shared memory
|
||||||
for (int64_t i = 0; i < D4/tph; ++i) {
|
for (int64_t i = 0; i < D4/tph; ++i) {
|
||||||
|
if (sgitg == 0) {
|
||||||
pq4[hiisg*D4 + tph*i + tiih] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[tph*i + tiih];
|
pq4[hiisg*D4 + tph*i + tiih] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[tph*i + tiih];
|
||||||
|
}
|
||||||
|
|
||||||
ps4[hiisg*D4 + tph*i + tiih] = 0.0h;
|
ps4[hiisg*D4 + tph*i + tiih] = 0.0h;
|
||||||
}
|
}
|
||||||
|
|
||||||
simdgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
half S = 0.0h;
|
half S = 0.0h;
|
||||||
half M = -INFINITY;
|
half M = -INFINITY;
|
||||||
|
|
||||||
for (int64_t ic = 0; ic < ne11; ++ic) {
|
for (int64_t ic = sgitg; ic < ne11; ic += nsg) {
|
||||||
const half mv = mp ? mp[ic] : 0.0h;
|
const half mv = mp ? mp[ic] : 0.0h;
|
||||||
if (mv == -INFINITY) {
|
if (mv == -INFINITY) {
|
||||||
continue;
|
continue;
|
||||||
@ -2175,18 +2107,18 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
s4 += pq4[hiisg*D4 + tph*i + tiih] * pk4[tph*i + tiih];
|
s4 += pq4[hiisg*D4 + tph*i + tiih] * pk4[tph*i + tiih];
|
||||||
}
|
}
|
||||||
|
|
||||||
ss4[hiisg*tph + tiih] = s4;
|
ss[hiisg*tph + tiih] = (s4.x + s4.y + s4.z + s4.w);
|
||||||
|
|
||||||
simdgroup_barrier(mem_flags::mem_threadgroup);
|
simdgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
if (tiih == 0) {
|
if (tiih == 0) {
|
||||||
s4 = 0.0h;
|
half s = 0.0h;
|
||||||
|
|
||||||
for (int64_t i = 0; i < tph; ++i) {
|
for (int64_t i = 0; i < tph; ++i) {
|
||||||
s4 += ss4[hiisg*tph + i];
|
s += ss[hiisg*tph + i];
|
||||||
}
|
}
|
||||||
|
|
||||||
half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv;
|
s = s*scale + mv;
|
||||||
|
|
||||||
const half Mold = M;
|
const half Mold = M;
|
||||||
|
|
||||||
@ -2211,9 +2143,34 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
simdgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
if (tiih == 0) {
|
if (tiih == 0) {
|
||||||
|
ss[2*hiisg + 0] = S;
|
||||||
|
ss[2*hiisg + 1] = M;
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// reduce the warps
|
||||||
|
if (sgitg == 0 && tiih == 0) {
|
||||||
|
for (int64_t sg = 1; sg < nsg; ++sg) {
|
||||||
|
const half S0 = S;
|
||||||
|
const half S1 = ss[sg*(R*D + 32) + 2*hiisg + 0];
|
||||||
|
|
||||||
|
const half M0 = M;
|
||||||
|
const half M1 = ss[sg*(R*D + 32) + 2*hiisg + 1];
|
||||||
|
|
||||||
|
M = max(M0, M1);
|
||||||
|
|
||||||
|
const half ms0 = exp(M0 - M);
|
||||||
|
const half ms1 = exp(M1 - M);
|
||||||
|
|
||||||
|
S = S0*ms0 + S1*ms1;
|
||||||
|
|
||||||
|
for (int64_t i = 0; i < D4; ++i) {
|
||||||
|
ps4[hiisg*D4 + i] = ps4[hiisg*D4 + i]*ms0 + ps4[sg*(R*D + 32)/4 + hiisg*D4 + i]*ms1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for (int64_t i = 0; i < D4; ++i) {
|
for (int64_t i = 0; i < D4; ++i) {
|
||||||
ps4[hiisg*D4 + i] /= S;
|
ps4[hiisg*D4 + i] /= S;
|
||||||
}
|
}
|
||||||
@ -2228,10 +2185,12 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
|
|
||||||
device float4 * dst4 = (device float4 *) dst;
|
device float4 * dst4 = (device float4 *) dst;
|
||||||
|
|
||||||
|
if (sgitg == 0) {
|
||||||
for (int64_t i = 0; i < D4/tph; ++i) {
|
for (int64_t i = 0; i < D4/tph; ++i) {
|
||||||
dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + tph*i + tiih] = (float4) ps4[hiisg*D4 + tph*i + tiih];
|
dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + tph*i + tiih] = (float4) ps4[hiisg*D4 + tph*i + tiih];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 4>;
|
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 4>;
|
||||||
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 4>;
|
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 4>;
|
||||||
|
Loading…
Reference in New Issue
Block a user