diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 5d5b98307..477f720a0 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -1693,11 +1693,13 @@ static void ggml_metal_encode_node( const int64_t n_seq_tokens = ne11; const int64_t n_seqs = ne13; + id pipeline = nil; + if (ne30 == 1) { // Mamba-2 - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP].pipeline; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP].pipeline; } else { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline; } [encoder setComputePipelineState:pipeline];