mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 03:44:35 +00:00
Vulkan: Unroll more loops for more mul mat mat performance
This commit is contained in:
parent
a0deeeed28
commit
9622fbe373
@ -196,7 +196,7 @@ void main() {
|
||||
coopmat<float16_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
|
||||
coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
|
||||
|
||||
for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
|
||||
[[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
|
||||
sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);
|
||||
}
|
||||
#else
|
||||
@ -209,7 +209,7 @@ void main() {
|
||||
}
|
||||
#endif
|
||||
|
||||
[[dont_unroll]] for (uint block = start_k; block < end_k; block += BK) {
|
||||
for (uint block = start_k; block < end_k; block += BK) {
|
||||
[[unroll]] for (uint l = 0; l < BM; l += loadstride_a) {
|
||||
|
||||
#if defined(DATA_A_F32) || defined(DATA_A_F16)
|
||||
@ -506,12 +506,12 @@ void main() {
|
||||
pos_b += BK / LOAD_VEC_B;
|
||||
|
||||
#ifdef COOPMAT
|
||||
for (uint i = 0; i < BK; i += TK) {
|
||||
for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
|
||||
[[unroll]] for (uint i = 0; i < BK; i += TK) {
|
||||
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
|
||||
// Load from shared into cache
|
||||
coopMatLoad(cache_a, buf_a, (warp_r * WM + cm_row * TM) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
|
||||
|
||||
for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
|
||||
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
|
||||
coopMatLoad(cache_b, buf_b, (warp_c * WN + cm_col * TN) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
|
||||
sums[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a, cache_b, sums[cm_col * cms_per_row + cm_row]);
|
||||
@ -519,7 +519,7 @@ void main() {
|
||||
}
|
||||
}
|
||||
#else
|
||||
for (uint i = 0; i < BK; i++) {
|
||||
[[unroll]] for (uint i = 0; i < BK; i++) {
|
||||
// Load from shared into cache
|
||||
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
|
||||
[[unroll]] for (uint j = 0; j < TM; j++) {
|
||||
|
Loading…
Reference in New Issue
Block a user