metal : parallel reduce across heads

This commit is contained in:
Georgi Gerganov 2024-01-21 22:44:41 +02:00
parent 77d08f3272
commit 17720fad66
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 22 additions and 14 deletions

View File

@ -2252,8 +2252,8 @@ static bool ggml_metal_graph_compute(
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26];
[encoder setBytes:&scale length:sizeof( float) atIndex:27];
const int64_t nwarps = 16;
const int64_t nhptg = 4; // heads per threadgroup
const int64_t nwarps = 32;
const int64_t nhptg = 2; // heads per threadgroup
const size_t smem = (nhptg*ne00 + nwarps*(nhptg*ne00 + 32))*(sizeof(float)/2);

View File

@ -2103,6 +2103,7 @@ kernel void kernel_flash_attn_ext_f16(
half4 s4 = 0.0h;
#pragma unroll
for (int64_t i = 0; i < D4/tph; ++i) {
s4 += pq4[hiisg*D4 + tph*i + tiih] * pk4[tph*i + tiih];
}
@ -2114,17 +2115,18 @@ kernel void kernel_flash_attn_ext_f16(
if (tiih == 0) {
half s = 0.0h;
#pragma unroll
for (int64_t i = 0; i < tph; ++i) {
s += ss[hiisg*tph + i];
}
s = s*scale + mv;
const half Mold = M;
const half m = M;
M = max(M, s);
const half ms = exp(Mold - M);
const half ms = exp(m - M);
const half vs = exp(s - M);
S = S*ms + vs;
@ -2138,6 +2140,7 @@ kernel void kernel_flash_attn_ext_f16(
const half ms = ss[2*hiisg + 0];
const half vs = ss[2*hiisg + 1];
#pragma unroll
for (int64_t i = 0; i < D4/tph; ++i) {
ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms + pv4[tph*i + tiih]*vs;
}
@ -2151,12 +2154,12 @@ kernel void kernel_flash_attn_ext_f16(
threadgroup_barrier(mem_flags::mem_threadgroup);
// reduce the warps
if (sgitg == 0 && tiih == 0) {
if (sgitg == 0) {
for (int64_t sg = 1; sg < nsg; ++sg) {
const half S0 = S;
const half S0 = ss[ 2*hiisg + 0];
const half S1 = ss[sg*(R*D + 32) + 2*hiisg + 0];
const half M0 = M;
const half M0 = ss[ 2*hiisg + 1];
const half M1 = ss[sg*(R*D + 32) + 2*hiisg + 1];
M = max(M0, M1);
@ -2166,13 +2169,18 @@ kernel void kernel_flash_attn_ext_f16(
S = S0*ms0 + S1*ms1;
for (int64_t i = 0; i < D4; ++i) {
ps4[hiisg*D4 + i] = ps4[hiisg*D4 + i]*ms0 + ps4[sg*(R*D + 32)/4 + hiisg*D4 + i]*ms1;
if (tiih == 0) {
ss[2*hiisg + 0] = S;
ss[2*hiisg + 1] = M;
}
for (int64_t i = 0; i < D4/tph; ++i) {
ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms0 + ps4[sg*(R*D + 32)/4 + hiisg*D4 + tph*i + tiih]*ms1;
}
}
for (int64_t i = 0; i < D4; ++i) {
ps4[hiisg*D4 + i] /= S;
for (int64_t i = 0; i < D4/tph; ++i) {
ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]/S;
}
}
@ -2192,9 +2200,9 @@ kernel void kernel_flash_attn_ext_f16(
}
}
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 4>;
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 4>;
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 4>;
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 2>;
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 2>;
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 2>;
kernel void kernel_cpy_f16_f16(
device const half * src0,