mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-24 10:24:35 +00:00
llama : avoid redundant state copy for Mamba 1 and 2
This commit is contained in:
parent
0e601cafe9
commit
273e7a495a
@ -1833,7 +1833,8 @@ extern "C" {
|
||||
struct ggml_tensor * A,
|
||||
struct ggml_tensor * B,
|
||||
struct ggml_tensor * C,
|
||||
struct ggml_tensor * D);
|
||||
struct ggml_tensor * D,
|
||||
struct ggml_tensor * ids);
|
||||
|
||||
// partition into non-overlapping windows with padding if needed
|
||||
// example:
|
||||
|
@ -7598,7 +7598,8 @@ struct ggml_tensor * ggml_ssm_scan(
|
||||
struct ggml_tensor * A,
|
||||
struct ggml_tensor * B,
|
||||
struct ggml_tensor * C,
|
||||
struct ggml_tensor * D) {
|
||||
struct ggml_tensor * D,
|
||||
struct ggml_tensor * ids) {
|
||||
GGML_ASSERT(ggml_is_contiguous(s));
|
||||
GGML_ASSERT(ggml_is_contiguous(dt));
|
||||
GGML_ASSERT(ggml_is_contiguous(A));
|
||||
@ -7609,6 +7610,7 @@ struct ggml_tensor * ggml_ssm_scan(
|
||||
GGML_ASSERT(B->nb[1] == B->ne[0]*B->nb[0]);
|
||||
GGML_ASSERT(C->nb[1] == C->ne[0]*C->nb[0]);
|
||||
GGML_ASSERT(ggml_are_same_shape(B, C));
|
||||
GGML_ASSERT(ids->type == GGML_TYPE_I32);
|
||||
|
||||
{
|
||||
const int64_t d_state = s->ne[0];
|
||||
@ -7623,21 +7625,19 @@ struct ggml_tensor * ggml_ssm_scan(
|
||||
GGML_ASSERT(ggml_is_3d(dt));
|
||||
GGML_ASSERT(s->ne[1] == head_dim);
|
||||
GGML_ASSERT(s->ne[2] == n_head);
|
||||
GGML_ASSERT(s->ne[3] == n_seqs);
|
||||
GGML_ASSERT(B->ne[0] == d_state);
|
||||
GGML_ASSERT(B->ne[2] == n_seq_tokens);
|
||||
GGML_ASSERT(B->ne[3] == n_seqs);
|
||||
GGML_ASSERT(D->ne[0] == n_head);
|
||||
GGML_ASSERT(ggml_is_vector(D));
|
||||
GGML_ASSERT(ids->ne[0] == n_seqs);
|
||||
GGML_ASSERT(ggml_is_vector(ids));
|
||||
GGML_ASSERT(A->ne[1] == n_head);
|
||||
GGML_ASSERT(ggml_is_matrix(A));
|
||||
|
||||
if (ggml_is_vector(A)) {
|
||||
// Mamba-2
|
||||
GGML_ASSERT(A->ne[0] == n_head);
|
||||
} else {
|
||||
// Mamba-1
|
||||
if (A->ne[0] != 1) {
|
||||
// Mamba-1 has more granular decay factors
|
||||
GGML_ASSERT(A->ne[0] == d_state);
|
||||
GGML_ASSERT(A->ne[1] == n_head);
|
||||
GGML_ASSERT(ggml_is_matrix(A));
|
||||
}
|
||||
}
|
||||
|
||||
@ -7649,7 +7649,7 @@ struct ggml_tensor * ggml_ssm_scan(
|
||||
}
|
||||
|
||||
// concatenated y + ssm_states
|
||||
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s));
|
||||
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + s->ne[0]*s->ne[1]*s->ne[2]*ids->ne[0]);
|
||||
|
||||
result->op = GGML_OP_SSM_SCAN;
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
@ -7660,6 +7660,7 @@ struct ggml_tensor * ggml_ssm_scan(
|
||||
result->src[4] = B;
|
||||
result->src[5] = C;
|
||||
result->src[6] = D;
|
||||
result->src[7] = ids;
|
||||
|
||||
return result;
|
||||
}
|
||||
@ -16635,13 +16636,14 @@ static void ggml_compute_forward_ssm_conv(
|
||||
static void ggml_compute_forward_ssm_scan_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
const struct ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs}
|
||||
const struct ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs+}
|
||||
const struct ggml_tensor * src1 = dst->src[1]; // x {dim, n_head, n_seq_tokens, n_seqs}
|
||||
const struct ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs}
|
||||
const struct ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {n_head}
|
||||
const struct ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {1, n_head}
|
||||
const struct ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs}
|
||||
const struct ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs}
|
||||
const struct ggml_tensor * src6 = dst->src[6]; // D {n_head}
|
||||
const struct ggml_tensor * src7 = dst->src[7]; // ids {n_seqs}
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
@ -16651,11 +16653,12 @@ static void ggml_compute_forward_ssm_scan_f32(
|
||||
const int64_t nh = src1->ne[1]; // n_head
|
||||
const int64_t ng = src4->ne[1];
|
||||
const int64_t nt = src1->ne[2]; // number of tokens per sequence
|
||||
const int64_t ns = src0->ne[3]; // number of sequences in the batch
|
||||
const int64_t ns = src1->ne[3]; // number of sequences in the batch
|
||||
|
||||
const int64_t s_off = ggml_element_size(src1) * ggml_nelements(src1);
|
||||
// can't use ggml_nbytes because src1 is not necessarily contiguous
|
||||
const int64_t s_off = ggml_nelements(src1) * ggml_element_size(src1);
|
||||
|
||||
GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
|
||||
GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*ns == ggml_nelements(dst));
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src2->nb[0] == sizeof(float));
|
||||
@ -16663,6 +16666,7 @@ static void ggml_compute_forward_ssm_scan_f32(
|
||||
GGML_ASSERT(src4->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src5->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src6->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src7->nb[0] == sizeof(int32_t));
|
||||
// allows optimizing the modulo since n_group should be a power of 2
|
||||
GGML_ASSERT((ng & -ng) == ng);
|
||||
|
||||
@ -16673,22 +16677,22 @@ static void ggml_compute_forward_ssm_scan_f32(
|
||||
const int ih0 = dh*ith;
|
||||
const int ih1 = MIN(ih0 + dh, nh);
|
||||
|
||||
const int32_t * ids = (const int32_t *) src7->data;
|
||||
|
||||
for (int i3 = 0; i3 < ns; ++i3) {
|
||||
const float * s0 = (const float *) ((const char *) src0->data + ids[i3]*(src0->nb[3])); // {d_state, dim, nh, ns}
|
||||
float * s = ( float *) (( char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns}
|
||||
|
||||
for (int i2 = 0; i2 < nt; ++i2) {
|
||||
const float * s0 = (const float *) ((const char *) src0->data + i3*(src0->nb[3])); // {d_state, dim, nh, ns}
|
||||
const float * x = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns}
|
||||
const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns}
|
||||
const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {nh}
|
||||
const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {1, nh}
|
||||
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns}
|
||||
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns}
|
||||
const float * D = (const float *) ((const char *) src6->data); // {nh}
|
||||
float * y = ( float *) (( char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns}
|
||||
float * s = ( float *) (( char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns}
|
||||
|
||||
// use the output as the source when it's not the first token-wise iteration
|
||||
if (i2 > 0) { s0 = s; }
|
||||
|
||||
if (ggml_is_vector(src3)) {
|
||||
if (src3->ne[0] == 1) {
|
||||
// Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop
|
||||
|
||||
// n_head
|
||||
@ -16778,6 +16782,8 @@ static void ggml_compute_forward_ssm_scan_f32(
|
||||
}
|
||||
}
|
||||
}
|
||||
// use the output as the source when it's not the first token-wise iteration
|
||||
s0 = s;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
154
src/llama.cpp
154
src/llama.cpp
@ -2801,6 +2801,10 @@ struct llama_kv_cache {
|
||||
// computed before each graph build
|
||||
uint32_t n = 0;
|
||||
|
||||
// first zero-ed state
|
||||
// NOTE: only used by recurrent models
|
||||
int32_t rs_z = -1;
|
||||
|
||||
ggml_type type_k = GGML_TYPE_F16;
|
||||
ggml_type type_v = GGML_TYPE_F16;
|
||||
|
||||
@ -3381,8 +3385,6 @@ struct llama_context {
|
||||
struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
|
||||
struct ggml_tensor * inp_cls; // I32 [n_batch]
|
||||
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
|
||||
struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
|
||||
struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch]
|
||||
struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch]
|
||||
struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
|
||||
struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
|
||||
@ -3813,6 +3815,15 @@ static bool llama_kv_cache_find_slot(
|
||||
}
|
||||
}
|
||||
|
||||
// Find first to-be-cleared cell
|
||||
cache.rs_z = -1;
|
||||
for (int i = min; i <= max; ++i) {
|
||||
if (cache.cells[i].src == -1) {
|
||||
cache.rs_z = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// allow getting the range of used cells, from head to head + n
|
||||
cache.head = min;
|
||||
cache.n = max - min + 1;
|
||||
@ -9569,36 +9580,42 @@ static struct ggml_tensor * llm_build_kv(
|
||||
return cur;
|
||||
}
|
||||
|
||||
static struct ggml_tensor * llm_build_copy_mask_state(
|
||||
static struct ggml_tensor * llm_build_rs(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_cgraph * graph,
|
||||
struct ggml_tensor * s,
|
||||
struct ggml_tensor * state_copy,
|
||||
struct ggml_tensor * state_mask,
|
||||
int32_t rs_zero,
|
||||
int32_t n_state,
|
||||
int32_t kv_size,
|
||||
int32_t kv_head,
|
||||
int32_t n_kv,
|
||||
int32_t n_seqs) {
|
||||
int32_t n_seqs,
|
||||
bool avoid_copies = false) {
|
||||
struct ggml_tensor * states = ggml_reshape_2d(ctx, s, n_state, kv_size);
|
||||
|
||||
// copy states
|
||||
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
|
||||
// this shrinks the tensors's ne[1] to n_kv
|
||||
states = ggml_get_rows(ctx, states, state_copy);
|
||||
|
||||
// clear states of sequences which are starting at the beginning of this batch
|
||||
// FIXME: zero-out NANs?
|
||||
states = ggml_mul(ctx, states, state_mask);
|
||||
// Clear a single state which will then be copied to the other cleared states.
|
||||
// Note that this is a no-op when the view is zero-sized.
|
||||
struct ggml_tensor * state_zero = ggml_view_1d(ctx, states, n_state*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
|
||||
ggml_build_forward_expand(graph, ggml_scale_inplace(ctx, state_zero, 0));
|
||||
|
||||
// copy states which won't be changed further (between n_seqs and n_kv)
|
||||
struct ggml_tensor * states_extra = ggml_get_rows(ctx, states, ggml_view_1d(ctx, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));
|
||||
ggml_build_forward_expand(graph,
|
||||
ggml_cpy(ctx,
|
||||
ggml_view_1d(ctx, states, n_state*(n_kv - n_seqs), n_seqs*n_state*ggml_element_size(states)),
|
||||
states_extra,
|
||||
ggml_view_1d(ctx, s, n_state*(n_kv - n_seqs), (kv_head + n_seqs)*n_state*ggml_element_size(s))));
|
||||
|
||||
// the part of the states that will be used and modified
|
||||
return ggml_view_2d(ctx, states, n_state, n_seqs, states->nb[1], 0);
|
||||
if (!avoid_copies) {
|
||||
// copy states
|
||||
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
|
||||
// this shrinks the tensors's ne[1] to n_kv
|
||||
states = ggml_get_rows(ctx, states, ggml_view_1d(ctx, state_copy, n_seqs, 0));
|
||||
// the part of the states that will be used and modified
|
||||
states = ggml_view_2d(ctx, states, n_state, n_seqs, states->nb[1], 0);
|
||||
}
|
||||
|
||||
return states;
|
||||
}
|
||||
|
||||
// TODO: split
|
||||
@ -9609,7 +9626,7 @@ static struct ggml_tensor * llm_build_mamba(
|
||||
struct ggml_cgraph * graph,
|
||||
struct ggml_tensor * cur,
|
||||
struct ggml_tensor * state_copy,
|
||||
struct ggml_tensor * state_mask,
|
||||
int32_t rs_zero,
|
||||
int32_t kv_head,
|
||||
int32_t n_kv,
|
||||
const llm_build_cb & cb,
|
||||
@ -9639,14 +9656,14 @@ static struct ggml_tensor * llm_build_mamba(
|
||||
struct ggml_tensor * ssm_states_all = kv.v_l[il];
|
||||
|
||||
// (ab)using the KV cache to store the states
|
||||
struct ggml_tensor * conv = llm_build_copy_mask_state(ctx,
|
||||
graph, conv_states_all, state_copy, state_mask,
|
||||
struct ggml_tensor * conv = llm_build_rs(ctx,
|
||||
graph, conv_states_all, state_copy, rs_zero,
|
||||
hparams.n_embd_k_s(), kv.size, kv_head, n_kv, n_seqs);
|
||||
conv = ggml_reshape_3d(ctx, conv, d_conv - 1, d_inner, n_seqs);
|
||||
struct ggml_tensor * ssm = llm_build_copy_mask_state(ctx,
|
||||
graph, ssm_states_all, state_copy, state_mask,
|
||||
hparams.n_embd_v_s(), kv.size, kv_head, n_kv, n_seqs);
|
||||
ssm = ggml_reshape_4d(ctx, ssm, d_state, head_dim, n_head, n_seqs);
|
||||
struct ggml_tensor * ssm = llm_build_rs(ctx,
|
||||
graph, ssm_states_all, state_copy, rs_zero,
|
||||
hparams.n_embd_v_s(), kv.size, kv_head, n_kv, n_seqs, true);
|
||||
ssm = ggml_reshape_4d(ctx, ssm, d_state, head_dim, n_head, kv.size);
|
||||
|
||||
// {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
|
||||
cur = ggml_reshape_3d(ctx, cur, cur->ne[0], n_seq_tokens, n_seqs);
|
||||
@ -9711,10 +9728,11 @@ static struct ggml_tensor * llm_build_mamba(
|
||||
|
||||
x = ggml_reshape_4d(ctx, x, head_dim, n_head, n_seq_tokens, n_seqs);
|
||||
|
||||
struct ggml_tensor * ssm_ids = ggml_view_1d(ctx, state_copy, n_seqs, 0);
|
||||
// Custom operator to optimize the parallel associative scan
|
||||
// as described in the Annex D of the Mamba paper.
|
||||
// => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
|
||||
struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C, model.layers[il].ssm_d);
|
||||
struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C, model.layers[il].ssm_d, ssm_ids);
|
||||
|
||||
// store last states
|
||||
ggml_build_forward_expand(graph,
|
||||
@ -9746,7 +9764,7 @@ static struct ggml_tensor * llm_build_mamba2(
|
||||
struct ggml_cgraph * graph,
|
||||
struct ggml_tensor * cur,
|
||||
struct ggml_tensor * state_copy,
|
||||
struct ggml_tensor * state_mask,
|
||||
int32_t rs_zero,
|
||||
int32_t kv_head,
|
||||
int32_t n_kv,
|
||||
const llm_build_cb & cb,
|
||||
@ -9772,14 +9790,14 @@ static struct ggml_tensor * llm_build_mamba2(
|
||||
struct ggml_tensor * ssm_states_all = kv.v_l[il];
|
||||
|
||||
// (ab)using the KV cache to store the states
|
||||
struct ggml_tensor * conv = llm_build_copy_mask_state(ctx,
|
||||
graph, conv_states_all, state_copy, state_mask,
|
||||
struct ggml_tensor * conv = llm_build_rs(ctx,
|
||||
graph, conv_states_all, state_copy, rs_zero,
|
||||
hparams.n_embd_k_s(), kv.size, kv_head, n_kv, n_seqs);
|
||||
conv = ggml_reshape_3d(ctx, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs);
|
||||
struct ggml_tensor * ssm = llm_build_copy_mask_state(ctx,
|
||||
graph, ssm_states_all, state_copy, state_mask,
|
||||
hparams.n_embd_v_s(), kv.size, kv_head, n_kv, n_seqs);
|
||||
ssm = ggml_reshape_4d(ctx, ssm, d_state, head_dim, n_head, n_seqs);
|
||||
struct ggml_tensor * ssm = llm_build_rs(ctx,
|
||||
graph, ssm_states_all, state_copy, rs_zero,
|
||||
hparams.n_embd_v_s(), kv.size, kv_head, n_kv, n_seqs, true);
|
||||
ssm = ggml_reshape_4d(ctx, ssm, d_state, head_dim, n_head, kv.size);
|
||||
|
||||
// {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
|
||||
cur = ggml_reshape_3d(ctx, cur, cur->ne[0], n_seq_tokens, n_seqs);
|
||||
@ -9835,9 +9853,12 @@ static struct ggml_tensor * llm_build_mamba2(
|
||||
// {n_head, n_seq_tokens, n_seqs}
|
||||
dt = ggml_add(ctx, dt, model.layers[il].ssm_dt_b);
|
||||
|
||||
struct ggml_tensor * ssm_ids = ggml_view_1d(ctx, state_copy, n_seqs, 0);
|
||||
// Use the same shape semantics for A as Mamba-1
|
||||
struct ggml_tensor * A = ggml_reshape_2d(ctx, model.layers[il].ssm_a, 1, n_head);
|
||||
// TODO: use semistructured matrices to implement state-space duality
|
||||
// => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
|
||||
struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C, model.layers[il].ssm_d);
|
||||
struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, model.layers[il].ssm_d, ssm_ids);
|
||||
|
||||
// store last states
|
||||
ggml_build_forward_expand(graph,
|
||||
@ -10069,6 +10090,7 @@ struct llm_build_context {
|
||||
const int32_t n_outputs;
|
||||
const int32_t n_outputs_enc;
|
||||
const int32_t kv_head; // index of where we store new KV data in the cache
|
||||
const int32_t rs_zero; // the first zero-ed recurrent state
|
||||
const int32_t n_ctx_orig;
|
||||
|
||||
const bool flash_attn;
|
||||
@ -10119,6 +10141,7 @@ struct llm_build_context {
|
||||
n_outputs (worst_case ? n_tokens : lctx.n_outputs),
|
||||
n_outputs_enc (worst_case ? n_tokens : lctx.embd_enc.size() / hparams.n_embd),
|
||||
kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head),
|
||||
rs_zero (kv_self.rs_z),
|
||||
n_ctx_orig (cparams.n_ctx_orig_yarn),
|
||||
flash_attn (cparams.flash_attn),
|
||||
pooling_type (cparams.pooling_type),
|
||||
@ -10147,8 +10170,6 @@ struct llm_build_context {
|
||||
lctx.inp_mean = nullptr;
|
||||
lctx.inp_cls = nullptr;
|
||||
lctx.inp_s_copy = nullptr;
|
||||
lctx.inp_s_mask = nullptr;
|
||||
lctx.inp_s_seq = nullptr;
|
||||
lctx.inp_pos_bucket = nullptr;
|
||||
lctx.inp_embd_enc = nullptr;
|
||||
lctx.inp_KQ_mask_cross = nullptr;
|
||||
@ -10332,13 +10353,6 @@ struct llm_build_context {
|
||||
return lctx.inp_s_copy;
|
||||
}
|
||||
|
||||
struct ggml_tensor * build_inp_s_mask() {
|
||||
lctx.inp_s_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv);
|
||||
cb(lctx.inp_s_mask, "inp_s_mask", -1);
|
||||
ggml_set_input(lctx.inp_s_mask);
|
||||
return lctx.inp_s_mask;
|
||||
}
|
||||
|
||||
struct ggml_cgraph * append_pooling(struct ggml_cgraph * gf) {
|
||||
// find result_norm tensor for input
|
||||
struct ggml_tensor * inp = nullptr;
|
||||
@ -13901,7 +13915,6 @@ struct llm_build_context {
|
||||
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
|
||||
|
||||
struct ggml_tensor * state_copy = build_inp_s_copy();
|
||||
struct ggml_tensor * state_mask = build_inp_s_mask();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
// norm
|
||||
@ -13912,15 +13925,13 @@ struct llm_build_context {
|
||||
|
||||
switch (version) {
|
||||
case 2:
|
||||
cur = llm_build_mamba2(ctx0, lctx, batch, gf, cur,
|
||||
state_copy, state_mask,
|
||||
kv_head, n_kv, cb, il);
|
||||
cur = llm_build_mamba2(ctx0, lctx, batch, gf, cur, state_copy,
|
||||
rs_zero, kv_head, n_kv, cb, il);
|
||||
break;
|
||||
case 1:
|
||||
default:
|
||||
cur = llm_build_mamba(ctx0, lctx, batch, gf, cur,
|
||||
state_copy, state_mask,
|
||||
kv_head, n_kv, cb, il);
|
||||
cur = llm_build_mamba(ctx0, lctx, batch, gf, cur, state_copy,
|
||||
rs_zero, kv_head, n_kv, cb, il);
|
||||
break;
|
||||
}
|
||||
|
||||
@ -15946,7 +15957,6 @@ struct llm_build_context {
|
||||
struct ggml_tensor * cur;
|
||||
struct ggml_tensor * inpL;
|
||||
struct ggml_tensor * state_copy = build_inp_s_copy();
|
||||
struct ggml_tensor * state_mask = build_inp_s_mask();
|
||||
|
||||
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
|
||||
inpL = llm_build_norm(ctx0, inpL, hparams, model.tok_norm, model.tok_norm_b, LLM_NORM, cb, -1);
|
||||
@ -15955,11 +15965,11 @@ struct llm_build_context {
|
||||
const llama_layer * layer = &model.layers[il];
|
||||
|
||||
// (ab)using the KV cache to store the states
|
||||
struct ggml_tensor * token_shift = llm_build_copy_mask_state(ctx0,
|
||||
gf, kv_self.k_l[il], state_copy, state_mask,
|
||||
struct ggml_tensor * token_shift = llm_build_rs(ctx0,
|
||||
gf, kv_self.k_l[il], state_copy, rs_zero,
|
||||
hparams.n_embd_k_s(), kv_self.size, kv_head, n_kv, n_seqs);
|
||||
struct ggml_tensor * wkv_states = llm_build_copy_mask_state(ctx0,
|
||||
gf, kv_self.v_l[il], state_copy, state_mask,
|
||||
struct ggml_tensor * wkv_states = llm_build_rs(ctx0,
|
||||
gf, kv_self.v_l[il], state_copy, rs_zero,
|
||||
hparams.n_embd_v_s(), kv_self.size, kv_head, n_kv, n_seqs);
|
||||
|
||||
cur = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
|
||||
@ -16329,18 +16339,6 @@ static void llama_set_k_shift(llama_context & lctx) {
|
||||
}
|
||||
}
|
||||
|
||||
static void llama_set_s_copy(llama_context & lctx) {
|
||||
const int64_t kv_size = lctx.kv_self.size;
|
||||
|
||||
assert(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer));
|
||||
|
||||
int32_t * data = (int32_t *) lctx.inp_s_copy->data;
|
||||
|
||||
for (int i = 0; i < kv_size; ++i) {
|
||||
data[i] = lctx.kv_self.cells[i].src;
|
||||
}
|
||||
}
|
||||
|
||||
static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
|
||||
// TODO move to hparams if a T5 variant appears that uses a different value
|
||||
const int64_t max_distance = 128;
|
||||
@ -16656,24 +16654,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
|
||||
if (kv_self.recurrent) {
|
||||
const int64_t n_kv = kv_self.n;
|
||||
|
||||
if (lctx.inp_s_mask) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer));
|
||||
float * data = (float *) lctx.inp_s_mask->data;
|
||||
|
||||
// clear unused states
|
||||
for (int i = 0; i < n_kv; ++i) {
|
||||
const uint32_t cell_id = i + kv_self.head;
|
||||
llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id];
|
||||
|
||||
data[i] = (float) (kv_cell.src >= 0);
|
||||
|
||||
// only clear once
|
||||
if (kv_cell.src < 0) {
|
||||
kv_cell.src = cell_id;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (lctx.inp_s_copy) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer));
|
||||
int32_t * data = (int32_t *) lctx.inp_s_copy->data;
|
||||
@ -16683,8 +16663,12 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
|
||||
const uint32_t cell_id = i + kv_self.head;
|
||||
llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id];
|
||||
|
||||
// prevent out-of-bound sources
|
||||
if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self.size) {
|
||||
if (kv_cell.src < 0) {
|
||||
GGML_ASSERT(kv_self.rs_z >= 0); // Need a valid zero-ed cell as a source
|
||||
kv_cell.src = kv_self.rs_z;
|
||||
}
|
||||
if ((uint32_t) kv_cell.src >= kv_self.size) {
|
||||
// ignore out-of-bound sources
|
||||
kv_cell.src = cell_id;
|
||||
}
|
||||
|
||||
|
@ -1530,27 +1530,58 @@ struct test_ssm_scan : public test_case {
|
||||
|
||||
const int64_t d_state;
|
||||
const int64_t d_inner;
|
||||
const int64_t n_head;
|
||||
const int64_t n_group;
|
||||
const int64_t n_seq_tokens;
|
||||
const int64_t n_seqs;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR5(type, d_state, d_inner, n_seq_tokens, n_seqs);
|
||||
return VARS_TO_STR7(type, d_state, d_inner, n_head, n_group, n_seq_tokens, n_seqs);
|
||||
}
|
||||
|
||||
test_ssm_scan(ggml_type type = GGML_TYPE_F32,
|
||||
int64_t d_state = 32, int64_t d_inner = 32, int64_t n_seq_tokens = 32, int64_t n_seqs = 32)
|
||||
: type(type), d_state(d_state), d_inner(d_inner), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
|
||||
int64_t d_state = 32,
|
||||
int64_t d_inner = 1, // non-zero for Mamba-2
|
||||
int64_t n_head = 32,
|
||||
int64_t n_group = 1,
|
||||
int64_t n_seq_tokens = 32,
|
||||
int64_t n_seqs = 32)
|
||||
: type(type), d_state(d_state), d_inner(d_inner), n_head(n_head), n_group(n_group), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * s = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_state, d_inner, n_seqs, 1 }.data());
|
||||
ggml_tensor * x = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_inner, n_seq_tokens, n_seqs, 1 }.data());
|
||||
ggml_tensor * dt = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_inner, n_seq_tokens, n_seqs, 1 }.data());
|
||||
ggml_tensor * A = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_state, d_inner, 1 , 1 }.data());
|
||||
ggml_tensor * B = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_state, n_seq_tokens, n_seqs, 1 }.data());
|
||||
ggml_tensor * C = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_state, n_seq_tokens, n_seqs, 1 }.data());
|
||||
ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C);
|
||||
ggml_tensor * s = ggml_new_tensor_4d(ctx, type, d_state, d_inner, n_head, n_seqs);
|
||||
ggml_tensor * x = ggml_new_tensor_4d(ctx, type, d_inner, n_head, n_seq_tokens, n_seqs);
|
||||
ggml_tensor * dt = ggml_new_tensor_3d(ctx, type, n_head, n_seq_tokens, n_seqs);
|
||||
ggml_tensor * A = ggml_new_tensor_2d(ctx, type, (d_inner > 1) ? 1 : d_state, n_head);
|
||||
ggml_tensor * B = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs);
|
||||
ggml_tensor * C = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs);
|
||||
ggml_tensor * D = ggml_new_tensor_1d(ctx, type, n_head);
|
||||
ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_seqs);
|
||||
ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C, D, ids);
|
||||
return out;
|
||||
}
|
||||
|
||||
// similar to test_mul_mat_id
|
||||
void initialize_tensors(ggml_context * ctx) override {
|
||||
std::random_device rd;
|
||||
std::default_random_engine rng(rd());
|
||||
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
if (t->type == GGML_TYPE_I32) {
|
||||
if (ggml_is_view_op(t->op)) { continue; }
|
||||
// ids
|
||||
for (int64_t r = 0; r < ggml_nrows(t); r++) {
|
||||
std::vector<int32_t> data(t->ne[0]);
|
||||
for (int i = 0; i < t->ne[0]; i++) {
|
||||
data[i] = i;
|
||||
}
|
||||
std::shuffle(data.begin(), data.end(), rng);
|
||||
ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t));
|
||||
}
|
||||
} else {
|
||||
init_tensor_uniform(t);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_MUL_MAT
|
||||
@ -3255,7 +3286,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
||||
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {8, 1536, 1, 1}, {4, 1536, 1, 1}));
|
||||
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 4, 1}, {4, 1536, 1, 1}));
|
||||
|
||||
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1024, 32, 4));
|
||||
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1, 1024, 1, 32, 4)); // Mamba-1
|
||||
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 32, 32, 2, 32, 4)); // Mamba-2
|
||||
|
||||
#if 1
|
||||
for (ggml_type type_a : base_types) {
|
||||
|
Loading…
Reference in New Issue
Block a user