mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-28 12:24:35 +00:00
cuda : fix matrix names
This commit is contained in:
parent
cfd9732b2e
commit
ef68fac2a8
10
ggml-cuda.cu
10
ggml-cuda.cu
@ -6687,19 +6687,19 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
for (int cc = 0; cc < C/16; ++cc) {
|
for (int cc = 0; cc < C/16; ++cc) {
|
||||||
const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
||||||
|
|
||||||
half16x16_b mk[D16];
|
half16x16_b mv[D16];
|
||||||
for (int i = 0; i < D16; ++i) {
|
for (int i = 0; i < D16; ++i) {
|
||||||
nvcuda::wmma::load_matrix_sync(mk[i], pv + i*16, nb21/sizeof(half));
|
nvcuda::wmma::load_matrix_sync(mv[i], pv + i*16, nb21/sizeof(half));
|
||||||
}
|
}
|
||||||
|
|
||||||
half16x16_a mv[Q16];
|
half16x16_a ms[Q16];
|
||||||
for (int j = 0; j < Q16; ++j) {
|
for (int j = 0; j < Q16; ++j) {
|
||||||
nvcuda::wmma::load_matrix_sync(mv[j], ss + 16*j*T + 16*cc, T);
|
nvcuda::wmma::load_matrix_sync(ms[j], ss + 16*j*T + 16*cc, T);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int j = 0; j < Q16; ++j) {
|
for (int j = 0; j < Q16; ++j) {
|
||||||
for (int i = 0; i < D16; ++i) {
|
for (int i = 0; i < D16; ++i) {
|
||||||
nvcuda::wmma::mma_sync(lo[j][i], mv[j], mk[i], lo[j][i]);
|
nvcuda::wmma::mma_sync(lo[j][i], ms[j], mv[i], lo[j][i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user