mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-11-15 07:19:53 +00:00
metal : add back n_seqs to SSM_SCAN args
Whoops, this is needed for the offset in the concatenated output.
This commit is contained in:
parent
7a351abc28
commit
8b15bc6fa0
@ -1718,22 +1718,23 @@ static void ggml_metal_encode_node(
|
|||||||
[encoder setBytes:&n_head length:sizeof(n_head) atIndex:11];
|
[encoder setBytes:&n_head length:sizeof(n_head) atIndex:11];
|
||||||
[encoder setBytes:&n_group length:sizeof(n_group) atIndex:12];
|
[encoder setBytes:&n_group length:sizeof(n_group) atIndex:12];
|
||||||
[encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:13];
|
[encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:13];
|
||||||
|
[encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:14];
|
||||||
|
|
||||||
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:14];
|
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:15];
|
||||||
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:15];
|
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:16];
|
||||||
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:16];
|
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:17];
|
||||||
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:17];
|
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18];
|
||||||
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:18];
|
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19];
|
||||||
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:19];
|
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:20];
|
||||||
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:20];
|
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:21];
|
||||||
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:21];
|
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:22];
|
||||||
[encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22];
|
[encoder setBytes:&nb31 length:sizeof(nb31) atIndex:23];
|
||||||
[encoder setBytes:&nb41 length:sizeof(nb41) atIndex:23];
|
[encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24];
|
||||||
[encoder setBytes:&nb42 length:sizeof(nb42) atIndex:24];
|
[encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25];
|
||||||
[encoder setBytes:&nb43 length:sizeof(nb43) atIndex:25];
|
[encoder setBytes:&nb43 length:sizeof(nb43) atIndex:26];
|
||||||
[encoder setBytes:&nb51 length:sizeof(nb51) atIndex:26];
|
[encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27];
|
||||||
[encoder setBytes:&nb52 length:sizeof(nb52) atIndex:27];
|
[encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28];
|
||||||
[encoder setBytes:&nb53 length:sizeof(nb53) atIndex:28];
|
[encoder setBytes:&nb53 length:sizeof(nb53) atIndex:29];
|
||||||
// NOTE: max index is 31
|
// NOTE: max index is 31
|
||||||
|
|
||||||
if (ne30 == 1) {
|
if (ne30 == 1) {
|
||||||
|
@ -812,6 +812,7 @@ kernel void kernel_ssm_scan_f32(
|
|||||||
constant int64_t & n_head,
|
constant int64_t & n_head,
|
||||||
constant int64_t & n_group,
|
constant int64_t & n_group,
|
||||||
constant int64_t & n_seq_tokens,
|
constant int64_t & n_seq_tokens,
|
||||||
|
constant int64_t & n_seqs,
|
||||||
constant uint64_t & nb01,
|
constant uint64_t & nb01,
|
||||||
constant uint64_t & nb02,
|
constant uint64_t & nb02,
|
||||||
constant uint64_t & nb03,
|
constant uint64_t & nb03,
|
||||||
@ -896,6 +897,7 @@ kernel void kernel_ssm_scan_f32_group(
|
|||||||
constant int64_t & n_head,
|
constant int64_t & n_head,
|
||||||
constant int64_t & n_group,
|
constant int64_t & n_group,
|
||||||
constant int64_t & n_seq_tokens,
|
constant int64_t & n_seq_tokens,
|
||||||
|
constant int64_t & n_seqs,
|
||||||
constant uint64_t & nb01,
|
constant uint64_t & nb01,
|
||||||
constant uint64_t & nb02,
|
constant uint64_t & nb02,
|
||||||
constant uint64_t & nb03,
|
constant uint64_t & nb03,
|
||||||
|
Loading…
Reference in New Issue
Block a user