cuda : avoid zeroing fragments

This commit is contained in:
Georgi Gerganov 2024-02-01 22:08:37 +02:00
parent c6769b9422
commit db1f3c482e
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -6443,11 +6443,11 @@ static __global__ void flash_attn_ext_f16(
const int iq2 = blockIdx.y; const int iq2 = blockIdx.y;
const int iq1 = blockIdx.x * Q; const int iq1 = blockIdx.x * Q;
const int D2 = D/2; const int D2 = D/2;
const int D16 = D/16; const int D16 = D/16;
const int Q16 = Q/16; const int Q16 = Q/16;
const int NW = WARP_SIZE; const int NW = WARP_SIZE;
const int SH = (C + Q); // shared memory per simdgroup in (half) const int SH = (C + Q); // shared memory per simdgroup in (half)
const int T = D + num_warps*SH; // shared memory size per query in (half) const int T = D + num_warps*SH; // shared memory size per query in (half)
const int T2 = T/2; // shared memory size per query in (half2) const int T2 = T/2; // shared memory size per query in (half2)
@ -6665,8 +6665,7 @@ static __global__ void flash_attn_ext_f16(
nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + 16*j, lo[j][i], T, nvcuda::wmma::mem_row_major); nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + 16*j, lo[j][i], T, nvcuda::wmma::mem_row_major);
nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + 16*j, T); nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + 16*j, T);
nvcuda::wmma::fill_fragment(lo[j][i], 0.0); nvcuda::wmma::mma_sync(lo[j][i], mm, lob, zr);
nvcuda::wmma::mma_sync(lo[j][i], mm, lob, lo[j][i]);
} }
// restore zeros // restore zeros
@ -6760,9 +6759,8 @@ static __global__ void flash_attn_ext_f16(
nvcuda::wmma::load_matrix_sync(ms1, ss + 16*j*T + C + 16*j + sg*SH, T); nvcuda::wmma::load_matrix_sync(ms1, ss + 16*j*T + C + 16*j + sg*SH, T);
for (int64_t i = 0; i < D16; ++i) { for (int64_t i = 0; i < D16; ++i) {
nvcuda::wmma::fill_fragment(t2, 0.0);
nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T); nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T);
nvcuda::wmma::mma_sync(t2, ms1, t, t2); nvcuda::wmma::mma_sync(t2, ms1, t, zr);
// convert accumulator to matrix_b // convert accumulator to matrix_b
nvcuda::wmma::store_matrix_sync( sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); nvcuda::wmma::store_matrix_sync( sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);