diff --git a/ggml-cuda.cu b/ggml-cuda.cu index e51ddc08f..25f810cbe 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6512,11 +6512,10 @@ static __global__ void flash_attn_ext_f16( __syncthreads(); { - half S[Q]; + half S = __float2half(0.0f); half M[Q]; for (int i = 0; i < Q; ++i) { - S[i] = __float2half(0.0f); M[i] = CUDART_MIN_DENORM_FP16; } @@ -6626,13 +6625,6 @@ static __global__ void flash_attn_ext_f16( M[j] = warp_reduce_max(M[j]); - const half ms = hexp(m - M[j]); - - // create a QxQ diagonal matrix for rescaling the output - if (lane_id == j) { - ss[j*T + C + j] = ms; - } - // local sum half2 ls = make_half2(0.0f, 0.0f); half2 M2 = make_half2(M[j], M[j]); @@ -6652,7 +6644,14 @@ static __global__ void flash_attn_ext_f16( ls = warp_reduce_sum(ls); - S[j] = S[j]*ms + ls.x + ls.y; + const half ms = hexp(m - M[j]); + + // create a QxQ diagonal matrix for rescaling the output + if (lane_id == j) { + ss[j*T + C + j] = ms; + + S = S*ms + ls.x + ls.y; + } } smax = warp_reduce_max(smax); @@ -6709,8 +6708,8 @@ static __global__ void flash_attn_ext_f16( // these are needed for reducing the results from the simdgroups (reuse the ss buffer) for (int j = 0; j < Q; ++j) { - if (lane_id == 0) { - ss[j*T + 0] = S[j]; + if (lane_id == j) { + ss[j*T + 0] = S; ss[j*T + 1] = M[j]; } }