diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 4d1fb008c..1fed9d23e 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6451,12 +6451,14 @@ static __global__ void flash_attn_ext_f16( const int T = D + num_warps*SH; // shared memory size per query in (half) const int T2 = T/2; // shared memory size per query in (half2) + const int C2 = C/2; extern __shared__ half __flash_attn_f16_shmem[]; // pq half * sq = (half *) (__flash_attn_f16_shmem + 0*D); // holds the query data half2 * sq2 = (half2 *) (__flash_attn_f16_shmem + 0*D); // same as above but in half2 half * ss = (half *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // scratch buffer for attention and diagonal matrix + half2 * ss2 = (half2 *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // same as above but in half2 half16x16_acc zr; half16x16_acc lo[Q16][D16]; @@ -6606,19 +6608,19 @@ static __global__ void flash_attn_ext_f16( } // used to detect blocks full of -INF - half smax = __float2half(-INFINITY); + half2 smax = make_half2(-INFINITY, -INFINITY); // online softmax for (int j = 0; j < Q; ++j) { const half m = M[j]; - for (int p0 = 0; p0 < C; p0 += NW) { + for (int p0 = 0; p0 < C2; p0 += NW) { const int p = p0 + lane_id; - const half s = ss[j*T + p]; + const half2 s = ss2[j*T2 + p]; - smax = __hmax(smax, s); - M[j] = __hmax(M[j], s); + smax = __hmax2(smax, s); + M[j] = __hmax(M[j], __hmax(s.x, s.y)); } M[j] = warp_reduce_max(M[j]); @@ -6631,28 +6633,31 @@ static __global__ void flash_attn_ext_f16( } // local sum - half ls = 0.0f; + half2 ls = make_half2(0.0f, 0.0f); + half2 M2 = make_half2(M[j], M[j]); - for (int p0 = 0; p0 < C; p0 += NW) { + for (int p0 = 0; p0 < C2; p0 += NW) { const int p = p0 + lane_id; - const half s = ss[j*T + p]; + const half2 s = ss2[j*T2 + p]; - const half vs = hexp(s - M[j]); + const half2 vs = h2exp(s - M2); ls += vs; // the P matrix from the paper (Q rows, C columns) - ss[j*T + p] = vs; + ss2[j*T2 + p] = vs; } - S[j] = S[j]*ms + warp_reduce_sum(ls); + ls = warp_reduce_sum(ls); + + S[j] = S[j]*ms + ls.x + ls.y; } smax = warp_reduce_max(smax); // skip -INF blocks - if (__hisinf(smax) == -1) { + if (__hisinf(smax.x) == -1 || __hisinf(smax.y) == -1) { continue; }