mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-28 12:24:35 +00:00
cuda : avoid zeroing fragments
This commit is contained in:
parent
c6769b9422
commit
db1f3c482e
12
ggml-cuda.cu
12
ggml-cuda.cu
@ -6443,11 +6443,11 @@ static __global__ void flash_attn_ext_f16(
|
||||
const int iq2 = blockIdx.y;
|
||||
const int iq1 = blockIdx.x * Q;
|
||||
|
||||
const int D2 = D/2;
|
||||
const int D2 = D/2;
|
||||
const int D16 = D/16;
|
||||
const int Q16 = Q/16;
|
||||
const int NW = WARP_SIZE;
|
||||
const int SH = (C + Q); // shared memory per simdgroup in (half)
|
||||
const int NW = WARP_SIZE;
|
||||
const int SH = (C + Q); // shared memory per simdgroup in (half)
|
||||
|
||||
const int T = D + num_warps*SH; // shared memory size per query in (half)
|
||||
const int T2 = T/2; // shared memory size per query in (half2)
|
||||
@ -6665,8 +6665,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + 16*j, lo[j][i], T, nvcuda::wmma::mem_row_major);
|
||||
nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + 16*j, T);
|
||||
|
||||
nvcuda::wmma::fill_fragment(lo[j][i], 0.0);
|
||||
nvcuda::wmma::mma_sync(lo[j][i], mm, lob, lo[j][i]);
|
||||
nvcuda::wmma::mma_sync(lo[j][i], mm, lob, zr);
|
||||
}
|
||||
|
||||
// restore zeros
|
||||
@ -6760,9 +6759,8 @@ static __global__ void flash_attn_ext_f16(
|
||||
nvcuda::wmma::load_matrix_sync(ms1, ss + 16*j*T + C + 16*j + sg*SH, T);
|
||||
|
||||
for (int64_t i = 0; i < D16; ++i) {
|
||||
nvcuda::wmma::fill_fragment(t2, 0.0);
|
||||
nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T);
|
||||
nvcuda::wmma::mma_sync(t2, ms1, t, t2);
|
||||
nvcuda::wmma::mma_sync(t2, ms1, t, zr);
|
||||
|
||||
// convert accumulator to matrix_b
|
||||
nvcuda::wmma::store_matrix_sync( sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
|
||||
|
Loading…
Reference in New Issue
Block a user