diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 098b55e07..7130209e7 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6443,11 +6443,11 @@ static __global__ void flash_attn_ext_f16( const int iq2 = blockIdx.y; const int iq1 = blockIdx.x * Q; - const int D2 = D/2; + const int D2 = D/2; const int D16 = D/16; const int Q16 = Q/16; - const int NW = WARP_SIZE; - const int SH = (C + Q); // shared memory per simdgroup in (half) + const int NW = WARP_SIZE; + const int SH = (C + Q); // shared memory per simdgroup in (half) const int T = D + num_warps*SH; // shared memory size per query in (half) const int T2 = T/2; // shared memory size per query in (half2) @@ -6665,8 +6665,7 @@ static __global__ void flash_attn_ext_f16( nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + 16*j, lo[j][i], T, nvcuda::wmma::mem_row_major); nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + 16*j, T); - nvcuda::wmma::fill_fragment(lo[j][i], 0.0); - nvcuda::wmma::mma_sync(lo[j][i], mm, lob, lo[j][i]); + nvcuda::wmma::mma_sync(lo[j][i], mm, lob, zr); } // restore zeros @@ -6760,9 +6759,8 @@ static __global__ void flash_attn_ext_f16( nvcuda::wmma::load_matrix_sync(ms1, ss + 16*j*T + C + 16*j + sg*SH, T); for (int64_t i = 0; i < D16; ++i) { - nvcuda::wmma::fill_fragment(t2, 0.0); nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T); - nvcuda::wmma::mma_sync(t2, ms1, t, t2); + nvcuda::wmma::mma_sync(t2, ms1, t, zr); // convert accumulator to matrix_b nvcuda::wmma::store_matrix_sync( sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);