mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-11-15 07:19:53 +00:00
metal : remove unused arguments for SSM_SCAN
The max index is 31, so trimming the arguments is necessary.
This commit is contained in:
parent
03d0e6eabe
commit
7a351abc28
@ -1655,7 +1655,7 @@ static void ggml_metal_encode_node(
|
|||||||
const int64_t ne30 = src3->ne[0];
|
const int64_t ne30 = src3->ne[0];
|
||||||
const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31);
|
const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31);
|
||||||
|
|
||||||
const uint64_t nb30 = src3->nb[0];
|
const uint64_t nb30 = src3->nb[0]; GGML_UNUSED(nb30);
|
||||||
const uint64_t nb31 = src3->nb[1];
|
const uint64_t nb31 = src3->nb[1];
|
||||||
|
|
||||||
const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40);
|
const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40);
|
||||||
@ -1663,7 +1663,7 @@ static void ggml_metal_encode_node(
|
|||||||
const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42);
|
const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42);
|
||||||
const int64_t ne43 = src4->ne[3]; GGML_UNUSED(ne43);
|
const int64_t ne43 = src4->ne[3]; GGML_UNUSED(ne43);
|
||||||
|
|
||||||
const uint64_t nb40 = src4->nb[0];
|
const uint64_t nb40 = src4->nb[0]; GGML_UNUSED(nb40);
|
||||||
const uint64_t nb41 = src4->nb[1];
|
const uint64_t nb41 = src4->nb[1];
|
||||||
const uint64_t nb42 = src4->nb[2];
|
const uint64_t nb42 = src4->nb[2];
|
||||||
const uint64_t nb43 = src4->nb[3];
|
const uint64_t nb43 = src4->nb[3];
|
||||||
@ -1673,18 +1673,18 @@ static void ggml_metal_encode_node(
|
|||||||
const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52);
|
const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52);
|
||||||
const int64_t ne53 = src5->ne[3]; GGML_UNUSED(ne53);
|
const int64_t ne53 = src5->ne[3]; GGML_UNUSED(ne53);
|
||||||
|
|
||||||
const uint64_t nb50 = src5->nb[0];
|
const uint64_t nb50 = src5->nb[0]; GGML_UNUSED(nb50);
|
||||||
const uint64_t nb51 = src5->nb[1];
|
const uint64_t nb51 = src5->nb[1];
|
||||||
const uint64_t nb52 = src5->nb[2];
|
const uint64_t nb52 = src5->nb[2];
|
||||||
const uint64_t nb53 = src5->nb[3];
|
const uint64_t nb53 = src5->nb[3];
|
||||||
|
|
||||||
const int64_t ne60 = src6->ne[0]; GGML_UNUSED(ne60);
|
const int64_t ne60 = src6->ne[0]; GGML_UNUSED(ne60);
|
||||||
|
|
||||||
const uint64_t nb60 = src6->nb[0];
|
const uint64_t nb60 = src6->nb[0]; GGML_UNUSED(nb60);
|
||||||
|
|
||||||
const int64_t ne70 = src7->ne[0]; GGML_UNUSED(ne70);
|
const int64_t ne70 = src7->ne[0]; GGML_UNUSED(ne70);
|
||||||
|
|
||||||
const uint64_t nb70 = src7->nb[0];
|
const uint64_t nb70 = src7->nb[0]; GGML_UNUSED(nb70);
|
||||||
|
|
||||||
const int64_t d_state = ne00;
|
const int64_t d_state = ne00;
|
||||||
const int64_t d_inner = ne01;
|
const int64_t d_inner = ne01;
|
||||||
@ -1718,32 +1718,23 @@ static void ggml_metal_encode_node(
|
|||||||
[encoder setBytes:&n_head length:sizeof(n_head) atIndex:11];
|
[encoder setBytes:&n_head length:sizeof(n_head) atIndex:11];
|
||||||
[encoder setBytes:&n_group length:sizeof(n_group) atIndex:12];
|
[encoder setBytes:&n_group length:sizeof(n_group) atIndex:12];
|
||||||
[encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:13];
|
[encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:13];
|
||||||
[encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:14];
|
|
||||||
|
|
||||||
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:15];
|
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:14];
|
||||||
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:16];
|
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:15];
|
||||||
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:17];
|
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:16];
|
||||||
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:18];
|
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:17];
|
||||||
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:19];
|
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:18];
|
||||||
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:20];
|
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:19];
|
||||||
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:21];
|
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:20];
|
||||||
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:22];
|
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:21];
|
||||||
[encoder setBytes:&nb20 length:sizeof(nb20) atIndex:23];
|
[encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22];
|
||||||
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:24];
|
[encoder setBytes:&nb41 length:sizeof(nb41) atIndex:23];
|
||||||
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:25];
|
[encoder setBytes:&nb42 length:sizeof(nb42) atIndex:24];
|
||||||
[encoder setBytes:&nb23 length:sizeof(nb23) atIndex:26];
|
[encoder setBytes:&nb43 length:sizeof(nb43) atIndex:25];
|
||||||
[encoder setBytes:&nb30 length:sizeof(nb30) atIndex:27];
|
[encoder setBytes:&nb51 length:sizeof(nb51) atIndex:26];
|
||||||
[encoder setBytes:&nb31 length:sizeof(nb31) atIndex:28];
|
[encoder setBytes:&nb52 length:sizeof(nb52) atIndex:27];
|
||||||
[encoder setBytes:&nb40 length:sizeof(nb40) atIndex:29];
|
[encoder setBytes:&nb53 length:sizeof(nb53) atIndex:28];
|
||||||
[encoder setBytes:&nb41 length:sizeof(nb41) atIndex:30];
|
// NOTE: max index is 31
|
||||||
[encoder setBytes:&nb42 length:sizeof(nb42) atIndex:31];
|
|
||||||
[encoder setBytes:&nb43 length:sizeof(nb43) atIndex:32];
|
|
||||||
[encoder setBytes:&nb50 length:sizeof(nb50) atIndex:33];
|
|
||||||
[encoder setBytes:&nb51 length:sizeof(nb51) atIndex:34];
|
|
||||||
[encoder setBytes:&nb52 length:sizeof(nb52) atIndex:35];
|
|
||||||
[encoder setBytes:&nb53 length:sizeof(nb53) atIndex:36];
|
|
||||||
[encoder setBytes:&nb60 length:sizeof(nb60) atIndex:37];
|
|
||||||
[encoder setBytes:&nb70 length:sizeof(nb70) atIndex:38];
|
|
||||||
|
|
||||||
if (ne30 == 1) {
|
if (ne30 == 1) {
|
||||||
// Mamba-2
|
// Mamba-2
|
||||||
|
@ -812,30 +812,21 @@ kernel void kernel_ssm_scan_f32(
|
|||||||
constant int64_t & n_head,
|
constant int64_t & n_head,
|
||||||
constant int64_t & n_group,
|
constant int64_t & n_group,
|
||||||
constant int64_t & n_seq_tokens,
|
constant int64_t & n_seq_tokens,
|
||||||
constant int64_t & n_seqs,
|
|
||||||
constant uint64_t & nb00,
|
|
||||||
constant uint64_t & nb01,
|
constant uint64_t & nb01,
|
||||||
constant uint64_t & nb02,
|
constant uint64_t & nb02,
|
||||||
constant uint64_t & nb03,
|
constant uint64_t & nb03,
|
||||||
constant uint64_t & nb10,
|
|
||||||
constant uint64_t & nb11,
|
constant uint64_t & nb11,
|
||||||
constant uint64_t & nb12,
|
constant uint64_t & nb12,
|
||||||
constant uint64_t & nb13,
|
constant uint64_t & nb13,
|
||||||
constant uint64_t & nb20,
|
|
||||||
constant uint64_t & nb21,
|
constant uint64_t & nb21,
|
||||||
constant uint64_t & nb22,
|
constant uint64_t & nb22,
|
||||||
constant uint64_t & nb30,
|
|
||||||
constant uint64_t & nb31,
|
constant uint64_t & nb31,
|
||||||
constant uint64_t & nb40,
|
|
||||||
constant uint64_t & nb41,
|
constant uint64_t & nb41,
|
||||||
constant uint64_t & nb42,
|
constant uint64_t & nb42,
|
||||||
constant uint64_t & nb43,
|
constant uint64_t & nb43,
|
||||||
constant uint64_t & nb50,
|
|
||||||
constant uint64_t & nb51,
|
constant uint64_t & nb51,
|
||||||
constant uint64_t & nb52,
|
constant uint64_t & nb52,
|
||||||
constant uint64_t & nb53,
|
constant uint64_t & nb53,
|
||||||
constant uint64_t & nb60,
|
|
||||||
constant uint64_t & nb70,
|
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
uint3 ntg[[threads_per_threadgroup]]) {
|
uint3 ntg[[threads_per_threadgroup]]) {
|
||||||
@ -843,12 +834,16 @@ kernel void kernel_ssm_scan_f32(
|
|||||||
const int64_t ir = tgpig.x; // current head
|
const int64_t ir = tgpig.x; // current head
|
||||||
const int64_t i3 = tgpig.y; // current seq
|
const int64_t i3 = tgpig.y; // current seq
|
||||||
|
|
||||||
|
const uint64_t nb00 = sizeof(float);
|
||||||
|
const uint64_t nb10 = sizeof(float);
|
||||||
|
const uint64_t nb20 = sizeof(float);
|
||||||
|
const uint64_t nb60 = sizeof(float);
|
||||||
|
|
||||||
const int64_t nc = d_state;
|
const int64_t nc = d_state;
|
||||||
const int64_t nr = d_inner;
|
const int64_t nr = d_inner;
|
||||||
const int64_t nh = n_head;
|
const int64_t nh = n_head;
|
||||||
const int64_t ng = n_group;
|
const int64_t ng = n_group;
|
||||||
const int64_t n_t = n_seq_tokens;
|
const int64_t n_t = n_seq_tokens;
|
||||||
const int64_t n_s = n_seqs;
|
|
||||||
|
|
||||||
const int64_t s_off = d_inner * n_head * n_seq_tokens * n_seqs * sizeof(float);
|
const int64_t s_off = d_inner * n_head * n_seq_tokens * n_seqs * sizeof(float);
|
||||||
|
|
||||||
@ -864,7 +859,7 @@ kernel void kernel_ssm_scan_f32(
|
|||||||
device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*nb41 + i2*nb42 + i3*nb43); // {d_state, ng, nt, ns}
|
device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*nb41 + i2*nb42 + i3*nb43); // {d_state, ng, nt, ns}
|
||||||
device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*nb51 + i2*nb52 + i3*nb53); // {d_state, ng, nt, ns}
|
device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*nb51 + i2*nb52 + i3*nb53); // {d_state, ng, nt, ns}
|
||||||
device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh}
|
device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh}
|
||||||
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*sizeof(float)); // {dim, nh, nt, ns}
|
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
|
||||||
|
|
||||||
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
|
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
|
||||||
const float x_dt = x[0] * dt_soft_plus;
|
const float x_dt = x[0] * dt_soft_plus;
|
||||||
@ -901,30 +896,21 @@ kernel void kernel_ssm_scan_f32_group(
|
|||||||
constant int64_t & n_head,
|
constant int64_t & n_head,
|
||||||
constant int64_t & n_group,
|
constant int64_t & n_group,
|
||||||
constant int64_t & n_seq_tokens,
|
constant int64_t & n_seq_tokens,
|
||||||
constant int64_t & n_seqs,
|
|
||||||
constant uint64_t & nb00,
|
|
||||||
constant uint64_t & nb01,
|
constant uint64_t & nb01,
|
||||||
constant uint64_t & nb02,
|
constant uint64_t & nb02,
|
||||||
constant uint64_t & nb03,
|
constant uint64_t & nb03,
|
||||||
constant uint64_t & nb10,
|
|
||||||
constant uint64_t & nb11,
|
constant uint64_t & nb11,
|
||||||
constant uint64_t & nb12,
|
constant uint64_t & nb12,
|
||||||
constant uint64_t & nb13,
|
constant uint64_t & nb13,
|
||||||
constant uint64_t & nb20,
|
|
||||||
constant uint64_t & nb21,
|
constant uint64_t & nb21,
|
||||||
constant uint64_t & nb22,
|
constant uint64_t & nb22,
|
||||||
constant uint64_t & nb30,
|
|
||||||
constant uint64_t & nb31,
|
constant uint64_t & nb31,
|
||||||
constant uint64_t & nb40,
|
|
||||||
constant uint64_t & nb41,
|
constant uint64_t & nb41,
|
||||||
constant uint64_t & nb42,
|
constant uint64_t & nb42,
|
||||||
constant uint64_t & nb43,
|
constant uint64_t & nb43,
|
||||||
constant uint64_t & nb50,
|
|
||||||
constant uint64_t & nb51,
|
constant uint64_t & nb51,
|
||||||
constant uint64_t & nb52,
|
constant uint64_t & nb52,
|
||||||
constant uint64_t & nb53,
|
constant uint64_t & nb53,
|
||||||
constant uint64_t & nb60,
|
|
||||||
constant uint64_t & nb70,
|
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
uint3 ntg[[threads_per_threadgroup]]) {
|
uint3 ntg[[threads_per_threadgroup]]) {
|
||||||
@ -932,12 +918,16 @@ kernel void kernel_ssm_scan_f32_group(
|
|||||||
const int64_t ir = tgpig.y; // current head
|
const int64_t ir = tgpig.y; // current head
|
||||||
const int64_t i3 = tgpig.z; // current seq
|
const int64_t i3 = tgpig.z; // current seq
|
||||||
|
|
||||||
|
const uint64_t nb00 = sizeof(float);
|
||||||
|
const uint64_t nb10 = sizeof(float);
|
||||||
|
const uint64_t nb20 = sizeof(float);
|
||||||
|
const uint64_t nb60 = sizeof(float);
|
||||||
|
|
||||||
const int64_t nc = d_state;
|
const int64_t nc = d_state;
|
||||||
const int64_t nr = d_inner;
|
const int64_t nr = d_inner;
|
||||||
const int64_t nh = n_head;
|
const int64_t nh = n_head;
|
||||||
const int64_t ng = n_group;
|
const int64_t ng = n_group;
|
||||||
const int64_t n_t = n_seq_tokens;
|
const int64_t n_t = n_seq_tokens;
|
||||||
const int64_t n_s = n_seqs;
|
|
||||||
|
|
||||||
const int64_t s_off = d_inner * n_head * n_seq_tokens * n_seqs * sizeof(float);
|
const int64_t s_off = d_inner * n_head * n_seq_tokens * n_seqs * sizeof(float);
|
||||||
|
|
||||||
@ -953,7 +943,7 @@ kernel void kernel_ssm_scan_f32_group(
|
|||||||
device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*nb41 + i2*nb42 + i3*nb43); // {d_state, ng, nt, ns}
|
device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*nb41 + i2*nb42 + i3*nb43); // {d_state, ng, nt, ns}
|
||||||
device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*nb51 + i2*nb52 + i3*nb53); // {d_state, ng, nt, ns}
|
device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*nb51 + i2*nb52 + i3*nb53); // {d_state, ng, nt, ns}
|
||||||
device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh}
|
device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh}
|
||||||
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*sizeof(float)); // {dim, nh, nt, ns}
|
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
|
||||||
|
|
||||||
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
|
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
|
||||||
const float x_dt = x[0] * dt_soft_plus;
|
const float x_dt = x[0] * dt_soft_plus;
|
||||||
|
Loading…
Reference in New Issue
Block a user