cuda : fix matrix names

This commit is contained in:
Georgi Gerganov 2024-02-03 18:36:58 +02:00
parent cfd9732b2e
commit ef68fac2a8
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -6687,19 +6687,19 @@ static __global__ void flash_attn_ext_f16(
for (int cc = 0; cc < C/16; ++cc) {
const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23));
half16x16_b mk[D16];
half16x16_b mv[D16];
for (int i = 0; i < D16; ++i) {
nvcuda::wmma::load_matrix_sync(mk[i], pv + i*16, nb21/sizeof(half));
nvcuda::wmma::load_matrix_sync(mv[i], pv + i*16, nb21/sizeof(half));
}
half16x16_a mv[Q16];
half16x16_a ms[Q16];
for (int j = 0; j < Q16; ++j) {
nvcuda::wmma::load_matrix_sync(mv[j], ss + 16*j*T + 16*cc, T);
nvcuda::wmma::load_matrix_sync(ms[j], ss + 16*j*T + 16*cc, T);
}
for (int j = 0; j < Q16; ++j) {
for (int i = 0; i < D16; ++i) {
nvcuda::wmma::mma_sync(lo[j][i], mv[j], mk[i], lo[j][i]);
nvcuda::wmma::mma_sync(lo[j][i], ms[j], mv[i], lo[j][i]);
}
}
}