metal : reorder write loop

This commit is contained in:
Georgi Gerganov 2024-11-08 15:15:25 +02:00
parent e40b85abfb
commit 80b5b51bd8
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -6403,11 +6403,22 @@ kernel void kernel_mul_mm(device const uchar * src0,
threadgroup_barrier(mem_flags::mem_threadgroup);
device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
if (sgitg == 0) {
for (int i = 0; i < n_rows; i++) {
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
*(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
device float * D = dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*ne0 + im*ne1*ne0;
device float4 * D4 = (device float4 *) D;
threadgroup float * C = temp_str + (j*BLOCK_SIZE_M);
threadgroup float4 * C4 = (threadgroup float4 *) C;
int i = 0;
for (; i < n_rows/4; i++) {
*(D4 + i) = *(C4 + i);
}
i *= 4;
for (; i < n_rows; i++) {
*(D + i) = *(C + i);
}
}
}