llama : avoid redundant state copy for Mamba 1 and 2

This commit is contained in:
Francis Couture-Harpin 2024-09-30 15:52:42 -04:00
parent 0e601cafe9
commit 273e7a495a
4 changed files with 142 additions and 119 deletions

View File

@ -1833,7 +1833,8 @@ extern "C" {
struct ggml_tensor * A, struct ggml_tensor * A,
struct ggml_tensor * B, struct ggml_tensor * B,
struct ggml_tensor * C, struct ggml_tensor * C,
struct ggml_tensor * D); struct ggml_tensor * D,
struct ggml_tensor * ids);
// partition into non-overlapping windows with padding if needed // partition into non-overlapping windows with padding if needed
// example: // example:

View File

@ -7598,7 +7598,8 @@ struct ggml_tensor * ggml_ssm_scan(
struct ggml_tensor * A, struct ggml_tensor * A,
struct ggml_tensor * B, struct ggml_tensor * B,
struct ggml_tensor * C, struct ggml_tensor * C,
struct ggml_tensor * D) { struct ggml_tensor * D,
struct ggml_tensor * ids) {
GGML_ASSERT(ggml_is_contiguous(s)); GGML_ASSERT(ggml_is_contiguous(s));
GGML_ASSERT(ggml_is_contiguous(dt)); GGML_ASSERT(ggml_is_contiguous(dt));
GGML_ASSERT(ggml_is_contiguous(A)); GGML_ASSERT(ggml_is_contiguous(A));
@ -7609,6 +7610,7 @@ struct ggml_tensor * ggml_ssm_scan(
GGML_ASSERT(B->nb[1] == B->ne[0]*B->nb[0]); GGML_ASSERT(B->nb[1] == B->ne[0]*B->nb[0]);
GGML_ASSERT(C->nb[1] == C->ne[0]*C->nb[0]); GGML_ASSERT(C->nb[1] == C->ne[0]*C->nb[0]);
GGML_ASSERT(ggml_are_same_shape(B, C)); GGML_ASSERT(ggml_are_same_shape(B, C));
GGML_ASSERT(ids->type == GGML_TYPE_I32);
{ {
const int64_t d_state = s->ne[0]; const int64_t d_state = s->ne[0];
@ -7623,21 +7625,19 @@ struct ggml_tensor * ggml_ssm_scan(
GGML_ASSERT(ggml_is_3d(dt)); GGML_ASSERT(ggml_is_3d(dt));
GGML_ASSERT(s->ne[1] == head_dim); GGML_ASSERT(s->ne[1] == head_dim);
GGML_ASSERT(s->ne[2] == n_head); GGML_ASSERT(s->ne[2] == n_head);
GGML_ASSERT(s->ne[3] == n_seqs);
GGML_ASSERT(B->ne[0] == d_state); GGML_ASSERT(B->ne[0] == d_state);
GGML_ASSERT(B->ne[2] == n_seq_tokens); GGML_ASSERT(B->ne[2] == n_seq_tokens);
GGML_ASSERT(B->ne[3] == n_seqs); GGML_ASSERT(B->ne[3] == n_seqs);
GGML_ASSERT(D->ne[0] == n_head); GGML_ASSERT(D->ne[0] == n_head);
GGML_ASSERT(ggml_is_vector(D)); GGML_ASSERT(ggml_is_vector(D));
GGML_ASSERT(ids->ne[0] == n_seqs);
GGML_ASSERT(ggml_is_vector(ids));
GGML_ASSERT(A->ne[1] == n_head);
GGML_ASSERT(ggml_is_matrix(A));
if (ggml_is_vector(A)) { if (A->ne[0] != 1) {
// Mamba-2 // Mamba-1 has more granular decay factors
GGML_ASSERT(A->ne[0] == n_head);
} else {
// Mamba-1
GGML_ASSERT(A->ne[0] == d_state); GGML_ASSERT(A->ne[0] == d_state);
GGML_ASSERT(A->ne[1] == n_head);
GGML_ASSERT(ggml_is_matrix(A));
} }
} }
@ -7649,7 +7649,7 @@ struct ggml_tensor * ggml_ssm_scan(
} }
// concatenated y + ssm_states // concatenated y + ssm_states
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s)); struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + s->ne[0]*s->ne[1]*s->ne[2]*ids->ne[0]);
result->op = GGML_OP_SSM_SCAN; result->op = GGML_OP_SSM_SCAN;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@ -7660,6 +7660,7 @@ struct ggml_tensor * ggml_ssm_scan(
result->src[4] = B; result->src[4] = B;
result->src[5] = C; result->src[5] = C;
result->src[6] = D; result->src[6] = D;
result->src[7] = ids;
return result; return result;
} }
@ -16635,13 +16636,14 @@ static void ggml_compute_forward_ssm_conv(
static void ggml_compute_forward_ssm_scan_f32( static void ggml_compute_forward_ssm_scan_f32(
const struct ggml_compute_params * params, const struct ggml_compute_params * params,
struct ggml_tensor * dst) { struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs} const struct ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs+}
const struct ggml_tensor * src1 = dst->src[1]; // x {dim, n_head, n_seq_tokens, n_seqs} const struct ggml_tensor * src1 = dst->src[1]; // x {dim, n_head, n_seq_tokens, n_seqs}
const struct ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs} const struct ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs}
const struct ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {n_head} const struct ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {1, n_head}
const struct ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs} const struct ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs}
const struct ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs} const struct ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs}
const struct ggml_tensor * src6 = dst->src[6]; // D {n_head} const struct ggml_tensor * src6 = dst->src[6]; // D {n_head}
const struct ggml_tensor * src7 = dst->src[7]; // ids {n_seqs}
const int ith = params->ith; const int ith = params->ith;
const int nth = params->nth; const int nth = params->nth;
@ -16651,11 +16653,12 @@ static void ggml_compute_forward_ssm_scan_f32(
const int64_t nh = src1->ne[1]; // n_head const int64_t nh = src1->ne[1]; // n_head
const int64_t ng = src4->ne[1]; const int64_t ng = src4->ne[1];
const int64_t nt = src1->ne[2]; // number of tokens per sequence const int64_t nt = src1->ne[2]; // number of tokens per sequence
const int64_t ns = src0->ne[3]; // number of sequences in the batch const int64_t ns = src1->ne[3]; // number of sequences in the batch
const int64_t s_off = ggml_element_size(src1) * ggml_nelements(src1); // can't use ggml_nbytes because src1 is not necessarily contiguous
const int64_t s_off = ggml_nelements(src1) * ggml_element_size(src1);
GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst)); GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*ns == ggml_nelements(dst));
GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float)); GGML_ASSERT(src1->nb[0] == sizeof(float));
GGML_ASSERT(src2->nb[0] == sizeof(float)); GGML_ASSERT(src2->nb[0] == sizeof(float));
@ -16663,6 +16666,7 @@ static void ggml_compute_forward_ssm_scan_f32(
GGML_ASSERT(src4->nb[0] == sizeof(float)); GGML_ASSERT(src4->nb[0] == sizeof(float));
GGML_ASSERT(src5->nb[0] == sizeof(float)); GGML_ASSERT(src5->nb[0] == sizeof(float));
GGML_ASSERT(src6->nb[0] == sizeof(float)); GGML_ASSERT(src6->nb[0] == sizeof(float));
GGML_ASSERT(src7->nb[0] == sizeof(int32_t));
// allows optimizing the modulo since n_group should be a power of 2 // allows optimizing the modulo since n_group should be a power of 2
GGML_ASSERT((ng & -ng) == ng); GGML_ASSERT((ng & -ng) == ng);
@ -16673,22 +16677,22 @@ static void ggml_compute_forward_ssm_scan_f32(
const int ih0 = dh*ith; const int ih0 = dh*ith;
const int ih1 = MIN(ih0 + dh, nh); const int ih1 = MIN(ih0 + dh, nh);
const int32_t * ids = (const int32_t *) src7->data;
for (int i3 = 0; i3 < ns; ++i3) { for (int i3 = 0; i3 < ns; ++i3) {
const float * s0 = (const float *) ((const char *) src0->data + ids[i3]*(src0->nb[3])); // {d_state, dim, nh, ns}
float * s = ( float *) (( char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns}
for (int i2 = 0; i2 < nt; ++i2) { for (int i2 = 0; i2 < nt; ++i2) {
const float * s0 = (const float *) ((const char *) src0->data + i3*(src0->nb[3])); // {d_state, dim, nh, ns}
const float * x = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns} const float * x = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns}
const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns} const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns}
const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {nh} const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {1, nh}
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns} const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns}
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns} const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns}
const float * D = (const float *) ((const char *) src6->data); // {nh} const float * D = (const float *) ((const char *) src6->data); // {nh}
float * y = ( float *) (( char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns} float * y = ( float *) (( char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns}
float * s = ( float *) (( char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns}
// use the output as the source when it's not the first token-wise iteration if (src3->ne[0] == 1) {
if (i2 > 0) { s0 = s; }
if (ggml_is_vector(src3)) {
// Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop // Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop
// n_head // n_head
@ -16778,6 +16782,8 @@ static void ggml_compute_forward_ssm_scan_f32(
} }
} }
} }
// use the output as the source when it's not the first token-wise iteration
s0 = s;
} }
} }
} }

