From ef68fac2a8b51e2237234e3d7c6120cade457ce8 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 3 Feb 2024 18:36:58 +0200 Subject: [PATCH] cuda : fix matrix names --- ggml-cuda.cu | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 25f810cbe..d9ab2bd09 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6687,19 +6687,19 @@ static __global__ void flash_attn_ext_f16( for (int cc = 0; cc < C/16; ++cc) { const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23)); - half16x16_b mk[D16]; + half16x16_b mv[D16]; for (int i = 0; i < D16; ++i) { - nvcuda::wmma::load_matrix_sync(mk[i], pv + i*16, nb21/sizeof(half)); + nvcuda::wmma::load_matrix_sync(mv[i], pv + i*16, nb21/sizeof(half)); } - half16x16_a mv[Q16]; + half16x16_a ms[Q16]; for (int j = 0; j < Q16; ++j) { - nvcuda::wmma::load_matrix_sync(mv[j], ss + 16*j*T + 16*cc, T); + nvcuda::wmma::load_matrix_sync(ms[j], ss + 16*j*T + 16*cc, T); } for (int j = 0; j < Q16; ++j) { for (int i = 0; i < D16; ++i) { - nvcuda::wmma::mma_sync(lo[j][i], mv[j], mk[i], lo[j][i]); + nvcuda::wmma::mma_sync(lo[j][i], ms[j], mv[i], lo[j][i]); } } }