diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index f307b1ac6..f0a63d921 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -264,6 +264,12 @@ class Model: return [(self.map_tensor_name(name), data_torch)] + # TODO: merge into modify_tensors? (need to check tensor shapes for all arches before doing that) + def reshape_tensors(self, data_torch: Tensor, new_name: str, bid: int | None) -> Tensor: + del new_name, bid # unused + + return data_torch.squeeze() + def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool: del name, new_name, bid, n_dims # unused @@ -295,7 +301,7 @@ class Model: break for new_name, data_torch in (self.modify_tensors(data_torch, name, bid)): - data = data_torch.squeeze().numpy() + data = self.reshape_tensors(data_torch, new_name, bid).numpy() # if data ends up empty, it means data_torch was a scalar tensor -> restore if len(data.shape) == 0: @@ -3063,6 +3069,24 @@ class Mamba2Model(Model): yield (new_name, data_torch) + def reshape_tensors(self, data_torch: Tensor, new_name: str, bid: int | None) -> Tensor: + if any(self.match_model_tensor_name(new_name, t, bid, suffix="") for t in [ + gguf.MODEL_TENSOR.SSM_A, + gguf.MODEL_TENSOR.SSM_D, + ]): + # unsqueeze A to use similar shape semantics as Mamba-1 + # (D is also unsqueezed, but for more straightforward broadcast internally) + return data_torch.reshape((*data_torch.shape, 1)) + + elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_NORM, bid): + d_model = self.find_hparam(["hidden_size", "d_model", "dim"]) + d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model + n_group = self.hparams.get("n_groups", 1) + return data_torch.reshape((n_group, d_inner // n_group)) + + return data_torch.squeeze() + + @Model.register("CohereForCausalLM") class CommandR2Model(Model): diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 0d2e5cb01..735f56b00 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1828,7 +1828,6 @@ extern "C" { struct ggml_tensor * A, struct ggml_tensor * B, struct ggml_tensor * C, - struct ggml_tensor * D, struct ggml_tensor * ids); // partition into non-overlapping windows with padding if needed diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 73e2fedc3..902728d8e 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -1649,25 +1649,21 @@ static void ggml_metal_encode_node( struct ggml_tensor * src4 = node->src[4]; struct ggml_tensor * src5 = node->src[5]; struct ggml_tensor * src6 = node->src[6]; - struct ggml_tensor * src7 = node->src[7]; GGML_ASSERT(src3); GGML_ASSERT(src4); GGML_ASSERT(src5); GGML_ASSERT(src6); - GGML_ASSERT(src7); size_t offs_src3 = 0; size_t offs_src4 = 0; size_t offs_src5 = 0; size_t offs_src6 = 0; - size_t offs_src7 = 0; id id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil; id id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil; id id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil; id id_src6 = src6 ? ggml_metal_get_buffer(src6, &offs_src6) : nil; - id id_src7 = src7 ? ggml_metal_get_buffer(src7, &offs_src7) : nil; const int64_t ne30 = src3->ne[0]; const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31); @@ -1699,10 +1695,6 @@ static void ggml_metal_encode_node( const uint64_t nb60 = src6->nb[0]; GGML_UNUSED(nb60); - const int64_t ne70 = src7->ne[0]; GGML_UNUSED(ne70); - - const uint64_t nb70 = src7->nb[0]; GGML_UNUSED(nb70); - const int64_t d_state = ne00; const int64_t d_inner = ne01; const int64_t n_head = ne02; @@ -1727,31 +1719,30 @@ static void ggml_metal_encode_node( [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4]; [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5]; [encoder setBuffer:id_src6 offset:offs_src6 atIndex:6]; - [encoder setBuffer:id_src7 offset:offs_src7 atIndex:7]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:8]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:7]; - [encoder setBytes:&d_state length:sizeof(d_state) atIndex:9]; - [encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:10]; - [encoder setBytes:&n_head length:sizeof(n_head) atIndex:11]; - [encoder setBytes:&n_group length:sizeof(n_group) atIndex:12]; - [encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:13]; - [encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:14]; + [encoder setBytes:&d_state length:sizeof(d_state) atIndex:8]; + [encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:9]; + [encoder setBytes:&n_head length:sizeof(n_head) atIndex:10]; + [encoder setBytes:&n_group length:sizeof(n_group) atIndex:11]; + [encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:12]; + [encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:13]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:15]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:16]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:17]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:20]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:21]; - [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:22]; - [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:23]; - [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24]; - [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25]; - [encoder setBytes:&nb43 length:sizeof(nb43) atIndex:26]; - [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27]; - [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28]; - [encoder setBytes:&nb53 length:sizeof(nb53) atIndex:29]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:14]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:15]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:16]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:17]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:18]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:19]; + [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:20]; + [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:21]; + [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22]; + [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:23]; + [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:24]; + [encoder setBytes:&nb43 length:sizeof(nb43) atIndex:25]; + [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:26]; + [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:27]; + [encoder setBytes:&nb53 length:sizeof(nb53) atIndex:28]; // NOTE: max index is 31 if (ne30 == 1) { diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 2f5a4d12e..05d04e8f3 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -805,7 +805,6 @@ kernel void kernel_ssm_scan_f32( device const void * src4, device const void * src5, device const void * src6, - device const void * src7, device float * dst, constant int64_t & d_state, constant int64_t & d_inner, @@ -838,7 +837,6 @@ kernel void kernel_ssm_scan_f32( const uint64_t nb00 = sizeof(float); const uint64_t nb10 = sizeof(float); const uint64_t nb20 = sizeof(float); - const uint64_t nb60 = sizeof(float); const int64_t nc = d_state; const int64_t nr = d_inner; @@ -848,7 +846,7 @@ kernel void kernel_ssm_scan_f32( const int64_t s_off = d_inner * n_head * n_seq_tokens * n_seqs * sizeof(float); - device const int32_t * ids = (device const int32_t *) src7; + device const int32_t * ids = (device const int32_t *) src6; device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb02 + ids[i3]*nb03); device float * s = (device float *) ((device char *) dst + ir*nb02 + i3*nb03 + s_off); @@ -859,7 +857,6 @@ kernel void kernel_ssm_scan_f32( device const float * A = (device const float *) ((device const char *) src3 + ir*nb31); // {d_state, nh} device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*nb41 + i2*nb42 + i3*nb43); // {d_state, ng, nt, ns} device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*nb51 + i2*nb52 + i3*nb53); // {d_state, ng, nt, ns} - device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh} device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns} const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; @@ -873,7 +870,7 @@ kernel void kernel_ssm_scan_f32( s[i] = state; } - y[0] = sumf + x[0] * D[0]; + y[0] = sumf; // recurse s0 = s; @@ -890,7 +887,6 @@ kernel void kernel_ssm_scan_f32_group( device const void * src4, device const void * src5, device const void * src6, - device const void * src7, device float * dst, constant int64_t & d_state, constant int64_t & d_inner, @@ -923,7 +919,6 @@ kernel void kernel_ssm_scan_f32_group( const uint64_t nb00 = sizeof(float); const uint64_t nb10 = sizeof(float); const uint64_t nb20 = sizeof(float); - const uint64_t nb60 = sizeof(float); const int64_t nc = d_state; const int64_t nr = d_inner; @@ -933,7 +928,7 @@ kernel void kernel_ssm_scan_f32_group( const int64_t s_off = d_inner * n_head * n_seq_tokens * n_seqs * sizeof(float); - device const int32_t * ids = (device const int32_t *) src7; + device const int32_t * ids = (device const int32_t *) src6; device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb02 + ids[i3]*nb03); device float * s = (device float *) ((device char *) dst + ir*nb02 + i3*nb03 + s_off); @@ -944,7 +939,6 @@ kernel void kernel_ssm_scan_f32_group( device const float * A = (device const float *) ((device const char *) src3 + ir*nb31); // {1, nh} device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*nb41 + i2*nb42 + i3*nb43); // {d_state, ng, nt, ns} device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*nb51 + i2*nb52 + i3*nb53); // {d_state, ng, nt, ns} - device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh} device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns} const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; @@ -959,7 +953,7 @@ kernel void kernel_ssm_scan_f32_group( s[i] = state; } - y[0] = sumf + x[0] * D[0]; + y[0] = sumf; // recurse s0 = s; diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 91b256a4c..9036fc0be 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -7181,7 +7181,6 @@ struct ggml_tensor * ggml_ssm_conv( const int64_t n_s = sx->ne[2]; // TODO: maybe support other strides than 1? - // FIXME: this is always true? GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t); GGML_ASSERT(sx->ne[1] == d_inner); GGML_ASSERT(n_t >= 0); @@ -7205,7 +7204,6 @@ struct ggml_tensor * ggml_ssm_scan( struct ggml_tensor * A, struct ggml_tensor * B, struct ggml_tensor * C, - struct ggml_tensor * D, struct ggml_tensor * ids) { GGML_ASSERT(ggml_is_contiguous(s)); GGML_ASSERT(ggml_is_contiguous(dt)); @@ -7235,8 +7233,6 @@ struct ggml_tensor * ggml_ssm_scan( GGML_ASSERT(B->ne[0] == d_state); GGML_ASSERT(B->ne[2] == n_seq_tokens); GGML_ASSERT(B->ne[3] == n_seqs); - GGML_ASSERT(D->ne[0] == n_head); - GGML_ASSERT(ggml_is_vector(D)); GGML_ASSERT(ids->ne[0] == n_seqs); GGML_ASSERT(ggml_is_vector(ids)); GGML_ASSERT(A->ne[1] == n_head); @@ -7258,8 +7254,7 @@ struct ggml_tensor * ggml_ssm_scan( result->src[3] = A; result->src[4] = B; result->src[5] = C; - result->src[6] = D; - result->src[7] = ids; + result->src[6] = ids; return result; } @@ -16217,8 +16212,7 @@ static void ggml_compute_forward_ssm_scan_f32( const struct ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {1, n_head} const struct ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs} const struct ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs} - const struct ggml_tensor * src6 = dst->src[6]; // D {n_head} - const struct ggml_tensor * src7 = dst->src[7]; // ids {n_seqs} + const struct ggml_tensor * src6 = dst->src[6]; // ids {n_seqs} const int ith = params->ith; const int nth = params->nth; @@ -16240,8 +16234,7 @@ static void ggml_compute_forward_ssm_scan_f32( GGML_ASSERT(src3->nb[0] == sizeof(float)); GGML_ASSERT(src4->nb[0] == sizeof(float)); GGML_ASSERT(src5->nb[0] == sizeof(float)); - GGML_ASSERT(src6->nb[0] == sizeof(float)); - GGML_ASSERT(src7->nb[0] == sizeof(int32_t)); + GGML_ASSERT(src6->nb[0] == sizeof(int32_t)); // allows optimizing the modulo since n_group should be a power of 2 GGML_ASSERT((ng & -ng) == ng); @@ -16252,7 +16245,7 @@ static void ggml_compute_forward_ssm_scan_f32( const int ih0 = dh*ith; const int ih1 = MIN(ih0 + dh, nh); - const int32_t * ids = (const int32_t *) src7->data; + const int32_t * ids = (const int32_t *) src6->data; for (int i3 = 0; i3 < ns; ++i3) { const float * s0 = (const float *) ((const char *) src0->data + ids[i3]*(src0->nb[3])); // {d_state, dim, nh, ns} @@ -16264,7 +16257,6 @@ static void ggml_compute_forward_ssm_scan_f32( const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {1, nh} const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns} const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns} - const float * D = (const float *) ((const char *) src6->data); // {nh} float * y = ( float *) (( char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns} if (src3->ne[0] == 1) { @@ -16325,7 +16317,7 @@ static void ggml_compute_forward_ssm_scan_f32( sumf += state * C[ig]; s[i] = state; } - y[ii] = sumf + x[ii] * D[h]; + y[ii] = sumf; } } } else { @@ -16353,7 +16345,7 @@ static void ggml_compute_forward_ssm_scan_f32( sumf += state * C[ig]; s[i] = state; } - y[ii] = sumf + x[ii] * D[h]; + y[ii] = sumf; } } } diff --git a/src/llama.cpp b/src/llama.cpp index e84510ce8..52052caf2 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -7120,6 +7120,7 @@ static const std::map llm_tensor_info_mapping = { {LLM_TENSOR_SSM_CONV1D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}}, {LLM_TENSOR_SSM_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_SCAN}}, {LLM_TENSOR_SSM_D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_SSM_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_TIME_MIX_LERP_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, @@ -7227,23 +7228,27 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w } break; case GGML_OP_SSM_CONV: { - // FIXME - ggml_tensor * conv_x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 12345, w->ne[1], 6789); + const int64_t n_seq_tokens = 512; + const int64_t n_seqs = 3; + ggml_tensor * conv_x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, w->ne[0] - 1 + n_seq_tokens, w->ne[1], n_seqs); op_tensor = ggml_ssm_conv(ctx, conv_x, w); } break; case GGML_OP_SSM_SCAN: { - // FIXME - const int64_t d_state = w->ne[0]; - const int64_t d_inner = w->ne[1]; + // w is ssm_a + const int64_t d_state = w->ne[0] == 1 ? hparams.ssm_d_state : w->ne[0]; + const int64_t n_head = w->ne[1]; + const int64_t head_dim = hparams.ssm_d_inner / n_head; + const int64_t n_group = hparams.ssm_n_group; const int64_t n_seq_tokens = 512; - const int64_t n_seqs = 1; - ggml_tensor * s = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, d_inner, n_seqs); - ggml_tensor * x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_seq_tokens, n_seqs); - ggml_tensor * dt = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_seq_tokens, n_seqs); - ggml_tensor * B = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs); - ggml_tensor * C = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs); - op_tensor = ggml_ssm_scan(ctx, s, x, dt, w, B, C); + const int64_t n_seqs = 3; + ggml_tensor * s = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, head_dim, n_head, n_seqs); + ggml_tensor * x = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, head_dim, n_head, n_seq_tokens, n_seqs); + ggml_tensor * dt = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_head, n_seq_tokens, n_seqs); + ggml_tensor * B = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, n_group, n_seq_tokens, n_seqs); + ggml_tensor * C = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, n_group, n_seq_tokens, n_seqs); + ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_seqs); + op_tensor = ggml_ssm_scan(ctx, s, x, dt, w, B, C, ids); } break; case GGML_OP_RWKV_WKV: { @@ -8572,10 +8577,10 @@ static bool llm_load_tensors( layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_head}, 0); // no "weight" suffix for these - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {n_head}, 0); - layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {n_head}, 0); + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_head}, 0); + layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, n_head}, 0); - layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner}, 0); + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0); // out_proj layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); @@ -9994,7 +9999,7 @@ static struct ggml_tensor * llm_build_rs( return states; } -// TODO: split +// TODO: split conv and ssm static struct ggml_tensor * llm_build_mamba( struct ggml_context * ctx, struct llama_context & lctx, @@ -10102,13 +10107,14 @@ static struct ggml_tensor * llm_build_mamba( dt = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_dt, dt); dt = ggml_add(ctx, dt, model.layers[il].ssm_dt_b); + cur = x; x = ggml_reshape_4d(ctx, x, head_dim, n_head, n_seq_tokens, n_seqs); struct ggml_tensor * ssm_ids = ggml_view_1d(ctx, state_copy, n_seqs, 0); // Custom operator to optimize the parallel associative scan // as described in the Annex D of the Mamba paper. // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} - struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C, model.layers[il].ssm_d, ssm_ids); + struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C, ssm_ids); // store last states ggml_build_forward_expand(graph, @@ -10120,6 +10126,7 @@ static struct ggml_tensor * llm_build_mamba( // TODO: skip computing output earlier for unused tokens + y = ggml_add(ctx, y, ggml_mul(ctx, cur, model.layers[il].ssm_d)); y = ggml_mul(ctx, y, ggml_silu(ctx, ggml_cont(ctx, z))); // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} @@ -10184,7 +10191,7 @@ static struct ggml_tensor * llm_build_mamba2( struct ggml_tensor * zxBCdt = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_in, cur); // split the above in three - struct ggml_tensor * z = ggml_view_3d(ctx, zxBCdt, d_inner, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], 0); + struct ggml_tensor * z = ggml_view_4d(ctx, zxBCdt, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*zxBCdt->nb[0], zxBCdt->nb[1], zxBCdt->nb[2], 0); struct ggml_tensor * xBC = ggml_view_3d(ctx, zxBCdt, d_inner + 2*n_group*d_state, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], d_inner*ggml_element_size(zxBCdt)); struct ggml_tensor * dt = ggml_view_3d(ctx, zxBCdt, n_head, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], (2*d_inner + 2*n_group*d_state)*ggml_element_size(zxBCdt)); @@ -10230,11 +10237,9 @@ static struct ggml_tensor * llm_build_mamba2( dt = ggml_add(ctx, dt, model.layers[il].ssm_dt_b); struct ggml_tensor * ssm_ids = ggml_view_1d(ctx, state_copy, n_seqs, 0); - // Use the same shape semantics for A as Mamba-1 - struct ggml_tensor * A = ggml_reshape_2d(ctx, model.layers[il].ssm_a, 1, n_head); // TODO: use semistructured matrices to implement state-space duality // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} - struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, model.layers[il].ssm_d, ssm_ids); + struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C, ssm_ids); // store last states ggml_build_forward_expand(graph, @@ -10242,17 +10247,16 @@ static struct ggml_tensor * llm_build_mamba2( ggml_view_1d(ctx, y_ssm, d_state*d_inner*n_seqs, ggml_nelements(x)*x->nb[0]), ggml_view_1d(ctx, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all)))); - struct ggml_tensor * y = ggml_view_3d(ctx, y_ssm, d_inner, n_seq_tokens, n_seqs, n_head*x->nb[1], n_seq_tokens*n_head*x->nb[1], 0); + struct ggml_tensor * y = ggml_view_4d(ctx, y_ssm, head_dim, n_head, n_seq_tokens, n_seqs, x->nb[1], n_head*x->nb[1], n_seq_tokens*n_head*x->nb[1], 0); // TODO: skip computing output earlier for unused tokens + y = ggml_add(ctx, y, ggml_mul(ctx, x, model.layers[il].ssm_d)); y = ggml_mul(ctx, y, ggml_silu(ctx, ggml_cont(ctx, z))); // grouped RMS norm y = ggml_reshape_4d(ctx, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs); - y = llm_build_norm(ctx, y, hparams, - ggml_reshape_2d(ctx, model.layers[il].ssm_norm, d_inner / n_group, n_group), NULL, - LLM_NORM_RMS, cb, il); + y = llm_build_norm(ctx, y, hparams, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, cb, il); y = ggml_reshape_3d(ctx, y, d_inner, n_seq_tokens, n_seqs); // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 6ca254a45..95f8abbd8 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1589,35 +1589,34 @@ struct test_ssm_scan : public test_case { const ggml_type type; const int64_t d_state; - const int64_t d_inner; + const int64_t head_dim; const int64_t n_head; const int64_t n_group; const int64_t n_seq_tokens; const int64_t n_seqs; std::string vars() override { - return VARS_TO_STR7(type, d_state, d_inner, n_head, n_group, n_seq_tokens, n_seqs); + return VARS_TO_STR7(type, d_state, head_dim, n_head, n_group, n_seq_tokens, n_seqs); } test_ssm_scan(ggml_type type = GGML_TYPE_F32, int64_t d_state = 32, - int64_t d_inner = 1, // non-zero for Mamba-2 + int64_t head_dim = 1, // non-zero for Mamba-2 int64_t n_head = 32, int64_t n_group = 1, int64_t n_seq_tokens = 32, int64_t n_seqs = 32) - : type(type), d_state(d_state), d_inner(d_inner), n_head(n_head), n_group(n_group), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {} + : type(type), d_state(d_state), head_dim(head_dim), n_head(n_head), n_group(n_group), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {} ggml_tensor * build_graph(ggml_context * ctx) override { - ggml_tensor * s = ggml_new_tensor_4d(ctx, type, d_state, d_inner, n_head, n_seqs); - ggml_tensor * x = ggml_new_tensor_4d(ctx, type, d_inner, n_head, n_seq_tokens, n_seqs); - ggml_tensor * dt = ggml_new_tensor_3d(ctx, type, n_head, n_seq_tokens, n_seqs); - ggml_tensor * A = ggml_new_tensor_2d(ctx, type, (d_inner > 1) ? 1 : d_state, n_head); - ggml_tensor * B = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs); - ggml_tensor * C = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs); - ggml_tensor * D = ggml_new_tensor_1d(ctx, type, n_head); - ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_seqs); - ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C, D, ids); + ggml_tensor * s = ggml_new_tensor_4d(ctx, type, d_state, head_dim, n_head, n_seqs); + ggml_tensor * x = ggml_new_tensor_4d(ctx, type, head_dim, n_head, n_seq_tokens, n_seqs); + ggml_tensor * dt = ggml_new_tensor_3d(ctx, type, n_head, n_seq_tokens, n_seqs); + ggml_tensor * A = ggml_new_tensor_2d(ctx, type, (head_dim > 1) ? 1 : d_state, n_head); + ggml_tensor * B = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs); + ggml_tensor * C = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs); + ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_seqs); + ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C, ids); return out; }