diff --git a/ggml-cuda.cu b/ggml-cuda.cu index c98b551b3..67541a61e 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6715,9 +6715,6 @@ static __global__ void flash_attn_ext_f16( // reduce the warps sequentially for (int sg = 1; sg < num_warps; ++sg) { - half S = __float2half(0.0f); - half M = CUDART_MIN_DENORM_FP16; - __syncthreads(); // each simdgroup stores its output to shared memory, reusing sq @@ -6733,27 +6730,25 @@ static __global__ void flash_attn_ext_f16( // the first simdgroup accumulates the results from the other simdgroups if (warp_id == 0) { - for (int j = 0; j < Q; ++j) { + for (int j = lane_id; j < Q; j += NW) { const half S0 = ss[j*T + 0]; const half S1 = ss[j*T + sg*SH + 0]; const half M0 = ss[j*T + 1]; const half M1 = ss[j*T + sg*SH + 1]; - M = __hmax(M0, M1); + const half M = __hmax(M0, M1); const half ms0 = hexp(M0 - M); const half ms1 = hexp(M1 - M); - S = S0*ms0 + S1*ms1; + const half S = S0*ms0 + S1*ms1; - if (lane_id == 0) { - ss[j*T + 0] = S; - ss[j*T + 1] = M; + ss[j*T + 0] = S; + ss[j*T + 1] = M; - ss[j*T + C + j ] = ms0; - ss[j*T + C + j + sg*SH] = ms1; - } + ss[j*T + C + j ] = ms0; + ss[j*T + C + j + sg*SH] = ms1; } // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 @@ -10931,6 +10926,8 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * const int nqpb = NQPB; // queries per block const int ncpw = NCPW; // cache values per warp (does not work for other values) + GGML_ASSERT(NQPB <= 32); + const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much? // TODO: produces wrong results for nwarps > 8 (RTX 2060) - not sure why const int nwarps = Q->ne[1] <= nqpb ? std::max(2, std::min((int) K->ne[1]/ncpw, nwarps_max)) : 1;