From 2c77d799f9387f5971289139aaca23b4ce37c435 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 2 Oct 2024 10:36:22 -0400 Subject: [PATCH] metal : attempt to adapt SSM_SCAN for Mamba-2 --- ggml/src/ggml-metal.m | 101 ++++++++++++++++++-------- ggml/src/ggml-metal.metal | 148 ++++++++++++++++++++++++++++++++------ 2 files changed, 200 insertions(+), 49 deletions(-) diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 9da08fe2e..5d5b98307 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -95,6 +95,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_NORM, GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, + GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP, GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, @@ -591,6 +592,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(void) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP, ssm_scan_f32_group, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction); @@ -1629,47 +1631,74 @@ static void ggml_metal_encode_node( struct ggml_tensor * src3 = node->src[3]; struct ggml_tensor * src4 = node->src[4]; struct ggml_tensor * src5 = node->src[5]; + struct ggml_tensor * src6 = node->src[6]; + struct ggml_tensor * src7 = node->src[7]; GGML_ASSERT(src3); GGML_ASSERT(src4); GGML_ASSERT(src5); + GGML_ASSERT(src6); + GGML_ASSERT(src7); size_t offs_src3 = 0; size_t offs_src4 = 0; size_t offs_src5 = 0; + size_t offs_src6 = 0; + size_t offs_src7 = 0; id id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil; id id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil; id id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil; + id id_src6 = src6 ? ggml_metal_get_buffer(src6, &offs_src6) : nil; + id id_src7 = src7 ? ggml_metal_get_buffer(src7, &offs_src7) : nil; - const int64_t ne30 = src3->ne[0]; GGML_UNUSED(ne30); + const int64_t ne30 = src3->ne[0]; const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31); const uint64_t nb30 = src3->nb[0]; const uint64_t nb31 = src3->nb[1]; const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40); - const int64_t ne41 = src4->ne[1]; GGML_UNUSED(ne41); + const int64_t ne41 = src4->ne[1]; const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42); + const int64_t ne43 = src4->ne[3]; GGML_UNUSED(ne43); const uint64_t nb40 = src4->nb[0]; const uint64_t nb41 = src4->nb[1]; const uint64_t nb42 = src4->nb[2]; + const uint64_t nb43 = src4->nb[3]; const int64_t ne50 = src5->ne[0]; GGML_UNUSED(ne50); const int64_t ne51 = src5->ne[1]; GGML_UNUSED(ne51); const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52); + const int64_t ne53 = src5->ne[3]; GGML_UNUSED(ne53); const uint64_t nb50 = src5->nb[0]; const uint64_t nb51 = src5->nb[1]; const uint64_t nb52 = src5->nb[2]; + const uint64_t nb53 = src5->nb[3]; + + const int64_t ne60 = src6->ne[0]; GGML_UNUSED(ne60); + + const uint64_t nb60 = src6->nb[0]; + + const int64_t ne70 = src7->ne[0]; GGML_UNUSED(ne70); + + const uint64_t nb70 = src7->nb[0]; const int64_t d_state = ne00; const int64_t d_inner = ne01; + const int64_t n_head = ne02; + const int64_t n_group = ne41; const int64_t n_seq_tokens = ne11; - const int64_t n_seqs = ne02; + const int64_t n_seqs = ne13; - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline; + if (ne30 == 1) { + // Mamba-2 + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP].pipeline; + } else { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline; + } [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; @@ -1678,33 +1707,49 @@ static void ggml_metal_encode_node( [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4]; [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:6]; + [encoder setBuffer:id_src6 offset:offs_src6 atIndex:6]; + [encoder setBuffer:id_src7 offset:offs_src7 atIndex:7]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:8]; - [encoder setBytes:&d_state length:sizeof(d_state) atIndex:7]; - [encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:8]; - [encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:9]; - [encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:10]; + [encoder setBytes:&d_state length:sizeof(d_state) atIndex:9]; + [encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:10]; + [encoder setBytes:&n_head length:sizeof(n_head) atIndex:11]; + [encoder setBytes:&n_group length:sizeof(n_group) atIndex:12]; + [encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:13]; + [encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:14]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:11]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:12]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:13]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17]; - [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:18]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:19]; - [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:20]; - [encoder setBytes:&nb30 length:sizeof(nb30) atIndex:21]; - [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22]; - [encoder setBytes:&nb40 length:sizeof(nb40) atIndex:23]; - [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24]; - [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25]; - [encoder setBytes:&nb50 length:sizeof(nb50) atIndex:26]; - [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27]; - [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:15]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:16]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:17]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:18]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:19]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:20]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:21]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:22]; + [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:23]; + [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:24]; + [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:25]; + [encoder setBytes:&nb23 length:sizeof(nb23) atIndex:26]; + [encoder setBytes:&nb30 length:sizeof(nb30) atIndex:27]; + [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:28]; + [encoder setBytes:&nb40 length:sizeof(nb40) atIndex:29]; + [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:30]; + [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:31]; + [encoder setBytes:&nb43 length:sizeof(nb43) atIndex:32]; + [encoder setBytes:&nb50 length:sizeof(nb50) atIndex:33]; + [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:34]; + [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:35]; + [encoder setBytes:&nb53 length:sizeof(nb53) atIndex:36]; + [encoder setBytes:&nb60 length:sizeof(nb60) atIndex:37]; + [encoder setBytes:&nb70 length:sizeof(nb70) atIndex:38]; - [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + if (ne30 == 1) { + // Mamba-2 + [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } else { + GGML_ASSERT(d_inner == 1); + [encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } } break; case GGML_OP_MUL_MAT: { diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 2b2000323..c75fa25c3 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -795,7 +795,7 @@ kernel void kernel_ssm_conv_f32( x[0] = sumf; } -// ref: ggml.c:ggml_compute_forward_ssm_scan_f32 +// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part // TODO: optimize kernel void kernel_ssm_scan_f32( device const void * src0, @@ -804,14 +804,19 @@ kernel void kernel_ssm_scan_f32( device const void * src3, device const void * src4, device const void * src5, + device const void * src6, + device const void * src7, device float * dst, constant int64_t & d_state, constant int64_t & d_inner, + constant int64_t & n_head, + constant int64_t & n_group, constant int64_t & n_seq_tokens, constant int64_t & n_seqs, constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, @@ -824,47 +829,148 @@ kernel void kernel_ssm_scan_f32( constant uint64_t & nb40, constant uint64_t & nb41, constant uint64_t & nb42, + constant uint64_t & nb43, constant uint64_t & nb50, constant uint64_t & nb51, constant uint64_t & nb52, + constant uint64_t & nb53, + constant uint64_t & nb60, + constant uint64_t & nb70, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { - const int64_t ir = tgpig.x; - const int64_t i3 = tgpig.y; + const int64_t i1 = 0; + const int64_t ir = tgpig.x; // current head + const int64_t i3 = tgpig.y; // current seq const int64_t nc = d_state; const int64_t nr = d_inner; + const int64_t nh = n_head; + const int64_t ng = n_group; const int64_t n_t = n_seq_tokens; const int64_t n_s = n_seqs; + const int64_t s_off = d_inner * n_head * n_seq_tokens * n_seqs * sizeof(float); + + device const int32_t * ids = (device const int32_t *) src7; + + device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + ids[i3]*nb03); + device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb03 + s_off); + for (int64_t i2 = 0; i2 < n_t; ++i2) { - device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + i3*nb02); - device const float * x = (device const float *) ((device const char *) src1 + ir*nb10 + i2*nb11 + i3*nb12); - device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*nb21 + i3*nb22); - device const float * A = (device const float *) ((device const char *) src3 + ir*nb31); - device const float * B = (device const float *) ((device const char *) src4 + i2*nb41 + i3*nb42); - device const float * C = (device const float *) ((device const char *) src5 + i2*nb51 + i3*nb52); - device float * y = (device float *) ((device char *) dst + ir*nb10 + i2*nb11 + i3*nb12); // TODO: do not use src1 strides - device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb02 + nb13); + device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*nb11 + i2*nb12 + i3*nb13); // {dim, nh, nt, ns} + device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*nb21 + i3*nb22); // {nh, nt, ns} + device const float * A = (device const float *) ((device const char *) src3 + ir*nb31); // {d_state, nh} + device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*nb41 + i2*nb42 + i3*nb43); // {d_state, ng, nt, ns} + device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*nb51 + i2*nb52 + i3*nb53); // {d_state, ng, nt, ns} + device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh} + device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*sizeof(float)); // {dim, nh, nt, ns} - if (i2 > 0) { - s0 = s; - } - - // i1 == 0 - float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; - float x_dt = x[0] * dt_soft_plus; + const float dt_soft_plus = dt[0] <= 20.0f ? log1pf(expf(dt[0])) : dt[0]; + const float x_dt = x[0] * dt_soft_plus; float sumf = 0.0f; for (int64_t i0 = 0; i0 < nc; ++i0) { - int64_t i = i0; - float state = (s0[i] * exp(dt_soft_plus * A[i])) + (B[i0] * x_dt); + const int64_t i = i0 + i1*nc; + const float state = (s0[i] * expf(dt_soft_plus * A[i0])) + (B[i0] * x_dt); sumf += state * C[i0]; s[i] = state; } - y[0] = sumf; + y[0] = sumf + x[0] * D[0]; + + // recurse + s0 = s; + } +} + +// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part +// TODO: optimize (e.g. by parallelizing over d_state) +kernel void kernel_ssm_scan_f32_group( + device const void * src0, + device const void * src1, + device const void * src2, + device const void * src3, + device const void * src4, + device const void * src5, + device const void * src6, + device const void * src7, + device float * dst, + constant int64_t & d_state, + constant int64_t & d_inner, + constant int64_t & n_head, + constant int64_t & n_group, + constant int64_t & n_seq_tokens, + constant int64_t & n_seqs, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant uint64_t & nb20, + constant uint64_t & nb21, + constant uint64_t & nb22, + constant uint64_t & nb30, + constant uint64_t & nb31, + constant uint64_t & nb40, + constant uint64_t & nb41, + constant uint64_t & nb42, + constant uint64_t & nb43, + constant uint64_t & nb50, + constant uint64_t & nb51, + constant uint64_t & nb52, + constant uint64_t & nb53, + constant uint64_t & nb60, + constant uint64_t & nb70, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i1 = tgpig.x; + const int64_t ir = tgpig.y; // current head + const int64_t i3 = tgpig.z; // current seq + + const int64_t nc = d_state; + const int64_t nr = d_inner; + const int64_t nh = n_head; + const int64_t ng = n_group; + const int64_t n_t = n_seq_tokens; + const int64_t n_s = n_seqs; + + const int64_t s_off = d_inner * n_head * n_seq_tokens * n_seqs * sizeof(float); + + device const int32_t * ids = (device const int32_t *) src7; + + device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + ids[i3]*nb03); + device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb03 + s_off); + + for (int64_t i2 = 0; i2 < n_t; ++i2) { + device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*nb11 + i2*nb12 + i3*nb13); // {dim, nh, nt, ns} + device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*nb21 + i3*nb22); // {nh, nt, ns} + device const float * A = (device const float *) ((device const char *) src3 + ir*nb31); // {1, nh} + device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*nb41 + i2*nb42 + i3*nb43); // {d_state, ng, nt, ns} + device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*nb51 + i2*nb52 + i3*nb53); // {d_state, ng, nt, ns} + device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh} + device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*sizeof(float)); // {dim, nh, nt, ns} + + const float dt_soft_plus = dt[0] <= 20.0f ? log1pf(expf(dt[0])) : dt[0]; + const float x_dt = x[0] * dt_soft_plus; + const float dA = expf(dt_soft_plus * A[0]); + float sumf = 0.0f; + + for (int64_t i0 = 0; i0 < nc; ++i0) { + const int64_t i = i0 + i1*nc; + const float state = (s0[i] * dA) + (B[i0] * x_dt); + sumf += state * C[i0]; + s[i] = state; + } + + y[0] = sumf + x[0] * D[0]; + + // recurse + s0 = s; } }