From ab558ac2b3b892cfff5b9f1467bdc74a9d1d8d71 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 13 Dec 2023 10:54:17 +0200 Subject: [PATCH] metal : fix soft_max kernels ref: https://github.com/ggerganov/ggml/pull/621/commits/1914017863d2f9ab8ecc0281cc2a56d683668b92 --- ggml-metal.m | 4 +++- ggml-metal.metal | 16 ++++++++++------ 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 090c84f59..1aa3424d2 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -1269,6 +1269,8 @@ void ggml_metal_graph_compute( [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; if (id_src1) { [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + } else { + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; } [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; @@ -1520,7 +1522,7 @@ void ggml_metal_graph_compute( else if (src0t == GGML_TYPE_Q6_K) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else { - int64_t ny = (ne11 + nrows - 1)/nrows; + const int64_t ny = (ne11 + nrows - 1)/nrows; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } } diff --git a/ggml-metal.metal b/ggml-metal.metal index c246e8645..8b76f969c 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -347,9 +347,9 @@ kernel void kernel_soft_max( const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); - device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - device const float * pmask = src1 ? src1 + i01*ne00 : nullptr; - device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr; + device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; // parallel max float lmax = -INFINITY; @@ -386,6 +386,8 @@ kernel void kernel_soft_max( } float sum = simd_sum(lsum); + threadgroup_barrier(mem_flags::mem_threadgroup); + if (ntg > N_SIMDWIDTH) { if (sgitg == 0) { buf[tiisg] = 0.0f; @@ -428,9 +430,9 @@ kernel void kernel_soft_max_4( const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); - device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - device const float4 * pmask = src1 ? (device const float4 *)(src1 + i01*ne00) : nullptr; - device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr; + device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); // parallel max float4 lmax4 = -INFINITY; @@ -468,6 +470,8 @@ kernel void kernel_soft_max_4( } const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3]; + threadgroup_barrier(mem_flags::mem_threadgroup); + float sum = simd_sum(lsum); if (ntg > N_SIMDWIDTH) { if (sgitg == 0) {