mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-28 12:24:35 +00:00
metal : parallel reduce across heads
This commit is contained in:
parent
77d08f3272
commit
17720fad66
@ -2252,8 +2252,8 @@ static bool ggml_metal_graph_compute(
|
||||
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26];
|
||||
[encoder setBytes:&scale length:sizeof( float) atIndex:27];
|
||||
|
||||
const int64_t nwarps = 16;
|
||||
const int64_t nhptg = 4; // heads per threadgroup
|
||||
const int64_t nwarps = 32;
|
||||
const int64_t nhptg = 2; // heads per threadgroup
|
||||
|
||||
const size_t smem = (nhptg*ne00 + nwarps*(nhptg*ne00 + 32))*(sizeof(float)/2);
|
||||
|
||||
|
@ -2103,6 +2103,7 @@ kernel void kernel_flash_attn_ext_f16(
|
||||
|
||||
half4 s4 = 0.0h;
|
||||
|
||||
#pragma unroll
|
||||
for (int64_t i = 0; i < D4/tph; ++i) {
|
||||
s4 += pq4[hiisg*D4 + tph*i + tiih] * pk4[tph*i + tiih];
|
||||
}
|
||||
@ -2114,17 +2115,18 @@ kernel void kernel_flash_attn_ext_f16(
|
||||
if (tiih == 0) {
|
||||
half s = 0.0h;
|
||||
|
||||
#pragma unroll
|
||||
for (int64_t i = 0; i < tph; ++i) {
|
||||
s += ss[hiisg*tph + i];
|
||||
}
|
||||
|
||||
s = s*scale + mv;
|
||||
|
||||
const half Mold = M;
|
||||
const half m = M;
|
||||
|
||||
M = max(M, s);
|
||||
|
||||
const half ms = exp(Mold - M);
|
||||
const half ms = exp(m - M);
|
||||
const half vs = exp(s - M);
|
||||
|
||||
S = S*ms + vs;
|
||||
@ -2138,6 +2140,7 @@ kernel void kernel_flash_attn_ext_f16(
|
||||
const half ms = ss[2*hiisg + 0];
|
||||
const half vs = ss[2*hiisg + 1];
|
||||
|
||||
#pragma unroll
|
||||
for (int64_t i = 0; i < D4/tph; ++i) {
|
||||
ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms + pv4[tph*i + tiih]*vs;
|
||||
}
|
||||
@ -2151,12 +2154,12 @@ kernel void kernel_flash_attn_ext_f16(
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// reduce the warps
|
||||
if (sgitg == 0 && tiih == 0) {
|
||||
if (sgitg == 0) {
|
||||
for (int64_t sg = 1; sg < nsg; ++sg) {
|
||||
const half S0 = S;
|
||||
const half S0 = ss[ 2*hiisg + 0];
|
||||
const half S1 = ss[sg*(R*D + 32) + 2*hiisg + 0];
|
||||
|
||||
const half M0 = M;
|
||||
const half M0 = ss[ 2*hiisg + 1];
|
||||
const half M1 = ss[sg*(R*D + 32) + 2*hiisg + 1];
|
||||
|
||||
M = max(M0, M1);
|
||||
@ -2166,13 +2169,18 @@ kernel void kernel_flash_attn_ext_f16(
|
||||
|
||||
S = S0*ms0 + S1*ms1;
|
||||
|
||||
for (int64_t i = 0; i < D4; ++i) {
|
||||
ps4[hiisg*D4 + i] = ps4[hiisg*D4 + i]*ms0 + ps4[sg*(R*D + 32)/4 + hiisg*D4 + i]*ms1;
|
||||
if (tiih == 0) {
|
||||
ss[2*hiisg + 0] = S;
|
||||
ss[2*hiisg + 1] = M;
|
||||
}
|
||||
|
||||
for (int64_t i = 0; i < D4/tph; ++i) {
|
||||
ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms0 + ps4[sg*(R*D + 32)/4 + hiisg*D4 + tph*i + tiih]*ms1;
|
||||
}
|
||||
}
|
||||
|
||||
for (int64_t i = 0; i < D4; ++i) {
|
||||
ps4[hiisg*D4 + i] /= S;
|
||||
for (int64_t i = 0; i < D4/tph; ++i) {
|
||||
ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]/S;
|
||||
}
|
||||
}
|
||||
|
||||
@ -2192,9 +2200,9 @@ kernel void kernel_flash_attn_ext_f16(
|
||||
}
|
||||
}
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 2>;
|
||||
|
||||
kernel void kernel_cpy_f16_f16(
|
||||
device const half * src0,
|
||||
|
Loading…
Reference in New Issue
Block a user