From 273e7a495ad8c93bb9ba8123c1a3de3c68f93cf9 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Mon, 30 Sep 2024 15:52:42 -0400 Subject: [PATCH] llama : avoid redundant state copy for Mamba 1 and 2 --- ggml/include/ggml.h | 3 +- ggml/src/ggml.c | 50 ++++++------ src/llama.cpp | 154 +++++++++++++++++-------------------- tests/test-backend-ops.cpp | 54 ++++++++++--- 4 files changed, 142 insertions(+), 119 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index fec6798ff..1fc53bebe 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1833,7 +1833,8 @@ extern "C" { struct ggml_tensor * A, struct ggml_tensor * B, struct ggml_tensor * C, - struct ggml_tensor * D); + struct ggml_tensor * D, + struct ggml_tensor * ids); // partition into non-overlapping windows with padding if needed // example: diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 12e4f2694..1c4c393e5 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -7598,7 +7598,8 @@ struct ggml_tensor * ggml_ssm_scan( struct ggml_tensor * A, struct ggml_tensor * B, struct ggml_tensor * C, - struct ggml_tensor * D) { + struct ggml_tensor * D, + struct ggml_tensor * ids) { GGML_ASSERT(ggml_is_contiguous(s)); GGML_ASSERT(ggml_is_contiguous(dt)); GGML_ASSERT(ggml_is_contiguous(A)); @@ -7609,6 +7610,7 @@ struct ggml_tensor * ggml_ssm_scan( GGML_ASSERT(B->nb[1] == B->ne[0]*B->nb[0]); GGML_ASSERT(C->nb[1] == C->ne[0]*C->nb[0]); GGML_ASSERT(ggml_are_same_shape(B, C)); + GGML_ASSERT(ids->type == GGML_TYPE_I32); { const int64_t d_state = s->ne[0]; @@ -7623,21 +7625,19 @@ struct ggml_tensor * ggml_ssm_scan( GGML_ASSERT(ggml_is_3d(dt)); GGML_ASSERT(s->ne[1] == head_dim); GGML_ASSERT(s->ne[2] == n_head); - GGML_ASSERT(s->ne[3] == n_seqs); GGML_ASSERT(B->ne[0] == d_state); GGML_ASSERT(B->ne[2] == n_seq_tokens); GGML_ASSERT(B->ne[3] == n_seqs); GGML_ASSERT(D->ne[0] == n_head); GGML_ASSERT(ggml_is_vector(D)); + GGML_ASSERT(ids->ne[0] == n_seqs); + GGML_ASSERT(ggml_is_vector(ids)); + GGML_ASSERT(A->ne[1] == n_head); + GGML_ASSERT(ggml_is_matrix(A)); - if (ggml_is_vector(A)) { - // Mamba-2 - GGML_ASSERT(A->ne[0] == n_head); - } else { - // Mamba-1 + if (A->ne[0] != 1) { + // Mamba-1 has more granular decay factors GGML_ASSERT(A->ne[0] == d_state); - GGML_ASSERT(A->ne[1] == n_head); - GGML_ASSERT(ggml_is_matrix(A)); } } @@ -7649,7 +7649,7 @@ struct ggml_tensor * ggml_ssm_scan( } // concatenated y + ssm_states - struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s)); + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + s->ne[0]*s->ne[1]*s->ne[2]*ids->ne[0]); result->op = GGML_OP_SSM_SCAN; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; @@ -7660,6 +7660,7 @@ struct ggml_tensor * ggml_ssm_scan( result->src[4] = B; result->src[5] = C; result->src[6] = D; + result->src[7] = ids; return result; } @@ -16635,13 +16636,14 @@ static void ggml_compute_forward_ssm_conv( static void ggml_compute_forward_ssm_scan_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { - const struct ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs} + const struct ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs+} const struct ggml_tensor * src1 = dst->src[1]; // x {dim, n_head, n_seq_tokens, n_seqs} const struct ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs} - const struct ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {n_head} + const struct ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {1, n_head} const struct ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs} const struct ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs} const struct ggml_tensor * src6 = dst->src[6]; // D {n_head} + const struct ggml_tensor * src7 = dst->src[7]; // ids {n_seqs} const int ith = params->ith; const int nth = params->nth; @@ -16651,11 +16653,12 @@ static void ggml_compute_forward_ssm_scan_f32( const int64_t nh = src1->ne[1]; // n_head const int64_t ng = src4->ne[1]; const int64_t nt = src1->ne[2]; // number of tokens per sequence - const int64_t ns = src0->ne[3]; // number of sequences in the batch + const int64_t ns = src1->ne[3]; // number of sequences in the batch - const int64_t s_off = ggml_element_size(src1) * ggml_nelements(src1); + // can't use ggml_nbytes because src1 is not necessarily contiguous + const int64_t s_off = ggml_nelements(src1) * ggml_element_size(src1); - GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst)); + GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*ns == ggml_nelements(dst)); GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src1->nb[0] == sizeof(float)); GGML_ASSERT(src2->nb[0] == sizeof(float)); @@ -16663,6 +16666,7 @@ static void ggml_compute_forward_ssm_scan_f32( GGML_ASSERT(src4->nb[0] == sizeof(float)); GGML_ASSERT(src5->nb[0] == sizeof(float)); GGML_ASSERT(src6->nb[0] == sizeof(float)); + GGML_ASSERT(src7->nb[0] == sizeof(int32_t)); // allows optimizing the modulo since n_group should be a power of 2 GGML_ASSERT((ng & -ng) == ng); @@ -16673,22 +16677,22 @@ static void ggml_compute_forward_ssm_scan_f32( const int ih0 = dh*ith; const int ih1 = MIN(ih0 + dh, nh); + const int32_t * ids = (const int32_t *) src7->data; + for (int i3 = 0; i3 < ns; ++i3) { + const float * s0 = (const float *) ((const char *) src0->data + ids[i3]*(src0->nb[3])); // {d_state, dim, nh, ns} + float * s = ( float *) (( char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns} + for (int i2 = 0; i2 < nt; ++i2) { - const float * s0 = (const float *) ((const char *) src0->data + i3*(src0->nb[3])); // {d_state, dim, nh, ns} const float * x = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns} const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns} - const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {nh} + const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {1, nh} const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns} const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns} const float * D = (const float *) ((const char *) src6->data); // {nh} float * y = ( float *) (( char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns} - float * s = ( float *) (( char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns} - // use the output as the source when it's not the first token-wise iteration - if (i2 > 0) { s0 = s; } - - if (ggml_is_vector(src3)) { + if (src3->ne[0] == 1) { // Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop // n_head @@ -16778,6 +16782,8 @@ static void ggml_compute_forward_ssm_scan_f32( } } } + // use the output as the source when it's not the first token-wise iteration + s0 = s; } } } diff --git a/src/llama.cpp b/src/llama.cpp index c11472112..3e1f8755f 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2801,6 +2801,10 @@ struct llama_kv_cache { // computed before each graph build uint32_t n = 0; + // first zero-ed state + // NOTE: only used by recurrent models + int32_t rs_z = -1; + ggml_type type_k = GGML_TYPE_F16; ggml_type type_v = GGML_TYPE_F16; @@ -3381,8 +3385,6 @@ struct llama_context { struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch] struct ggml_tensor * inp_cls; // I32 [n_batch] struct ggml_tensor * inp_s_copy; // I32 [kv_size] - struct ggml_tensor * inp_s_mask; // F32 [1, n_kv] - struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch] struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch] struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc] struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch] @@ -3813,6 +3815,15 @@ static bool llama_kv_cache_find_slot( } } + // Find first to-be-cleared cell + cache.rs_z = -1; + for (int i = min; i <= max; ++i) { + if (cache.cells[i].src == -1) { + cache.rs_z = i; + break; + } + } + // allow getting the range of used cells, from head to head + n cache.head = min; cache.n = max - min + 1; @@ -9569,36 +9580,42 @@ static struct ggml_tensor * llm_build_kv( return cur; } -static struct ggml_tensor * llm_build_copy_mask_state( +static struct ggml_tensor * llm_build_rs( struct ggml_context * ctx, struct ggml_cgraph * graph, struct ggml_tensor * s, struct ggml_tensor * state_copy, - struct ggml_tensor * state_mask, + int32_t rs_zero, int32_t n_state, int32_t kv_size, int32_t kv_head, int32_t n_kv, - int32_t n_seqs) { + int32_t n_seqs, + bool avoid_copies = false) { struct ggml_tensor * states = ggml_reshape_2d(ctx, s, n_state, kv_size); - // copy states - // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv - // this shrinks the tensors's ne[1] to n_kv - states = ggml_get_rows(ctx, states, state_copy); - - // clear states of sequences which are starting at the beginning of this batch - // FIXME: zero-out NANs? - states = ggml_mul(ctx, states, state_mask); + // Clear a single state which will then be copied to the other cleared states. + // Note that this is a no-op when the view is zero-sized. + struct ggml_tensor * state_zero = ggml_view_1d(ctx, states, n_state*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0)); + ggml_build_forward_expand(graph, ggml_scale_inplace(ctx, state_zero, 0)); // copy states which won't be changed further (between n_seqs and n_kv) + struct ggml_tensor * states_extra = ggml_get_rows(ctx, states, ggml_view_1d(ctx, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0])); ggml_build_forward_expand(graph, ggml_cpy(ctx, - ggml_view_1d(ctx, states, n_state*(n_kv - n_seqs), n_seqs*n_state*ggml_element_size(states)), + states_extra, ggml_view_1d(ctx, s, n_state*(n_kv - n_seqs), (kv_head + n_seqs)*n_state*ggml_element_size(s)))); - // the part of the states that will be used and modified - return ggml_view_2d(ctx, states, n_state, n_seqs, states->nb[1], 0); + if (!avoid_copies) { + // copy states + // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv + // this shrinks the tensors's ne[1] to n_kv + states = ggml_get_rows(ctx, states, ggml_view_1d(ctx, state_copy, n_seqs, 0)); + // the part of the states that will be used and modified + states = ggml_view_2d(ctx, states, n_state, n_seqs, states->nb[1], 0); + } + + return states; } // TODO: split @@ -9609,7 +9626,7 @@ static struct ggml_tensor * llm_build_mamba( struct ggml_cgraph * graph, struct ggml_tensor * cur, struct ggml_tensor * state_copy, - struct ggml_tensor * state_mask, + int32_t rs_zero, int32_t kv_head, int32_t n_kv, const llm_build_cb & cb, @@ -9639,14 +9656,14 @@ static struct ggml_tensor * llm_build_mamba( struct ggml_tensor * ssm_states_all = kv.v_l[il]; // (ab)using the KV cache to store the states - struct ggml_tensor * conv = llm_build_copy_mask_state(ctx, - graph, conv_states_all, state_copy, state_mask, + struct ggml_tensor * conv = llm_build_rs(ctx, + graph, conv_states_all, state_copy, rs_zero, hparams.n_embd_k_s(), kv.size, kv_head, n_kv, n_seqs); conv = ggml_reshape_3d(ctx, conv, d_conv - 1, d_inner, n_seqs); - struct ggml_tensor * ssm = llm_build_copy_mask_state(ctx, - graph, ssm_states_all, state_copy, state_mask, - hparams.n_embd_v_s(), kv.size, kv_head, n_kv, n_seqs); - ssm = ggml_reshape_4d(ctx, ssm, d_state, head_dim, n_head, n_seqs); + struct ggml_tensor * ssm = llm_build_rs(ctx, + graph, ssm_states_all, state_copy, rs_zero, + hparams.n_embd_v_s(), kv.size, kv_head, n_kv, n_seqs, true); + ssm = ggml_reshape_4d(ctx, ssm, d_state, head_dim, n_head, kv.size); // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} cur = ggml_reshape_3d(ctx, cur, cur->ne[0], n_seq_tokens, n_seqs); @@ -9711,10 +9728,11 @@ static struct ggml_tensor * llm_build_mamba( x = ggml_reshape_4d(ctx, x, head_dim, n_head, n_seq_tokens, n_seqs); + struct ggml_tensor * ssm_ids = ggml_view_1d(ctx, state_copy, n_seqs, 0); // Custom operator to optimize the parallel associative scan // as described in the Annex D of the Mamba paper. // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} - struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C, model.layers[il].ssm_d); + struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C, model.layers[il].ssm_d, ssm_ids); // store last states ggml_build_forward_expand(graph, @@ -9746,7 +9764,7 @@ static struct ggml_tensor * llm_build_mamba2( struct ggml_cgraph * graph, struct ggml_tensor * cur, struct ggml_tensor * state_copy, - struct ggml_tensor * state_mask, + int32_t rs_zero, int32_t kv_head, int32_t n_kv, const llm_build_cb & cb, @@ -9772,14 +9790,14 @@ static struct ggml_tensor * llm_build_mamba2( struct ggml_tensor * ssm_states_all = kv.v_l[il]; // (ab)using the KV cache to store the states - struct ggml_tensor * conv = llm_build_copy_mask_state(ctx, - graph, conv_states_all, state_copy, state_mask, + struct ggml_tensor * conv = llm_build_rs(ctx, + graph, conv_states_all, state_copy, rs_zero, hparams.n_embd_k_s(), kv.size, kv_head, n_kv, n_seqs); conv = ggml_reshape_3d(ctx, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs); - struct ggml_tensor * ssm = llm_build_copy_mask_state(ctx, - graph, ssm_states_all, state_copy, state_mask, - hparams.n_embd_v_s(), kv.size, kv_head, n_kv, n_seqs); - ssm = ggml_reshape_4d(ctx, ssm, d_state, head_dim, n_head, n_seqs); + struct ggml_tensor * ssm = llm_build_rs(ctx, + graph, ssm_states_all, state_copy, rs_zero, + hparams.n_embd_v_s(), kv.size, kv_head, n_kv, n_seqs, true); + ssm = ggml_reshape_4d(ctx, ssm, d_state, head_dim, n_head, kv.size); // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} cur = ggml_reshape_3d(ctx, cur, cur->ne[0], n_seq_tokens, n_seqs); @@ -9835,9 +9853,12 @@ static struct ggml_tensor * llm_build_mamba2( // {n_head, n_seq_tokens, n_seqs} dt = ggml_add(ctx, dt, model.layers[il].ssm_dt_b); + struct ggml_tensor * ssm_ids = ggml_view_1d(ctx, state_copy, n_seqs, 0); + // Use the same shape semantics for A as Mamba-1 + struct ggml_tensor * A = ggml_reshape_2d(ctx, model.layers[il].ssm_a, 1, n_head); // TODO: use semistructured matrices to implement state-space duality // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} - struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C, model.layers[il].ssm_d); + struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, model.layers[il].ssm_d, ssm_ids); // store last states ggml_build_forward_expand(graph, @@ -10069,6 +10090,7 @@ struct llm_build_context { const int32_t n_outputs; const int32_t n_outputs_enc; const int32_t kv_head; // index of where we store new KV data in the cache + const int32_t rs_zero; // the first zero-ed recurrent state const int32_t n_ctx_orig; const bool flash_attn; @@ -10119,6 +10141,7 @@ struct llm_build_context { n_outputs (worst_case ? n_tokens : lctx.n_outputs), n_outputs_enc (worst_case ? n_tokens : lctx.embd_enc.size() / hparams.n_embd), kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head), + rs_zero (kv_self.rs_z), n_ctx_orig (cparams.n_ctx_orig_yarn), flash_attn (cparams.flash_attn), pooling_type (cparams.pooling_type), @@ -10147,8 +10170,6 @@ struct llm_build_context { lctx.inp_mean = nullptr; lctx.inp_cls = nullptr; lctx.inp_s_copy = nullptr; - lctx.inp_s_mask = nullptr; - lctx.inp_s_seq = nullptr; lctx.inp_pos_bucket = nullptr; lctx.inp_embd_enc = nullptr; lctx.inp_KQ_mask_cross = nullptr; @@ -10332,13 +10353,6 @@ struct llm_build_context { return lctx.inp_s_copy; } - struct ggml_tensor * build_inp_s_mask() { - lctx.inp_s_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv); - cb(lctx.inp_s_mask, "inp_s_mask", -1); - ggml_set_input(lctx.inp_s_mask); - return lctx.inp_s_mask; - } - struct ggml_cgraph * append_pooling(struct ggml_cgraph * gf) { // find result_norm tensor for input struct ggml_tensor * inp = nullptr; @@ -13901,7 +13915,6 @@ struct llm_build_context { inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); struct ggml_tensor * state_copy = build_inp_s_copy(); - struct ggml_tensor * state_mask = build_inp_s_mask(); for (int il = 0; il < n_layer; ++il) { // norm @@ -13912,15 +13925,13 @@ struct llm_build_context { switch (version) { case 2: - cur = llm_build_mamba2(ctx0, lctx, batch, gf, cur, - state_copy, state_mask, - kv_head, n_kv, cb, il); + cur = llm_build_mamba2(ctx0, lctx, batch, gf, cur, state_copy, + rs_zero, kv_head, n_kv, cb, il); break; case 1: default: - cur = llm_build_mamba(ctx0, lctx, batch, gf, cur, - state_copy, state_mask, - kv_head, n_kv, cb, il); + cur = llm_build_mamba(ctx0, lctx, batch, gf, cur, state_copy, + rs_zero, kv_head, n_kv, cb, il); break; } @@ -15946,7 +15957,6 @@ struct llm_build_context { struct ggml_tensor * cur; struct ggml_tensor * inpL; struct ggml_tensor * state_copy = build_inp_s_copy(); - struct ggml_tensor * state_mask = build_inp_s_mask(); inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); inpL = llm_build_norm(ctx0, inpL, hparams, model.tok_norm, model.tok_norm_b, LLM_NORM, cb, -1); @@ -15955,11 +15965,11 @@ struct llm_build_context { const llama_layer * layer = &model.layers[il]; // (ab)using the KV cache to store the states - struct ggml_tensor * token_shift = llm_build_copy_mask_state(ctx0, - gf, kv_self.k_l[il], state_copy, state_mask, + struct ggml_tensor * token_shift = llm_build_rs(ctx0, + gf, kv_self.k_l[il], state_copy, rs_zero, hparams.n_embd_k_s(), kv_self.size, kv_head, n_kv, n_seqs); - struct ggml_tensor * wkv_states = llm_build_copy_mask_state(ctx0, - gf, kv_self.v_l[il], state_copy, state_mask, + struct ggml_tensor * wkv_states = llm_build_rs(ctx0, + gf, kv_self.v_l[il], state_copy, rs_zero, hparams.n_embd_v_s(), kv_self.size, kv_head, n_kv, n_seqs); cur = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs); @@ -16329,18 +16339,6 @@ static void llama_set_k_shift(llama_context & lctx) { } } -static void llama_set_s_copy(llama_context & lctx) { - const int64_t kv_size = lctx.kv_self.size; - - assert(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer)); - - int32_t * data = (int32_t *) lctx.inp_s_copy->data; - - for (int i = 0; i < kv_size; ++i) { - data[i] = lctx.kv_self.cells[i].src; - } -} - static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) { // TODO move to hparams if a T5 variant appears that uses a different value const int64_t max_distance = 128; @@ -16656,24 +16654,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { if (kv_self.recurrent) { const int64_t n_kv = kv_self.n; - if (lctx.inp_s_mask) { - GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer)); - float * data = (float *) lctx.inp_s_mask->data; - - // clear unused states - for (int i = 0; i < n_kv; ++i) { - const uint32_t cell_id = i + kv_self.head; - llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id]; - - data[i] = (float) (kv_cell.src >= 0); - - // only clear once - if (kv_cell.src < 0) { - kv_cell.src = cell_id; - } - } - } - if (lctx.inp_s_copy) { GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer)); int32_t * data = (int32_t *) lctx.inp_s_copy->data; @@ -16683,8 +16663,12 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { const uint32_t cell_id = i + kv_self.head; llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id]; - // prevent out-of-bound sources - if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self.size) { + if (kv_cell.src < 0) { + GGML_ASSERT(kv_self.rs_z >= 0); // Need a valid zero-ed cell as a source + kv_cell.src = kv_self.rs_z; + } + if ((uint32_t) kv_cell.src >= kv_self.size) { + // ignore out-of-bound sources kv_cell.src = cell_id; } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index aa7896def..092639eed 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1530,27 +1530,58 @@ struct test_ssm_scan : public test_case { const int64_t d_state; const int64_t d_inner; + const int64_t n_head; + const int64_t n_group; const int64_t n_seq_tokens; const int64_t n_seqs; std::string vars() override { - return VARS_TO_STR5(type, d_state, d_inner, n_seq_tokens, n_seqs); + return VARS_TO_STR7(type, d_state, d_inner, n_head, n_group, n_seq_tokens, n_seqs); } test_ssm_scan(ggml_type type = GGML_TYPE_F32, - int64_t d_state = 32, int64_t d_inner = 32, int64_t n_seq_tokens = 32, int64_t n_seqs = 32) - : type(type), d_state(d_state), d_inner(d_inner), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {} + int64_t d_state = 32, + int64_t d_inner = 1, // non-zero for Mamba-2 + int64_t n_head = 32, + int64_t n_group = 1, + int64_t n_seq_tokens = 32, + int64_t n_seqs = 32) + : type(type), d_state(d_state), d_inner(d_inner), n_head(n_head), n_group(n_group), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {} ggml_tensor * build_graph(ggml_context * ctx) override { - ggml_tensor * s = ggml_new_tensor(ctx, type, 4, std::vector{ d_state, d_inner, n_seqs, 1 }.data()); - ggml_tensor * x = ggml_new_tensor(ctx, type, 4, std::vector{ d_inner, n_seq_tokens, n_seqs, 1 }.data()); - ggml_tensor * dt = ggml_new_tensor(ctx, type, 4, std::vector{ d_inner, n_seq_tokens, n_seqs, 1 }.data()); - ggml_tensor * A = ggml_new_tensor(ctx, type, 4, std::vector{ d_state, d_inner, 1 , 1 }.data()); - ggml_tensor * B = ggml_new_tensor(ctx, type, 4, std::vector{ d_state, n_seq_tokens, n_seqs, 1 }.data()); - ggml_tensor * C = ggml_new_tensor(ctx, type, 4, std::vector{ d_state, n_seq_tokens, n_seqs, 1 }.data()); - ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C); + ggml_tensor * s = ggml_new_tensor_4d(ctx, type, d_state, d_inner, n_head, n_seqs); + ggml_tensor * x = ggml_new_tensor_4d(ctx, type, d_inner, n_head, n_seq_tokens, n_seqs); + ggml_tensor * dt = ggml_new_tensor_3d(ctx, type, n_head, n_seq_tokens, n_seqs); + ggml_tensor * A = ggml_new_tensor_2d(ctx, type, (d_inner > 1) ? 1 : d_state, n_head); + ggml_tensor * B = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs); + ggml_tensor * C = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs); + ggml_tensor * D = ggml_new_tensor_1d(ctx, type, n_head); + ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_seqs); + ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C, D, ids); return out; } + + // similar to test_mul_mat_id + void initialize_tensors(ggml_context * ctx) override { + std::random_device rd; + std::default_random_engine rng(rd()); + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + if (t->type == GGML_TYPE_I32) { + if (ggml_is_view_op(t->op)) { continue; } + // ids + for (int64_t r = 0; r < ggml_nrows(t); r++) { + std::vector data(t->ne[0]); + for (int i = 0; i < t->ne[0]; i++) { + data[i] = i; + } + std::shuffle(data.begin(), data.end(), rng); + ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t)); + } + } else { + init_tensor_uniform(t); + } + } + } }; // GGML_OP_MUL_MAT @@ -3255,7 +3286,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {8, 1536, 1, 1}, {4, 1536, 1, 1})); test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 4, 1}, {4, 1536, 1, 1})); - test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1024, 32, 4)); + test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1, 1024, 1, 32, 4)); // Mamba-1 + test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 32, 32, 2, 32, 4)); // Mamba-2 #if 1 for (ggml_type type_a : base_types) {