diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index d9f98bd0d..3a5149d2d 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -3139,7 +3139,6 @@ static void ggml_metal_encode_node( if (!use_vec_kernel) { // half8x8 kernel const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! - const int64_t nkpsg = 8; // keys per simdgroup const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! GGML_ASSERT(nqptg <= 32); @@ -3149,7 +3148,9 @@ static void ggml_metal_encode_node( int64_t nsgmax = 2; while (true) { - const size_t smem = (nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg)) + 4*16*nkpsg*nsgmax)*(sizeof(float)/2); + // 16*32*nsgmax - the shared memory needed for the simdgroups to load the KV cache + // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG + const size_t smem = (nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg)) + 16*32*nsgmax)*(sizeof(float)/2); if (smem > device.maxThreadgroupMemoryLength) { break; } @@ -3160,12 +3161,12 @@ static void ggml_metal_encode_node( // simdgroups per threadgroup (a.k.a. warps) const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4; - const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + 4*16*nkpsg*nsg)*(sizeof(float)/2); + const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + 16*32*nsg)*(sizeof(float)/2); //printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg); GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength); - [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 128) atIndex:0]; + [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; } else { diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 79c3c45d5..3ae1be095 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -2723,45 +2723,9 @@ kernel void kernel_leaky_relu_f32( dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope; } -typedef void (flash_attn_ext_t)( - device const char * q, - device const char * k, - device const char * v, - device const char * mask, - device float * dst, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant uint64_t & nb21, - constant uint64_t & nb22, - constant uint64_t & nb23, - constant uint64_t & nb31, - constant int64_t & ne1, - constant int64_t & ne2, - constant float & scale, - constant float & max_bias, - constant float & m0, - constant float & m1, - constant uint32_t & n_head_log2, - constant float & logit_softcap, - threadgroup half * shared, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]]); - // ref: https://arxiv.org/pdf/2307.08691.pdf -template // head size, queries per threadgroup, cache items per threadgroup +// D - head size, Q - queries per threadgroup, KV - key/value processed per each simdgroup, C - cache items per threadgroup +template kernel void kernel_flash_attn_ext( device const char * q, device const char * k, @@ -2818,8 +2782,8 @@ kernel void kernel_flash_attn_ext( threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix - threadgroup half * skv = (threadgroup half *) (shared + sgitg*(4*16*K) + Q*T); // scratch buffer to load K and V in shared memory - threadgroup half4x4 * skv4 = (threadgroup half4x4 *) (shared + sgitg*(4*16*K) + Q*T); // same as above but in half4x4 + threadgroup half * skv = (threadgroup half *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K and V in shared memory + threadgroup half4x4 * skv4 = (threadgroup half4x4 *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in half4x4 // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) simdgroup_half8x8 lo[D8]; @@ -3179,6 +3143,8 @@ kernel void kernel_flash_attn_ext( } } +typedef decltype(kernel_flash_attn_ext) flash_attn_ext_t; + template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -3223,7 +3189,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_ // D - head size, Q - queries per threadgroup, C - cache items per threadgroup template -kernel void flash_attn_ext_vec( +kernel void kernel_flash_attn_ext_vec( device const char * q, device const char * k, device const char * v, @@ -3548,22 +3514,21 @@ kernel void flash_attn_ext_vec( } } -//template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext_vec_f16<128>; -//template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext_vec_f16<256>; +typedef decltype(kernel_flash_attn_ext_vec) flash_attn_ext_vec_t; -template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_t flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_t flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_t flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_t flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_t flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_t flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_t flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_t flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_t flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_t flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_t flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_t flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; template kernel void kernel_cpy(