metal : parallel reduce across heads

This commit is contained in:
Georgi Gerganov 2024-01-21 22:44:41 +02:00
parent 77d08f3272
commit 17720fad66
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 22 additions and 14 deletions

View File

@ -2252,8 +2252,8 @@ static bool ggml_metal_graph_compute(
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26];
[encoder setBytes:&scale length:sizeof( float) atIndex:27]; [encoder setBytes:&scale length:sizeof( float) atIndex:27];
const int64_t nwarps = 16; const int64_t nwarps = 32;
const int64_t nhptg = 4; // heads per threadgroup const int64_t nhptg = 2; // heads per threadgroup
const size_t smem = (nhptg*ne00 + nwarps*(nhptg*ne00 + 32))*(sizeof(float)/2); const size_t smem = (nhptg*ne00 + nwarps*(nhptg*ne00 + 32))*(sizeof(float)/2);

View File

@ -2103,6 +2103,7 @@ kernel void kernel_flash_attn_ext_f16(
half4 s4 = 0.0h; half4 s4 = 0.0h;
#pragma unroll
for (int64_t i = 0; i < D4/tph; ++i) { for (int64_t i = 0; i < D4/tph; ++i) {
s4 += pq4[hiisg*D4 + tph*i + tiih] * pk4[tph*i + tiih]; s4 += pq4[hiisg*D4 + tph*i + tiih] * pk4[tph*i + tiih];
} }
@ -2114,17 +2115,18 @@ kernel void kernel_flash_attn_ext_f16(
if (tiih == 0) { if (tiih == 0) {
half s = 0.0h; half s = 0.0h;
#pragma unroll
for (int64_t i = 0; i < tph; ++i) { for (int64_t i = 0; i < tph; ++i) {
s += ss[hiisg*tph + i]; s += ss[hiisg*tph + i];
} }
s = s*scale + mv; s = s*scale + mv;
const half Mold = M; const half m = M;
M = max(M, s); M = max(M, s);
const half ms = exp(Mold - M); const half ms = exp(m - M);
const half vs = exp(s - M); const half vs = exp(s - M);
S = S*ms + vs; S = S*ms + vs;
@ -2138,6 +2140,7 @@ kernel void kernel_flash_attn_ext_f16(
const half ms = ss[2*hiisg + 0]; const half ms = ss[2*hiisg + 0];
const half vs = ss[2*hiisg + 1]; const half vs = ss[2*hiisg + 1];
#pragma unroll
for (int64_t i = 0; i < D4/tph; ++i) { for (int64_t i = 0; i < D4/tph; ++i) {
ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms + pv4[tph*i + tiih]*vs; ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms + pv4[tph*i + tiih]*vs;
} }
@ -2151,12 +2154,12 @@ kernel void kernel_flash_attn_ext_f16(
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
// reduce the warps // reduce the warps
if (sgitg == 0 && tiih == 0) { if (sgitg == 0) {
for (int64_t sg = 1; sg < nsg; ++sg) { for (int64_t sg = 1; sg < nsg; ++sg) {
const half S0 = S; const half S0 = ss[ 2*hiisg + 0];
const half S1 = ss[sg*(R*D + 32) + 2*hiisg + 0]; const half S1 = ss[sg*(R*D + 32) + 2*hiisg + 0];
const half M0 = M; const half M0 = ss[ 2*hiisg + 1];
const half M1 = ss[sg*(R*D + 32) + 2*hiisg + 1]; const half M1 = ss[sg*(R*D + 32) + 2*hiisg + 1];
M = max(M0, M1); M = max(M0, M1);
@ -2166,13 +2169,18 @@ kernel void kernel_flash_attn_ext_f16(
S = S0*ms0 + S1*ms1; S = S0*ms0 + S1*ms1;
for (int64_t i = 0; i < D4; ++i) { if (tiih == 0) {
ps4[hiisg*D4 + i] = ps4[hiisg*D4 + i]*ms0 + ps4[sg*(R*D + 32)/4 + hiisg*D4 + i]*ms1; ss[2*hiisg + 0] = S;
ss[2*hiisg + 1] = M;
}
for (int64_t i = 0; i < D4/tph; ++i) {
ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms0 + ps4[sg*(R*D + 32)/4 + hiisg*D4 + tph*i + tiih]*ms1;
} }
} }
for (int64_t i = 0; i < D4; ++i) { for (int64_t i = 0; i < D4/tph; ++i) {
ps4[hiisg*D4 + i] /= S; ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]/S;
} }
} }
@ -2192,9 +2200,9 @@ kernel void kernel_flash_attn_ext_f16(
} }
} }
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 4>; template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 2>;
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 4>; template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 2>;
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 4>; template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 2>;
kernel void kernel_cpy_f16_f16( kernel void kernel_cpy_f16_f16(
device const half * src0, device const half * src0,