cuda : avoid zeroing fragments

This commit is contained in:
Georgi Gerganov 2024-02-01 22:08:37 +02:00
parent c6769b9422
commit db1f3c482e
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -6665,8 +6665,7 @@ static __global__ void flash_attn_ext_f16(
nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + 16*j, lo[j][i], T, nvcuda::wmma::mem_row_major);
nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + 16*j, T);
nvcuda::wmma::fill_fragment(lo[j][i], 0.0);
nvcuda::wmma::mma_sync(lo[j][i], mm, lob, lo[j][i]);
nvcuda::wmma::mma_sync(lo[j][i], mm, lob, zr);
}
// restore zeros
@ -6760,9 +6759,8 @@ static __global__ void flash_attn_ext_f16(
nvcuda::wmma::load_matrix_sync(ms1, ss + 16*j*T + C + 16*j + sg*SH, T);
for (int64_t i = 0; i < D16; ++i) {
nvcuda::wmma::fill_fragment(t2, 0.0);
nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T);
nvcuda::wmma::mma_sync(t2, ms1, t, t2);
nvcuda::wmma::mma_sync(t2, ms1, t, zr);
// convert accumulator to matrix_b
nvcuda::wmma::store_matrix_sync( sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);