From 4c190ba6767d7dee914c432af29fb0ef906d24c1 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 28 Mar 2024 21:17:08 +0200 Subject: [PATCH] cuda : reduce registers --- ggml-cuda/fattn.cu | 40 +++++++++++++++++++--------------------- 1 file changed, 19 insertions(+), 21 deletions(-) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index 9b723e809..cfb00af1f 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -175,12 +175,12 @@ static __global__ void flash_attn_ext_f16( const int iv3 = iq3 / rv3; // load the queries from shared memory into local memory - half16x16_a mq[Q16][D16]; - for (int j = 0; j < Q16; ++j) { - for (int i = 0; i < D16; ++i) { - nvcuda::wmma::load_matrix_sync(mq[j][i], sq + 16*j*T + i*16, T); - } - } + //half16x16_a mq[Q16][D16]; + //for (int j = 0; j < Q16; ++j) { + // for (int i = 0; i < D16; ++i) { + // nvcuda::wmma::load_matrix_sync(mq[j][i], sq + 16*j*T + i*16, T); + // } + //} // pointer to the mask const half * mp = mask ? (const half *) (mask + iq1*nb31) : nullptr; @@ -216,7 +216,9 @@ static __global__ void flash_attn_ext_f16( nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half)); for (int j = 0; j < Q16; ++j) { - nvcuda::wmma::mma_sync(mqk[j], mq[j][i], mk, mqk[j]); + half16x16_a mq; + nvcuda::wmma::load_matrix_sync(mq, sq + 16*j*T + i*16, T); + nvcuda::wmma::mma_sync(mqk[j], mq, mk, mqk[j]); } } @@ -319,19 +321,13 @@ static __global__ void flash_attn_ext_f16( for (int cc = 0; cc < C16; ++cc) { const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23)); - half16x16_b mv[D16]; - for (int i = 0; i < D16; ++i) { - nvcuda::wmma::load_matrix_sync(mv[i], pv + i*16, nb21/sizeof(half)); - } - - half16x16_a ms[Q16]; - for (int j = 0; j < Q16; ++j) { - nvcuda::wmma::load_matrix_sync(ms[j], ss + 16*j*T + 16*cc, T); - } - for (int j = 0; j < Q16; ++j) { + half16x16_a ms; + nvcuda::wmma::load_matrix_sync(ms, ss + 16*j*T + 16*cc, T); for (int i = 0; i < D16; ++i) { - nvcuda::wmma::mma_sync(lo[j][i], ms[j], mv[i], lo[j][i]); + half16x16_b mv; + nvcuda::wmma::load_matrix_sync(mv, pv + i*16, nb21/sizeof(half)); + nvcuda::wmma::mma_sync(lo[j][i], ms, mv, lo[j][i]); } } } @@ -554,6 +550,9 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst } break; case 128: { + //const size_t shmem_max = 96*1024; + //cudaFuncSetAttribute(flash_attn_ext_f16<128, NQPB, NCPW>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_max); + flash_attn_ext_f16<128, NQPB, NCPW> <<>> ( (const char *) Q->data, // Query @@ -572,9 +571,8 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst } break; case 256: { - // increase shared memory limit to 64KB - //const size_t shmem_max = 64*1024; - //cudaFuncSetAttribute(flash_attn_ext_f16<256, NQPB, NCPW>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_max); + const size_t shmem_max = 64*1024; + cudaFuncSetAttribute(flash_attn_ext_f16<256, NQPB, NCPW>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_max); flash_attn_ext_f16<256, NQPB, NCPW> <<>> (