mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-30 21:34:36 +00:00
wip
This commit is contained in:
parent
035c4f01e6
commit
5cbdba693d
@ -2253,16 +2253,16 @@ static bool ggml_metal_graph_compute(
|
|||||||
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26];
|
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26];
|
||||||
[encoder setBytes:&scale length:sizeof( float) atIndex:27];
|
[encoder setBytes:&scale length:sizeof( float) atIndex:27];
|
||||||
|
|
||||||
const int64_t nwarps = 4;
|
const int64_t nsg = 4; // simdgroups per threadgroup (a.k.a. warps)
|
||||||
const int64_t nhptg = 2; // heads per threadgroup !! sync with kernel template arguments !!
|
const int64_t nhptg = 2; // heads per threadgroup !! sync with kernel template arguments !!
|
||||||
const int64_t nqptg = 4; // queries per threadgroup !! sync with kernel template arguments !!
|
const int64_t nqptg = 4; // queries per threadgroup !! sync with kernel template arguments !!
|
||||||
|
|
||||||
const size_t smem = nqptg*(nhptg*ne00 + nwarps*(nhptg*ne00 + 256))*(sizeof(float)/2);
|
const size_t smem = nqptg*(nhptg*ne00 + nsg*(nhptg*ne00 + 256))*(sizeof(float)/2);
|
||||||
|
|
||||||
GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
|
GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
|
||||||
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, (ne02 + nhptg - 1)/(nhptg), ne03) threadsPerThreadgroup:MTLSizeMake(32, nwarps, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, (ne02 + nhptg - 1)/(nhptg), ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_DUP:
|
case GGML_OP_DUP:
|
||||||
case GGML_OP_CPY:
|
case GGML_OP_CPY:
|
||||||
|
@ -2091,8 +2091,7 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
|
|
||||||
// load H heads from Q to shared memory
|
// load H heads from Q to shared memory
|
||||||
for (int64_t i = 0; i < D4/tph; ++i) {
|
for (int64_t i = 0; i < D4/tph; ++i) {
|
||||||
if (sgitg < Q) {
|
for (int64_t j = sgitg; j < Q; j += nsg) {
|
||||||
const int64_t j = sgitg;
|
|
||||||
if (iq1 + j < ne01) {
|
if (iq1 + j < ne01) {
|
||||||
pq4[j*T4 + hiisg*D4 + tph*i + tiih] = ((device const half4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)))[tph*i + tiih];
|
pq4[j*T4 + hiisg*D4 + tph*i + tiih] = ((device const half4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)))[tph*i + tiih];
|
||||||
} else {
|
} else {
|
||||||
@ -2180,28 +2179,28 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
|
|
||||||
simdgroup_barrier(mem_flags::mem_none);
|
simdgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
|
for (int64_t i = 0; i < D4/tph; ++i) {
|
||||||
|
half4 pv4v[8];
|
||||||
|
|
||||||
for (int p = 0; p < 8; ++p) {
|
for (int p = 0; p < 8; ++p) {
|
||||||
const int64_t ic = iic + p;
|
const int64_t ic = iic + p;
|
||||||
|
|
||||||
device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23));
|
device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23));
|
||||||
|
|
||||||
half ms[Q] = { 1.0h };
|
pv4v[p] = pv4[tph*i + tiih];
|
||||||
half vs[Q] = { 0.0h };
|
|
||||||
|
|
||||||
for (int64_t j = 0; j < Q; ++j) {
|
|
||||||
ms[j] = ss[j*T + 32*p + 2*hiisg + 0];
|
|
||||||
vs[j] = ss[j*T + 32*p + 2*hiisg + 1];
|
|
||||||
}
|
|
||||||
|
|
||||||
thread half4 pv4v[D4/tph];
|
|
||||||
for (int64_t i = 0; i < D4/tph; ++i) {
|
|
||||||
pv4v[i] = pv4[tph*i + tiih];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int64_t j = 0; j < Q; ++j) {
|
for (int64_t j = 0; j < Q; ++j) {
|
||||||
for (int64_t i = 0; i < D4/tph; ++i) {
|
half4 ps4v = ps4[j*T4 + hiisg*D4 + tph*i + tiih];
|
||||||
ps4[j*T4 + hiisg*D4 + tph*i + tiih] = ps4[j*T4 + hiisg*D4 + tph*i + tiih]*ms[j] + pv4v[i]*vs[j];
|
|
||||||
|
for (int p = 0; p < 8; ++p) {
|
||||||
|
const half ms = ss[j*T + 32*p + 2*hiisg + 0];
|
||||||
|
const half vs = ss[j*T + 32*p + 2*hiisg + 1];
|
||||||
|
|
||||||
|
ps4v = ps4v*ms + pv4v[p]*vs;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ps4[j*T4 + hiisg*D4 + tph*i + tiih] = ps4v;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user