diff --git a/ggml-cuda.cu b/ggml-cuda.cu index deda4cc70..4d1fb008c 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6513,9 +6513,9 @@ static __global__ void flash_attn_ext_f16( half S[Q]; half M[Q]; - for(int i = 0; i < Q; i++) { + for (int i = 0; i < Q; ++i) { S[i] = __float2half(0.0f); - M[i] = __float2half(-INFINITY); + M[i] = CUDART_MIN_DENORM_FP16; } // assume K and V are same shape @@ -6609,69 +6609,44 @@ static __global__ void flash_attn_ext_f16( half smax = __float2half(-INFINITY); // online softmax - if (C == 32) { - for (int j = 0; j < Q; ++j) { - const int p = lane_id; + for (int j = 0; j < Q; ++j) { + const half m = M[j]; + + for (int p0 = 0; p0 < C; p0 += NW) { + const int p = p0 + lane_id; - const half m = M[j]; const half s = ss[j*T + p]; - smax = warp_reduce_max(__hmax(smax, s)); - M[j] = warp_reduce_max(__hmax(M[j], s)); + smax = __hmax(smax, s); + M[j] = __hmax(M[j], s); + } - const half ms = __hisinf(m) == -1 ? __float2half(0.0f) : hexp(m - M[j]); - const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]); + M[j] = warp_reduce_max(M[j]); - S[j] = S[j]*ms + warp_reduce_sum(vs); + const half ms = hexp(m - M[j]); - // create a QxQ diagonal matrix for rescaling the output - if (p == j) { - ss[j*T + C + j] = ms; - } + // create a QxQ diagonal matrix for rescaling the output + if (lane_id == j) { + ss[j*T + C + j] = ms; + } + + // local sum + half ls = 0.0f; + + for (int p0 = 0; p0 < C; p0 += NW) { + const int p = p0 + lane_id; + + const half s = ss[j*T + p]; + + const half vs = hexp(s - M[j]); + + ls += vs; // the P matrix from the paper (Q rows, C columns) ss[j*T + p] = vs; } - } else { - for (int j = 0; j < Q; ++j) { - const half m = M[j]; - for (int p0 = 0; p0 < C; p0 += NW) { - const int p = p0 + lane_id; - - const half s = ss[j*T + p]; - - smax = __hmax(smax, s); - M[j] = __hmax(M[j], s); - } - - M[j] = warp_reduce_max(M[j]); - - const half ms = __hisinf(m) == -1 ? __float2half(0.0f) : hexp(m - M[j]); - - // create a QxQ diagonal matrix for rescaling the output - if (lane_id == j) { - ss[j*T + C + j] = ms; - } - - // local sum - half ls = 0.0f; - - for (int p0 = 0; p0 < C; p0 += NW) { - const int p = p0 + lane_id; - - const half s = ss[j*T + p]; - - const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]); - - ls += vs; - - // the P matrix from the paper (Q rows, C columns) - ss[j*T + p] = vs; - } - - S[j] = S[j]*ms + warp_reduce_sum(ls); - } + S[j] = S[j]*ms + warp_reduce_sum(ls); } smax = warp_reduce_max(smax); @@ -6736,7 +6711,7 @@ static __global__ void flash_attn_ext_f16( // reduce the warps sequentially for (int sg = 1; sg < num_warps; ++sg) { half S = __float2half(0.0f); - half M = __float2half(-INFINITY); + half M = CUDART_MIN_DENORM_FP16; __syncthreads(); @@ -6762,8 +6737,8 @@ static __global__ void flash_attn_ext_f16( M = __hmax(M0, M1); - const half ms0 = __hisinf(M0) == -1 ? __float2half(0.0f) : hexp(M0 - M); - const half ms1 = __hisinf(M1) == -1 ? __float2half(0.0f) : hexp(M1 - M); + const half ms0 = hexp(M0 - M); + const half ms1 = hexp(M1 - M); S = S0*ms0 + S1*ms1;