diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 5127b34f8..3f7183060 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -1718,22 +1718,23 @@ static void ggml_metal_encode_node( [encoder setBytes:&n_head length:sizeof(n_head) atIndex:11]; [encoder setBytes:&n_group length:sizeof(n_group) atIndex:12]; [encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:13]; + [encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:14]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:14]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:15]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:16]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:17]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:18]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:19]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:20]; - [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:21]; - [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22]; - [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:23]; - [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:24]; - [encoder setBytes:&nb43 length:sizeof(nb43) atIndex:25]; - [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:26]; - [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:27]; - [encoder setBytes:&nb53 length:sizeof(nb53) atIndex:28]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:15]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:16]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:17]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:20]; + [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:21]; + [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:22]; + [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:23]; + [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24]; + [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25]; + [encoder setBytes:&nb43 length:sizeof(nb43) atIndex:26]; + [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27]; + [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28]; + [encoder setBytes:&nb53 length:sizeof(nb53) atIndex:29]; // NOTE: max index is 31 if (ne30 == 1) { diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 3745f2f22..c36eedb01 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -812,6 +812,7 @@ kernel void kernel_ssm_scan_f32( constant int64_t & n_head, constant int64_t & n_group, constant int64_t & n_seq_tokens, + constant int64_t & n_seqs, constant uint64_t & nb01, constant uint64_t & nb02, constant uint64_t & nb03, @@ -896,6 +897,7 @@ kernel void kernel_ssm_scan_f32_group( constant int64_t & n_head, constant int64_t & n_group, constant int64_t & n_seq_tokens, + constant int64_t & n_seqs, constant uint64_t & nb01, constant uint64_t & nb02, constant uint64_t & nb03,