mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-07 17:21:46 +00:00
metal : reorder write loop
This commit is contained in:
parent
5b359bb1e3
commit
535050572a
@ -6410,11 +6410,22 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
|
|
||||||
if (sgitg == 0) {
|
if (sgitg == 0) {
|
||||||
for (int i = 0; i < n_rows; i++) {
|
|
||||||
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
|
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
|
||||||
*(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
|
device float * D = dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*ne0 + im*ne1*ne0;
|
||||||
|
device float4 * D4 = (device float4 *) D;
|
||||||
|
|
||||||
|
threadgroup float * C = temp_str + (j*BLOCK_SIZE_M);
|
||||||
|
threadgroup float4 * C4 = (threadgroup float4 *) C;
|
||||||
|
|
||||||
|
int i = 0;
|
||||||
|
for (; i < n_rows/4; i++) {
|
||||||
|
*(D4 + i) = *(C4 + i);
|
||||||
|
}
|
||||||
|
|
||||||
|
i *= 4;
|
||||||
|
for (; i < n_rows; i++) {
|
||||||
|
*(D + i) = *(C + i);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user