View File

@ -2801,6 +2801,10 @@ struct llama_kv_cache {
// computed before each graph build // computed before each graph build
uint32_t n = 0; uint32_t n = 0;
// first zero-ed state
// NOTE: only used by recurrent models
int32_t rs_z = -1;
ggml_type type_k = GGML_TYPE_F16; ggml_type type_k = GGML_TYPE_F16;
ggml_type type_v = GGML_TYPE_F16; ggml_type type_v = GGML_TYPE_F16;
@ -3381,8 +3385,6 @@ struct llama_context {
struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch] struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
struct ggml_tensor * inp_cls; // I32 [n_batch] struct ggml_tensor * inp_cls; // I32 [n_batch]
struct ggml_tensor * inp_s_copy; // I32 [kv_size] struct ggml_tensor * inp_s_copy; // I32 [kv_size]
struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch]
struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch] struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch]
struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc] struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch] struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
@ -3813,6 +3815,15 @@ static bool llama_kv_cache_find_slot(
} }
} }
// Find first to-be-cleared cell
cache.rs_z = -1;
for (int i = min; i <= max; ++i) {
if (cache.cells[i].src == -1) {
cache.rs_z = i;
break;
}
}
// allow getting the range of used cells, from head to head + n // allow getting the range of used cells, from head to head + n
cache.head = min; cache.head = min;
cache.n = max - min + 1; cache.n = max - min + 1;
@ -9569,36 +9580,42 @@ static struct ggml_tensor * llm_build_kv(
return cur; return cur;
} }
static struct ggml_tensor * llm_build_copy_mask_state( static struct ggml_tensor * llm_build_rs(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_cgraph * graph, struct ggml_cgraph * graph,
struct ggml_tensor * s, struct ggml_tensor * s,
struct ggml_tensor * state_copy, struct ggml_tensor * state_copy,
struct ggml_tensor * state_mask, int32_t rs_zero,
int32_t n_state, int32_t n_state,
int32_t kv_size, int32_t kv_size,
int32_t kv_head, int32_t kv_head,
int32_t n_kv, int32_t n_kv,
int32_t n_seqs) { int32_t n_seqs,
bool avoid_copies = false) {
struct ggml_tensor * states = ggml_reshape_2d(ctx, s, n_state, kv_size); struct ggml_tensor * states = ggml_reshape_2d(ctx, s, n_state, kv_size);
// copy states // Clear a single state which will then be copied to the other cleared states.
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv // Note that this is a no-op when the view is zero-sized.
// this shrinks the tensors's ne[1] to n_kv struct ggml_tensor * state_zero = ggml_view_1d(ctx, states, n_state*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
states = ggml_get_rows(ctx, states, state_copy); ggml_build_forward_expand(graph, ggml_scale_inplace(ctx, state_zero, 0));
// clear states of sequences which are starting at the beginning of this batch
// FIXME: zero-out NANs?
states = ggml_mul(ctx, states, state_mask);
// copy states which won't be changed further (between n_seqs and n_kv) // copy states which won't be changed further (between n_seqs and n_kv)
struct ggml_tensor * states_extra = ggml_get_rows(ctx, states, ggml_view_1d(ctx, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));
ggml_build_forward_expand(graph, ggml_build_forward_expand(graph,
ggml_cpy(ctx, ggml_cpy(ctx,
ggml_view_1d(ctx, states, n_state*(n_kv - n_seqs), n_seqs*n_state*ggml_element_size(states)), states_extra,
ggml_view_1d(ctx, s, n_state*(n_kv - n_seqs), (kv_head + n_seqs)*n_state*ggml_element_size(s)))); ggml_view_1d(ctx, s, n_state*(n_kv - n_seqs), (kv_head + n_seqs)*n_state*ggml_element_size(s))));
// the part of the states that will be used and modified if (!avoid_copies) {
return ggml_view_2d(ctx, states, n_state, n_seqs, states->nb[1], 0); // copy states
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
// this shrinks the tensors's ne[1] to n_kv
states = ggml_get_rows(ctx, states, ggml_view_1d(ctx, state_copy, n_seqs, 0));
// the part of the states that will be used and modified
states = ggml_view_2d(ctx, states, n_state, n_seqs, states->nb[1], 0);
}
return states;
} }
// TODO: split // TODO: split
@ -9609,7 +9626,7 @@ static struct ggml_tensor * llm_build_mamba(
struct ggml_cgraph * graph, struct ggml_cgraph * graph,
struct ggml_tensor * cur, struct ggml_tensor * cur,
struct ggml_tensor * state_copy, struct ggml_tensor * state_copy,
struct ggml_tensor * state_mask, int32_t rs_zero,
int32_t kv_head, int32_t kv_head,
int32_t n_kv, int32_t n_kv,
const llm_build_cb & cb, const llm_build_cb & cb,
@ -9639,14 +9656,14 @@ static struct ggml_tensor * llm_build_mamba(
struct ggml_tensor * ssm_states_all = kv.v_l[il]; struct ggml_tensor * ssm_states_all = kv.v_l[il];
// (ab)using the KV cache to store the states // (ab)using the KV cache to store the states
struct ggml_tensor * conv = llm_build_copy_mask_state(ctx, struct ggml_tensor * conv = llm_build_rs(ctx,
graph, conv_states_all, state_copy, state_mask, graph, conv_states_all, state_copy, rs_zero,
hparams.n_embd_k_s(), kv.size, kv_head, n_kv, n_seqs); hparams.n_embd_k_s(), kv.size, kv_head, n_kv, n_seqs);
conv = ggml_reshape_3d(ctx, conv, d_conv - 1, d_inner, n_seqs); conv = ggml_reshape_3d(ctx, conv, d_conv - 1, d_inner, n_seqs);
struct ggml_tensor * ssm = llm_build_copy_mask_state(ctx, struct ggml_tensor * ssm = llm_build_rs(ctx,
graph, ssm_states_all, state_copy, state_mask, graph, ssm_states_all, state_copy, rs_zero,
hparams.n_embd_v_s(), kv.size, kv_head, n_kv, n_seqs); hparams.n_embd_v_s(), kv.size, kv_head, n_kv, n_seqs, true);
ssm = ggml_reshape_4d(ctx, ssm, d_state, head_dim, n_head, n_seqs); ssm = ggml_reshape_4d(ctx, ssm, d_state, head_dim, n_head, kv.size);
// {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
cur = ggml_reshape_3d(ctx, cur, cur->ne[0], n_seq_tokens, n_seqs); cur = ggml_reshape_3d(ctx, cur, cur->ne[0], n_seq_tokens, n_seqs);
@ -9711,10 +9728,11 @@ static struct ggml_tensor * llm_build_mamba(
x = ggml_reshape_4d(ctx, x, head_dim, n_head, n_seq_tokens, n_seqs); x = ggml_reshape_4d(ctx, x, head_dim, n_head, n_seq_tokens, n_seqs);
struct ggml_tensor * ssm_ids = ggml_view_1d(ctx, state_copy, n_seqs, 0);
// Custom operator to optimize the parallel associative scan // Custom operator to optimize the parallel associative scan
// as described in the Annex D of the Mamba paper. // as described in the Annex D of the Mamba paper.
// => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C, model.layers[il].ssm_d); struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C, model.layers[il].ssm_d, ssm_ids);
// store last states // store last states
ggml_build_forward_expand(graph, ggml_build_forward_expand(graph,
@ -9746,7 +9764,7 @@ static struct ggml_tensor * llm_build_mamba2(
struct ggml_cgraph * graph, struct ggml_cgraph * graph,
struct ggml_tensor * cur, struct ggml_tensor * cur,
struct ggml_tensor * state_copy, struct ggml_tensor * state_copy,
struct ggml_tensor * state_mask, int32_t rs_zero,
int32_t kv_head, int32_t kv_head,
int32_t n_kv, int32_t n_kv,
const llm_build_cb & cb, const llm_build_cb & cb,
@ -9772,14 +9790,14 @@ static struct ggml_tensor * llm_build_mamba2(
struct ggml_tensor * ssm_states_all = kv.v_l[il]; struct ggml_tensor * ssm_states_all = kv.v_l[il];
// (ab)using the KV cache to store the states // (ab)using the KV cache to store the states
struct ggml_tensor * conv = llm_build_copy_mask_state(ctx, struct ggml_tensor * conv = llm_build_rs(ctx,
graph, conv_states_all, state_copy, state_mask, graph, conv_states_all, state_copy, rs_zero,
hparams.n_embd_k_s(), kv.size, kv_head, n_kv, n_seqs); hparams.n_embd_k_s(), kv.size, kv_head, n_kv, n_seqs);
conv = ggml_reshape_3d(ctx, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs); conv = ggml_reshape_3d(ctx, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs);
struct ggml_tensor * ssm = llm_build_copy_mask_state(ctx, struct ggml_tensor * ssm = llm_build_rs(ctx,
graph, ssm_states_all, state_copy, state_mask, graph, ssm_states_all, state_copy, rs_zero,
hparams.n_embd_v_s(), kv.size, kv_head, n_kv, n_seqs); hparams.n_embd_v_s(), kv.size, kv_head, n_kv, n_seqs, true);
ssm = ggml_reshape_4d(ctx, ssm, d_state, head_dim, n_head, n_seqs); ssm = ggml_reshape_4d(ctx, ssm, d_state, head_dim, n_head, kv.size);
// {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
cur = ggml_reshape_3d(ctx, cur, cur->ne[0], n_seq_tokens, n_seqs); cur = ggml_reshape_3d(ctx, cur, cur->ne[0], n_seq_tokens, n_seqs);
@ -9835,9 +9853,12 @@ static struct ggml_tensor * llm_build_mamba2(
// {n_head, n_seq_tokens, n_seqs} // {n_head, n_seq_tokens, n_seqs}
dt = ggml_add(ctx, dt, model.layers[il].ssm_dt_b); dt = ggml_add(ctx, dt, model.layers[il].ssm_dt_b);
struct ggml_tensor * ssm_ids = ggml_view_1d(ctx, state_copy, n_seqs, 0);
// Use the same shape semantics for A as Mamba-1
struct ggml_tensor * A = ggml_reshape_2d(ctx, model.layers[il].ssm_a, 1, n_head);
// TODO: use semistructured matrices to implement state-space duality // TODO: use semistructured matrices to implement state-space duality
// => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C, model.layers[il].ssm_d); struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, model.layers[il].ssm_d, ssm_ids);
// store last states // store last states
ggml_build_forward_expand(graph, ggml_build_forward_expand(graph,
@ -10069,6 +10090,7 @@ struct llm_build_context {
const int32_t n_outputs; const int32_t n_outputs;
const int32_t n_outputs_enc; const int32_t n_outputs_enc;
const int32_t kv_head; // index of where we store new KV data in the cache const int32_t kv_head; // index of where we store new KV data in the cache
const int32_t rs_zero; // the first zero-ed recurrent state
const int32_t n_ctx_orig; const int32_t n_ctx_orig;
const bool flash_attn; const bool flash_attn;
@ -10119,6 +10141,7 @@ struct llm_build_context {
n_outputs (worst_case ? n_tokens : lctx.n_outputs), n_outputs (worst_case ? n_tokens : lctx.n_outputs),
n_outputs_enc (worst_case ? n_tokens : lctx.embd_enc.size() / hparams.n_embd), n_outputs_enc (worst_case ? n_tokens : lctx.embd_enc.size() / hparams.n_embd),
kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head), kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head),
rs_zero (kv_self.rs_z),
n_ctx_orig (cparams.n_ctx_orig_yarn), n_ctx_orig (cparams.n_ctx_orig_yarn),
flash_attn (cparams.flash_attn), flash_attn (cparams.flash_attn),
pooling_type (cparams.pooling_type), pooling_type (cparams.pooling_type),
@ -10147,8 +10170,6 @@ struct llm_build_context {
lctx.inp_mean = nullptr; lctx.inp_mean = nullptr;
lctx.inp_cls = nullptr; lctx.inp_cls = nullptr;
lctx.inp_s_copy = nullptr; lctx.inp_s_copy = nullptr;
lctx.inp_s_mask = nullptr;
lctx.inp_s_seq = nullptr;
lctx.inp_pos_bucket = nullptr; lctx.inp_pos_bucket = nullptr;
lctx.inp_embd_enc = nullptr; lctx.inp_embd_enc = nullptr;
lctx.inp_KQ_mask_cross = nullptr; lctx.inp_KQ_mask_cross = nullptr;
@ -10332,13 +10353,6 @@ struct llm_build_context {
return lctx.inp_s_copy; return lctx.inp_s_copy;
} }
struct ggml_tensor * build_inp_s_mask() {
lctx.inp_s_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv);
cb(lctx.inp_s_mask, "inp_s_mask", -1);
ggml_set_input(lctx.inp_s_mask);
return lctx.inp_s_mask;
}
struct ggml_cgraph * append_pooling(struct ggml_cgraph * gf) { struct ggml_cgraph * append_pooling(struct ggml_cgraph * gf) {
// find result_norm tensor for input // find result_norm tensor for input
struct ggml_tensor * inp = nullptr; struct ggml_tensor * inp = nullptr;
@ -13901,7 +13915,6 @@ struct llm_build_context {
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
struct ggml_tensor * state_copy = build_inp_s_copy(); struct ggml_tensor * state_copy = build_inp_s_copy();
struct ggml_tensor * state_mask = build_inp_s_mask();
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
// norm // norm
@ -13912,15 +13925,13 @@ struct llm_build_context {
switch (version) { switch (version) {
case 2: case 2:
cur = llm_build_mamba2(ctx0, lctx, batch, gf, cur, cur = llm_build_mamba2(ctx0, lctx, batch, gf, cur, state_copy,
state_copy, state_mask, rs_zero, kv_head, n_kv, cb, il);
kv_head, n_kv, cb, il);
break; break;
case 1: case 1:
default: default:
cur = llm_build_mamba(ctx0, lctx, batch, gf, cur, cur = llm_build_mamba(ctx0, lctx, batch, gf, cur, state_copy,
state_copy, state_mask, rs_zero, kv_head, n_kv, cb, il);
kv_head, n_kv, cb, il);
break; break;
} }
@ -15946,7 +15957,6 @@ struct llm_build_context {
struct ggml_tensor * cur; struct ggml_tensor * cur;
struct ggml_tensor * inpL; struct ggml_tensor * inpL;
struct ggml_tensor * state_copy = build_inp_s_copy(); struct ggml_tensor * state_copy = build_inp_s_copy();
struct ggml_tensor * state_mask = build_inp_s_mask();
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
inpL = llm_build_norm(ctx0, inpL, hparams, model.tok_norm, model.tok_norm_b, LLM_NORM, cb, -1); inpL = llm_build_norm(ctx0, inpL, hparams, model.tok_norm, model.tok_norm_b, LLM_NORM, cb, -1);
@ -15955,11 +15965,11 @@ struct llm_build_context {
const llama_layer * layer = &model.layers[il]; const llama_layer * layer = &model.layers[il];
// (ab)using the KV cache to store the states // (ab)using the KV cache to store the states
struct ggml_tensor * token_shift = llm_build_copy_mask_state(ctx0, struct ggml_tensor * token_shift = llm_build_rs(ctx0,
gf, kv_self.k_l[il], state_copy, state_mask, gf, kv_self.k_l[il], state_copy, rs_zero,
hparams.n_embd_k_s(), kv_self.size, kv_head, n_kv, n_seqs); hparams.n_embd_k_s(), kv_self.size, kv_head, n_kv, n_seqs);
struct ggml_tensor * wkv_states = llm_build_copy_mask_state(ctx0, struct ggml_tensor * wkv_states = llm_build_rs(ctx0,
gf, kv_self.v_l[il], state_copy, state_mask, gf, kv_self.v_l[il], state_copy, rs_zero,
hparams.n_embd_v_s(), kv_self.size, kv_head, n_kv, n_seqs); hparams.n_embd_v_s(), kv_self.size, kv_head, n_kv, n_seqs);
cur = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs); cur = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
@ -16329,18 +16339,6 @@ static void llama_set_k_shift(llama_context & lctx) {
} }
} }
static void llama_set_s_copy(llama_context & lctx) {
const int64_t kv_size = lctx.kv_self.size;
assert(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer));
int32_t * data = (int32_t *) lctx.inp_s_copy->data;
for (int i = 0; i < kv_size; ++i) {
data[i] = lctx.kv_self.cells[i].src;
}
}
static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) { static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
// TODO move to hparams if a T5 variant appears that uses a different value // TODO move to hparams if a T5 variant appears that uses a different value
const int64_t max_distance = 128; const int64_t max_distance = 128;
@ -16656,24 +16654,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
if (kv_self.recurrent) { if (kv_self.recurrent) {
const int64_t n_kv = kv_self.n; const int64_t n_kv = kv_self.n;
if (lctx.inp_s_mask) {
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer));
float * data = (float *) lctx.inp_s_mask->data;
// clear unused states
for (int i = 0; i < n_kv; ++i) {
const uint32_t cell_id = i + kv_self.head;
llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id];
data[i] = (float) (kv_cell.src >= 0);
// only clear once
if (kv_cell.src < 0) {
kv_cell.src = cell_id;
}
}
}
if (lctx.inp_s_copy) { if (lctx.inp_s_copy) {
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer)); GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer));
int32_t * data = (int32_t *) lctx.inp_s_copy->data; int32_t * data = (int32_t *) lctx.inp_s_copy->data;
@ -16683,8 +16663,12 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
const uint32_t cell_id = i + kv_self.head; const uint32_t cell_id = i + kv_self.head;
llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id]; llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id];
// prevent out-of-bound sources if (kv_cell.src < 0) {
if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self.size) { GGML_ASSERT(kv_self.rs_z >= 0); // Need a valid zero-ed cell as a source
kv_cell.src = kv_self.rs_z;
}
if ((uint32_t) kv_cell.src >= kv_self.size) {
// ignore out-of-bound sources
kv_cell.src = cell_id; kv_cell.src = cell_id;
} }

View File

@ -1530,27 +1530,58 @@ struct test_ssm_scan : public test_case {
const int64_t d_state; const int64_t d_state;
const int64_t d_inner; const int64_t d_inner;
const int64_t n_head;
const int64_t n_group;
const int64_t n_seq_tokens; const int64_t n_seq_tokens;
const int64_t n_seqs; const int64_t n_seqs;
std::string vars() override { std::string vars() override {
return VARS_TO_STR5(type, d_state, d_inner, n_seq_tokens, n_seqs); return VARS_TO_STR7(type, d_state, d_inner, n_head, n_group, n_seq_tokens, n_seqs);
} }
test_ssm_scan(ggml_type type = GGML_TYPE_F32, test_ssm_scan(ggml_type type = GGML_TYPE_F32,
int64_t d_state = 32, int64_t d_inner = 32, int64_t n_seq_tokens = 32, int64_t n_seqs = 32) int64_t d_state = 32,
: type(type), d_state(d_state), d_inner(d_inner), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {} int64_t d_inner = 1, // non-zero for Mamba-2
int64_t n_head = 32,
int64_t n_group = 1,
int64_t n_seq_tokens = 32,
int64_t n_seqs = 32)
: type(type), d_state(d_state), d_inner(d_inner), n_head(n_head), n_group(n_group), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * s = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_state, d_inner, n_seqs, 1 }.data()); ggml_tensor * s = ggml_new_tensor_4d(ctx, type, d_state, d_inner, n_head, n_seqs);
ggml_tensor * x = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_inner, n_seq_tokens, n_seqs, 1 }.data()); ggml_tensor * x = ggml_new_tensor_4d(ctx, type, d_inner, n_head, n_seq_tokens, n_seqs);
ggml_tensor * dt = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_inner, n_seq_tokens, n_seqs, 1 }.data()); ggml_tensor * dt = ggml_new_tensor_3d(ctx, type, n_head, n_seq_tokens, n_seqs);
ggml_tensor * A = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_state, d_inner, 1 , 1 }.data()); ggml_tensor * A = ggml_new_tensor_2d(ctx, type, (d_inner > 1) ? 1 : d_state, n_head);
ggml_tensor * B = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_state, n_seq_tokens, n_seqs, 1 }.data()); ggml_tensor * B = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs);
ggml_tensor * C = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_state, n_seq_tokens, n_seqs, 1 }.data()); ggml_tensor * C = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs);
ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C); ggml_tensor * D = ggml_new_tensor_1d(ctx, type, n_head);
ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_seqs);
ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C, D, ids);
return out; return out;
} }
// similar to test_mul_mat_id
void initialize_tensors(ggml_context * ctx) override {
std::random_device rd;
std::default_random_engine rng(rd());
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
if (t->type == GGML_TYPE_I32) {
if (ggml_is_view_op(t->op)) { continue; }
// ids
for (int64_t r = 0; r < ggml_nrows(t); r++) {
std::vector<int32_t> data(t->ne[0]);
for (int i = 0; i < t->ne[0]; i++) {
data[i] = i;
}
std::shuffle(data.begin(), data.end(), rng);
ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t));
}
} else {
init_tensor_uniform(t);
}
}
}
}; };
// GGML_OP_MUL_MAT // GGML_OP_MUL_MAT
@ -3255,7 +3286,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {8, 1536, 1, 1}, {4, 1536, 1, 1})); test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {8, 1536, 1, 1}, {4, 1536, 1, 1}));
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 4, 1}, {4, 1536, 1, 1})); test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 4, 1}, {4, 1536, 1, 1}));
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1024, 32, 4)); test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1, 1024, 1, 32, 4)); // Mamba-1
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 32, 32, 2, 32, 4)); // Mamba-2
#if 1 #if 1
for (ggml_type type_a : base_types) { for (ggml_type type_a : base_types) {