llama : initial Mamba-2 support

This commit is contained in:
Francis Couture-Harpin 2024-08-01 10:43:42 -04:00
parent a1631e53f6
commit 1f0fea70fb
7 changed files with 490 additions and 82 deletions

View File

@ -2788,6 +2788,73 @@ class MambaModel(Model):
return [(new_name, data_torch)] return [(new_name, data_torch)]
@Model.register("Mamba2ForCausalLM")
class Mamba2Model(Model):
model_arch = gguf.MODEL_ARCH.MAMBA2
def set_vocab(self):
vocab_size = self.hparams["vocab_size"]
# Round vocab size to next multiple of 16
pad_vocab = self.hparams.get("pad_vocab_size_multiple", 16)
# pad using ceiling division
# ref: https://stackoverflow.com/a/17511341/22827863
vocab_size = -(vocab_size // -pad_vocab) * pad_vocab
self.hparams["vocab_size"] = vocab_size
if (self.dir_model / "tokenizer.json").is_file():
self._set_vocab_gpt2()
elif (self.dir_model / "tokenizer.model").is_file():
self._set_vocab_sentencepiece()
elif (self.dir_model / "tokenizer.model.v3").is_file():
# mamba-codestral
raise NotImplementedError(f"Please rename {self.dir_model / 'tokenizer.model.v3'} to {self.dir_model / 'tokenizer.model'}")
else:
# Use the GPT-NeoX tokenizer when no tokenizer files are present
self._set_vocab_builtin("gpt-neox", vocab_size)
def set_gguf_parameters(self):
d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 128
head_dim = self.find_hparam(["head_dim"], optional=True) or 64
n_group = self.find_hparam(["n_groups"], optional=True) or 1
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
# Fail early for models which don't have a block expansion factor of 2
# TODO: does this really matter?
assert d_inner == 2 * d_model
assert d_inner % head_dim == 0
self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
self.gguf_writer.add_embedding_length(d_model)
self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_ssm_conv_kernel(d_conv)
self.gguf_writer.add_ssm_inner_size(d_inner)
self.gguf_writer.add_ssm_state_size(d_state)
self.gguf_writer.add_ssm_time_step_rank(d_inner // head_dim)
self.gguf_writer.add_ssm_group_count(n_group)
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
self.gguf_writer.add_file_type(self.ftype)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
if name.endswith(".dt_bias"):
name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias"
new_name = self.map_tensor_name(name)
if name.endswith(".A_log"):
logger.debug("A_log --> A ==> " + new_name)
data_torch = -torch.exp(data_torch)
yield (new_name, data_torch)
@Model.register("CohereForCausalLM") @Model.register("CohereForCausalLM")
class CommandR2Model(Model): class CommandR2Model(Model):
model_arch = gguf.MODEL_ARCH.COMMAND_R model_arch = gguf.MODEL_ARCH.COMMAND_R

View File

@ -1787,7 +1787,8 @@ extern "C" {
struct ggml_tensor * dt, struct ggml_tensor * dt,
struct ggml_tensor * A, struct ggml_tensor * A,
struct ggml_tensor * B, struct ggml_tensor * B,
struct ggml_tensor * C); struct ggml_tensor * C,
struct ggml_tensor * D);
// partition into non-overlapping windows with padding if needed // partition into non-overlapping windows with padding if needed
// example: // example:

View File

@ -7270,32 +7270,48 @@ struct ggml_tensor * ggml_ssm_scan(
struct ggml_tensor * dt, struct ggml_tensor * dt,
struct ggml_tensor * A, struct ggml_tensor * A,
struct ggml_tensor * B, struct ggml_tensor * B,
struct ggml_tensor * C) { struct ggml_tensor * C,
struct ggml_tensor * D) {
GGML_ASSERT(ggml_is_contiguous(s)); GGML_ASSERT(ggml_is_contiguous(s));
GGML_ASSERT(ggml_is_contiguous(x));
GGML_ASSERT(ggml_is_contiguous(dt)); GGML_ASSERT(ggml_is_contiguous(dt));
GGML_ASSERT(ggml_is_contiguous(A)); GGML_ASSERT(ggml_is_contiguous(A));
GGML_ASSERT(ggml_is_matrix(A)); GGML_ASSERT(x->nb[0] == ggml_type_size(x->type));
GGML_ASSERT(ggml_is_3d(B));
GGML_ASSERT(ggml_is_3d(s));
GGML_ASSERT(B->nb[0] == ggml_type_size(B->type)); GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
GGML_ASSERT(C->nb[0] == ggml_type_size(C->type)); GGML_ASSERT(C->nb[0] == ggml_type_size(C->type));
GGML_ASSERT(ggml_are_same_shape(x, dt)); GGML_ASSERT(x->nb[1] == x->ne[0]*x->nb[0]);
GGML_ASSERT(B->nb[1] == B->ne[0]*B->nb[0]);
GGML_ASSERT(C->nb[1] == C->ne[0]*C->nb[0]);
GGML_ASSERT(ggml_are_same_shape(B, C)); GGML_ASSERT(ggml_are_same_shape(B, C));
{ {
const int64_t d_state = s->ne[0]; const int64_t d_state = s->ne[0];
const int64_t d_inner = s->ne[1]; const int64_t head_dim = x->ne[0];
const int64_t n_seq_tokens = x->ne[1]; const int64_t n_head = x->ne[1];
const int64_t n_seqs = x->ne[2]; const int64_t n_seq_tokens = x->ne[2];
const int64_t n_seqs = x->ne[3];
GGML_ASSERT(s->ne[2] == n_seqs); GGML_ASSERT(dt->ne[0] == n_head);
GGML_ASSERT(x->ne[0] == d_inner); GGML_ASSERT(dt->ne[1] == n_seq_tokens);
GGML_ASSERT(A->ne[0] == d_state); GGML_ASSERT(dt->ne[2] == n_seqs);
GGML_ASSERT(A->ne[1] == d_inner); GGML_ASSERT(ggml_is_3d(dt));
GGML_ASSERT(s->ne[1] == head_dim);
GGML_ASSERT(s->ne[2] == n_head);
GGML_ASSERT(s->ne[3] == n_seqs);
GGML_ASSERT(B->ne[0] == d_state); GGML_ASSERT(B->ne[0] == d_state);
GGML_ASSERT(B->ne[1] == n_seq_tokens); GGML_ASSERT(B->ne[2] == n_seq_tokens);
GGML_ASSERT(B->ne[2] == n_seqs); GGML_ASSERT(B->ne[3] == n_seqs);
GGML_ASSERT(D->ne[0] == n_head);
GGML_ASSERT(ggml_is_vector(D));
if (ggml_is_vector(A)) {
// Mamba-2
GGML_ASSERT(A->ne[0] == n_head);
} else {
// Mamba-1
GGML_ASSERT(A->ne[0] == d_state);
GGML_ASSERT(A->ne[1] == n_head);
GGML_ASSERT(ggml_is_matrix(A));
}
} }
bool is_node = false; bool is_node = false;
@ -7316,6 +7332,7 @@ struct ggml_tensor * ggml_ssm_scan(
result->src[3] = A; result->src[3] = A;
result->src[4] = B; result->src[4] = B;
result->src[5] = C; result->src[5] = C;
result->src[6] = D;
return result; return result;
} }
@ -15840,20 +15857,25 @@ static void ggml_compute_forward_ssm_conv(
static void ggml_compute_forward_ssm_scan_f32( static void ggml_compute_forward_ssm_scan_f32(
const struct ggml_compute_params * params, const struct ggml_compute_params * params,
struct ggml_tensor * dst) { struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0]; // s const struct ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs}
const struct ggml_tensor * src1 = dst->src[1]; // x const struct ggml_tensor * src1 = dst->src[1]; // x {dim, n_head, n_seq_tokens, n_seqs}
const struct ggml_tensor * src2 = dst->src[2]; // dt const struct ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs}
const struct ggml_tensor * src3 = dst->src[3]; // A const struct ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {n_head}
const struct ggml_tensor * src4 = dst->src[4]; // B const struct ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs}
const struct ggml_tensor * src5 = dst->src[5]; // C const struct ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs}
const struct ggml_tensor * src6 = dst->src[6]; // D {n_head}
const int ith = params->ith; const int ith = params->ith;
const int nth = params->nth; const int nth = params->nth;
const int64_t nc = src0->ne[0]; // d_state const int64_t nc = src0->ne[0]; // d_state
const int64_t nr = src0->ne[1]; // d_inner const int64_t nr = src0->ne[1]; // dim
const int64_t n_t = src1->ne[1]; // number of tokens per sequence const int64_t nh = src1->ne[1]; // n_head
const int64_t n_s = src0->ne[2]; // number of sequences in the batch const int64_t ng = src4->ne[1];
const int64_t nt = src1->ne[2]; // number of tokens per sequence
const int64_t ns = src0->ne[3]; // number of sequences in the batch
const int64_t s_off = ggml_element_size(src1) * ggml_nelements(src1);
GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst)); GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src0->nb[0] == sizeof(float));
@ -15862,51 +15884,86 @@ static void ggml_compute_forward_ssm_scan_f32(
GGML_ASSERT(src3->nb[0] == sizeof(float)); GGML_ASSERT(src3->nb[0] == sizeof(float));
GGML_ASSERT(src4->nb[0] == sizeof(float)); GGML_ASSERT(src4->nb[0] == sizeof(float));
GGML_ASSERT(src5->nb[0] == sizeof(float)); GGML_ASSERT(src5->nb[0] == sizeof(float));
// required for the dot product between s and C GGML_ASSERT(src6->nb[0] == sizeof(float));
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); // allows optimizing the modulo since n_group should be a power of 2
// required for per-sequence offsets for states GGML_ASSERT((ng & -ng) == ng);
GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
// required to get correct offset for state destination (i.e. src1->nb[3])
GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float));
// rows per thread // heads per thread
const int dr = (nr + nth - 1)/nth; const int dh = (nh + nth - 1)/nth;
// row range for this thread // head range for this thread
const int ir0 = dr*ith; const int ih0 = dh*ith;
const int ir1 = MIN(ir0 + dr, nr); const int ih1 = MIN(ih0 + dh, nh);
const int ir = ir1 - ir0;
for (int i3 = 0; i3 < n_s; ++i3) { for (int i3 = 0; i3 < ns; ++i3) {
for (int i2 = 0; i2 < n_t; ++i2) { for (int i2 = 0; i2 < nt; ++i2) {
const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s} const float * s0 = (const float *) ((const char *) src0->data + i3*(src0->nb[3])); // {d_state, dim, nh, ns}
const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} const float * x = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns}
const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s} const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns}
const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {nh}
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s} const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns}
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s} const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns}
float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} const float * D = (const float *) ((const char *) src6->data); // {nh}
float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s} float * y = (float *) ((char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns}
float * s = (float *) ((char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns}
// use the output as the source for the next token-wise iterations // use the output as the source when it's not the first token-wise iteration
if (i2 > 0) { s0 = s; } if (i2 > 0) { s0 = s; }
// d_inner if (ggml_is_vector(src3)) {
for (int i1 = 0; i1 < ir; ++i1) { // Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop
// ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1]; // n_head
float x_dt = x[i1] * dt_soft_plus; for (int h = ih0; h < ih1; ++h) {
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
const float dA = expf(dt_soft_plus * A[h]);
// TODO: SIMD implementation
// dim
for (int i1 = 0; i1 < nr; ++i1) {
const int i = i1 + h*nr;
const float x_dt = x[i] * dt_soft_plus;
float sumf = 0.0f; float sumf = 0.0f;
// d_state // d_state
for (int i0 = 0; i0 < nc; ++i0) { for (int i0 = 0; i0 < nc; ++i0) {
int i = i0 + i1*nc; const int ii = i0 + i*nc;
const int ig = i0 + (h & (ng - 1))*nc;
// state = prev_state * dA + dB * x // state = prev_state * dA + dB * x
float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); const float state = (s0[ii] * dA) + (B[ig] * x_dt);
// y = rowwise_dotprod(state, C) // y = rowwise_dotprod(state, C)
sumf += state * C[i0]; sumf += state * C[ig];
s[i] = state; s[ii] = state;
}
y[i] = sumf + x[i] * D[h];
}
}
} else {
// Mamba-1 has an element-wise decay factor for the states
// n_head
for (int h = ih0; h < ih1; ++h) {
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
// dim
for (int i1 = 0; i1 < nr; ++i1) {
const int i = i1 + h*nr;
const float x_dt = x[i] * dt_soft_plus;
float sumf = 0.0f;
// d_state
for (int i0 = 0; i0 < nc; ++i0) {
const int ii = i0 + i*nc;
const int ig = i0 + (h & (ng - 1))*nc;
// state = prev_state * dA + dB * x
const float state = (s0[ii] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
// y = rowwise_dotprod(state, C)
sumf += state * C[ig];
s[ii] = state;
}
y[i] = sumf + x[i] * D[h];
}
} }
y[i1] = sumf;
} }
} }
} }

View File

@ -130,6 +130,7 @@ class Keys:
INNER_SIZE = "{arch}.ssm.inner_size" INNER_SIZE = "{arch}.ssm.inner_size"
STATE_SIZE = "{arch}.ssm.state_size" STATE_SIZE = "{arch}.ssm.state_size"
TIME_STEP_RANK = "{arch}.ssm.time_step_rank" TIME_STEP_RANK = "{arch}.ssm.time_step_rank"
GROUP_COUNT = "{arch}.ssm.group_count"
DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms" DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms"
class Tokenizer: class Tokenizer:
@ -208,6 +209,7 @@ class MODEL_ARCH(IntEnum):
GEMMA2 = auto() GEMMA2 = auto()
STARCODER2 = auto() STARCODER2 = auto()
MAMBA = auto() MAMBA = auto()
MAMBA2 = auto()
XVERSE = auto() XVERSE = auto()
COMMAND_R = auto() COMMAND_R = auto()
DBRX = auto() DBRX = auto()
@ -269,6 +271,7 @@ class MODEL_TENSOR(IntEnum):
SSM_DT = auto() SSM_DT = auto()
SSM_A = auto() SSM_A = auto()
SSM_D = auto() SSM_D = auto()
SSM_NORM = auto()
SSM_OUT = auto() SSM_OUT = auto()
ATTN_Q_A = auto() ATTN_Q_A = auto()
ATTN_Q_B = auto() ATTN_Q_B = auto()
@ -338,6 +341,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.GEMMA2: "gemma2", MODEL_ARCH.GEMMA2: "gemma2",
MODEL_ARCH.STARCODER2: "starcoder2", MODEL_ARCH.STARCODER2: "starcoder2",
MODEL_ARCH.MAMBA: "mamba", MODEL_ARCH.MAMBA: "mamba",
MODEL_ARCH.MAMBA2: "mamba2",
MODEL_ARCH.XVERSE: "xverse", MODEL_ARCH.XVERSE: "xverse",
MODEL_ARCH.COMMAND_R: "command-r", MODEL_ARCH.COMMAND_R: "command-r",
MODEL_ARCH.DBRX: "dbrx", MODEL_ARCH.DBRX: "dbrx",
@ -399,6 +403,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt", MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt",
MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a", MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a",
MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d", MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
MODEL_TENSOR.SSM_NORM: "blk.{bid}.ssm_norm",
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out", MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
MODEL_TENSOR.ATTN_Q_A: "blk.{bid}.attn_q_a", MODEL_TENSOR.ATTN_Q_A: "blk.{bid}.attn_q_a",
MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b", MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b",
@ -869,6 +874,19 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.SSM_D, MODEL_TENSOR.SSM_D,
MODEL_TENSOR.SSM_OUT, MODEL_TENSOR.SSM_OUT,
], ],
MODEL_ARCH.MAMBA2: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.SSM_IN,
MODEL_TENSOR.SSM_CONV1D,
MODEL_TENSOR.SSM_DT,
MODEL_TENSOR.SSM_A,
MODEL_TENSOR.SSM_D,
MODEL_TENSOR.SSM_NORM,
MODEL_TENSOR.SSM_OUT,
],
MODEL_ARCH.XVERSE: [ MODEL_ARCH.XVERSE: [
MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT_NORM,
@ -1373,6 +1391,7 @@ KEY_SSM_CONV_KERNEL = Keys.SSM.CONV_KERNEL
KEY_SSM_INNER_SIZE = Keys.SSM.INNER_SIZE KEY_SSM_INNER_SIZE = Keys.SSM.INNER_SIZE
KEY_SSM_STATE_SIZE = Keys.SSM.STATE_SIZE KEY_SSM_STATE_SIZE = Keys.SSM.STATE_SIZE
KEY_SSM_TIME_STEP_RANK = Keys.SSM.TIME_STEP_RANK KEY_SSM_TIME_STEP_RANK = Keys.SSM.TIME_STEP_RANK
KEY_SSM_GROUP_COUNT = Keys.SSM.GROUP_COUNT
KEY_SSM_DT_B_C_RMS = Keys.SSM.DT_B_C_RMS KEY_SSM_DT_B_C_RMS = Keys.SSM.DT_B_C_RMS
# tokenization # tokenization

View File

@ -730,6 +730,9 @@ class GGUFWriter:
def add_ssm_time_step_rank(self, value: int) -> None: def add_ssm_time_step_rank(self, value: int) -> None:
self.add_uint32(Keys.SSM.TIME_STEP_RANK.format(arch=self.arch), value) self.add_uint32(Keys.SSM.TIME_STEP_RANK.format(arch=self.arch), value)
def add_ssm_group_count(self, value: int) -> None:
self.add_uint32(Keys.SSM.GROUP_COUNT.format(arch=self.arch), value)
def add_ssm_dt_b_c_rms(self, value: bool) -> None: def add_ssm_dt_b_c_rms(self, value: bool) -> None:
self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value) self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value)

View File

@ -396,7 +396,7 @@ class TensorNameMap:
"encoder.layers.{bid}.norm2", # nomic-bert "encoder.layers.{bid}.norm2", # nomic-bert
"transformer.decoder_layer.{bid}.rms_norm_3", # Grok "transformer.decoder_layer.{bid}.rms_norm_3", # Grok
"encoder.layer.{bid}.mlp.layernorm", # jina-bert-v2 "encoder.layer.{bid}.mlp.layernorm", # jina-bert-v2
"encoder.layer.{bid}.layer_norm_2" # jina-v2-code "encoder.layer.{bid}.layer_norm_2", # jina-v2-code
), ),
MODEL_TENSOR.SSM_IN: ( MODEL_TENSOR.SSM_IN: (
@ -429,6 +429,10 @@ class TensorNameMap:
"backbone.layers.{bid}.mixer.D", "backbone.layers.{bid}.mixer.D",
), ),
MODEL_TENSOR.SSM_NORM: (
"backbone.layers.{bid}.mixer.norm", # mamba2
),
MODEL_TENSOR.SSM_OUT: ( MODEL_TENSOR.SSM_OUT: (
"model.layers.{bid}.out_proj", "model.layers.{bid}.out_proj",
"backbone.layers.{bid}.mixer.out_proj", "backbone.layers.{bid}.mixer.out_proj",

View File

@ -198,6 +198,7 @@ enum llm_arch {
LLM_ARCH_GEMMA2, LLM_ARCH_GEMMA2,
LLM_ARCH_STARCODER2, LLM_ARCH_STARCODER2,
LLM_ARCH_MAMBA, LLM_ARCH_MAMBA,
LLM_ARCH_MAMBA2,
LLM_ARCH_XVERSE, LLM_ARCH_XVERSE,
LLM_ARCH_COMMAND_R, LLM_ARCH_COMMAND_R,
LLM_ARCH_DBRX, LLM_ARCH_DBRX,
@ -245,6 +246,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_GEMMA2, "gemma2" }, { LLM_ARCH_GEMMA2, "gemma2" },
{ LLM_ARCH_STARCODER2, "starcoder2" }, { LLM_ARCH_STARCODER2, "starcoder2" },
{ LLM_ARCH_MAMBA, "mamba" }, { LLM_ARCH_MAMBA, "mamba" },
{ LLM_ARCH_MAMBA2, "mamba2" },
{ LLM_ARCH_XVERSE, "xverse" }, { LLM_ARCH_XVERSE, "xverse" },
{ LLM_ARCH_COMMAND_R, "command-r" }, { LLM_ARCH_COMMAND_R, "command-r" },
{ LLM_ARCH_DBRX, "dbrx" }, { LLM_ARCH_DBRX, "dbrx" },
@ -328,6 +330,7 @@ enum llm_kv {
LLM_KV_SSM_CONV_KERNEL, LLM_KV_SSM_CONV_KERNEL,
LLM_KV_SSM_STATE_SIZE, LLM_KV_SSM_STATE_SIZE,
LLM_KV_SSM_TIME_STEP_RANK, LLM_KV_SSM_TIME_STEP_RANK,
LLM_KV_SSM_GROUP_COUNT,
LLM_KV_SSM_DT_B_C_RMS, LLM_KV_SSM_DT_B_C_RMS,
LLM_KV_TOKENIZER_MODEL, LLM_KV_TOKENIZER_MODEL,
@ -427,6 +430,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" }, { LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" },
{ LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" }, { LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" },
{ LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" }, { LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" },
{ LLM_KV_SSM_GROUP_COUNT, "%s.ssm.group_count" },
{ LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" }, { LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" },
{ LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
@ -517,6 +521,7 @@ enum llm_tensor {
LLM_TENSOR_SSM_DT, LLM_TENSOR_SSM_DT,
LLM_TENSOR_SSM_A, LLM_TENSOR_SSM_A,
LLM_TENSOR_SSM_D, LLM_TENSOR_SSM_D,
LLM_TENSOR_SSM_NORM,
LLM_TENSOR_SSM_OUT, LLM_TENSOR_SSM_OUT,
LLM_TENSOR_ATTN_Q_A, LLM_TENSOR_ATTN_Q_A,
LLM_TENSOR_ATTN_Q_B, LLM_TENSOR_ATTN_Q_B,
@ -1068,6 +1073,22 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
}, },
}, },
{
LLM_ARCH_MAMBA2,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
{ LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
{ LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
{ LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" },
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
},
},
{ {
LLM_ARCH_XVERSE, LLM_ARCH_XVERSE,
{ {
@ -2239,6 +2260,7 @@ struct llama_hparams {
uint32_t ssm_d_inner = 0; uint32_t ssm_d_inner = 0;
uint32_t ssm_d_state = 0; uint32_t ssm_d_state = 0;
uint32_t ssm_dt_rank = 0; uint32_t ssm_dt_rank = 0;
uint32_t ssm_n_group = 0;
bool ssm_dt_b_c_rms = false; bool ssm_dt_b_c_rms = false;
float f_clamp_kqv = 0.0f; float f_clamp_kqv = 0.0f;
@ -2289,6 +2311,7 @@ struct llama_hparams {
if (this->ssm_d_inner != other.ssm_d_inner) return true; if (this->ssm_d_inner != other.ssm_d_inner) return true;
if (this->ssm_d_state != other.ssm_d_state) return true; if (this->ssm_d_state != other.ssm_d_state) return true;
if (this->ssm_dt_rank != other.ssm_dt_rank) return true; if (this->ssm_dt_rank != other.ssm_dt_rank) return true;
if (this->ssm_n_group != other.ssm_n_group) return true;
if (this->ssm_dt_b_c_rms != other.ssm_dt_b_c_rms) return true; if (this->ssm_dt_b_c_rms != other.ssm_dt_b_c_rms) return true;
if (this->dec_start_token_id != other.dec_start_token_id) return true; if (this->dec_start_token_id != other.dec_start_token_id) return true;
@ -2357,7 +2380,7 @@ struct llama_hparams {
// corresponds to Mamba's conv_states size // corresponds to Mamba's conv_states size
// TODO: maybe support other convolution strides than 1 // TODO: maybe support other convolution strides than 1
// NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner; return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * (ssm_d_inner + 2*ssm_n_group*ssm_d_state);
} }
uint32_t n_embd_v_s() const { // dimension of the recurrent state embeddings uint32_t n_embd_v_s() const { // dimension of the recurrent state embeddings
@ -2419,6 +2442,7 @@ struct llama_layer {
struct ggml_tensor * ffn_sub_norm; struct ggml_tensor * ffn_sub_norm;
struct ggml_tensor * attn_norm_cross; struct ggml_tensor * attn_norm_cross;
struct ggml_tensor * attn_norm_enc; struct ggml_tensor * attn_norm_enc;
struct ggml_tensor * ssm_norm;
// attention // attention
struct ggml_tensor * wq; struct ggml_tensor * wq;
@ -5573,6 +5597,38 @@ static void llm_load_hparams(
default: model.type = e_model::MODEL_UNKNOWN; default: model.type = e_model::MODEL_UNKNOWN;
} }
} break; } break;
case LLM_ARCH_MAMBA2:
{
ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv);
ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner);
ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state);
ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
switch (hparams.n_layer) {
case 24:
switch (hparams.n_embd) {
case 768: model.type = e_model::MODEL_SMALL; break;
default: model.type = e_model::MODEL_UNKNOWN;
} break;
case 48:
switch (hparams.n_embd) {
case 1024: model.type = e_model::MODEL_MEDIUM; break;
case 1536: model.type = e_model::MODEL_LARGE; break;
case 2048: model.type = e_model::MODEL_XL; break;
default: model.type = e_model::MODEL_UNKNOWN;
} break;
case 64:
switch (hparams.n_embd) {
case 2560: model.type = e_model::MODEL_3B; break;
case 4096: model.type = e_model::MODEL_7B; break;
default: model.type = e_model::MODEL_UNKNOWN;
} break;
default: model.type = e_model::MODEL_UNKNOWN;
}
}
case LLM_ARCH_XVERSE: case LLM_ARCH_XVERSE:
{ {
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@ -6404,6 +6460,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner);
LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state);
LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank); LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank);
LLAMA_LOG_INFO("%s: ssm_n_group = %u\n", __func__, hparams.ssm_n_group);
LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms); LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms);
} }
@ -7639,7 +7696,7 @@ static bool llm_load_tensors(
layer.ssm_in = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2*d_inner}); layer.ssm_in = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2*d_inner});
layer.ssm_conv1d = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner}); layer.ssm_conv1d = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner});
layer.ssm_conv1d_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner}); layer.ssm_conv1d_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner});
layer.ssm_x = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_X, "weight", i), {d_inner, dt_rank + 2*d_state}); layer.ssm_x = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_X, "weight", i), {d_inner, dt_rank + 2*d_state});
@ -7648,9 +7705,61 @@ static bool llm_load_tensors(
layer.ssm_dt_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_DT, "bias", i), {d_inner}); layer.ssm_dt_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_DT, "bias", i), {d_inner});
// no "weight" suffix for these // no "weight" suffix for these
layer.ssm_a = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner}); layer.ssm_a = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner});
layer.ssm_d = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_D, i), {d_inner}); layer.ssm_d = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_D, i), {d_inner});
// out_proj
layer.ssm_out = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd});
}
} break;
case LLM_ARCH_MAMBA2:
{
const int64_t d_conv = hparams.ssm_d_conv;
const int64_t d_inner = hparams.ssm_d_inner;
const int64_t d_state = hparams.ssm_d_state;
const int64_t n_head = hparams.ssm_dt_rank;
const int64_t head_dim = n_embd / n_head;
const int64_t n_group = hparams.ssm_n_group;
const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_head;
// only an expansion factor of 2 is supported for now
GGML_ASSERT(2 * n_embd == d_inner);
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
// output
{
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
// if output is NULL, init from the input tok embed, duplicated to allow offloading
if (model.output == NULL) {
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
}
}
for (int i = 0; i < n_layer; ++i) {
ggml_context * ctx_layer = ctx_for_layer(i);
ggml_context * ctx_split = ctx_for_layer_split(i);
auto & layer = model.layers[i];
// norm
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
layer.ssm_in = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj});
layer.ssm_conv1d = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state});
layer.ssm_conv1d_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state});
layer.ssm_dt_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_DT, "bias", i), {n_head});
// no "weight" suffix for these
layer.ssm_a = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_A, i), {n_head});
layer.ssm_d = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_D, i), {n_head});
layer.ssm_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner});
// out_proj // out_proj
layer.ssm_out = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}); layer.ssm_out = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd});
} }
@ -9041,6 +9150,8 @@ static struct ggml_tensor * llm_build_mamba(
const int64_t d_inner = hparams.ssm_d_inner; const int64_t d_inner = hparams.ssm_d_inner;
const int64_t d_state = hparams.ssm_d_state; const int64_t d_state = hparams.ssm_d_state;
const int64_t dt_rank = hparams.ssm_dt_rank; const int64_t dt_rank = hparams.ssm_dt_rank;
const int64_t n_head = d_inner;
const int64_t head_dim = 1;
const int64_t n_seqs = batch.n_seqs; const int64_t n_seqs = batch.n_seqs;
// Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers) // Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers)
const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms; const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms;
@ -9064,7 +9175,7 @@ static struct ggml_tensor * llm_build_mamba(
struct ggml_tensor * ssm = llm_build_copy_mask_state(ctx, struct ggml_tensor * ssm = llm_build_copy_mask_state(ctx,
graph, ssm_states_all, state_copy, state_mask, graph, ssm_states_all, state_copy, state_mask,
hparams.n_embd_v_s(), kv.size, kv_head, n_kv, n_seqs); hparams.n_embd_v_s(), kv.size, kv_head, n_kv, n_seqs);
ssm = ggml_reshape_3d(ctx, ssm, d_state, d_inner, n_seqs); ssm = ggml_reshape_4d(ctx, ssm, d_state, head_dim, n_head, n_seqs);
// {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
cur = ggml_reshape_3d(ctx, cur, cur->ne[0], n_seq_tokens, n_seqs); cur = ggml_reshape_3d(ctx, cur, cur->ne[0], n_seq_tokens, n_seqs);
@ -9113,8 +9224,8 @@ static struct ggml_tensor * llm_build_mamba(
struct ggml_tensor * x_db = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_x, x); struct ggml_tensor * x_db = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_x, x);
// split // split
struct ggml_tensor * dt = ggml_view_3d(ctx, x_db, dt_rank, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], 0); struct ggml_tensor * dt = ggml_view_3d(ctx, x_db, dt_rank, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], 0);
struct ggml_tensor * B = ggml_view_3d(ctx, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*dt_rank); struct ggml_tensor * B = ggml_view_4d(ctx, x_db, d_state, /* n_group */ 1, n_seq_tokens, n_seqs, d_state*x_db->nb[0], x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*dt_rank);
struct ggml_tensor * C = ggml_view_3d(ctx, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*(dt_rank+d_state)); struct ggml_tensor * C = ggml_view_4d(ctx, x_db, d_state, /* n_group */ 1, n_seq_tokens, n_seqs, d_state*x_db->nb[0], x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*(dt_rank+d_state));
// Some Mamba variants (e.g. FalconMamba) apply RMS norm in B, C & Dt layers // Some Mamba variants (e.g. FalconMamba) apply RMS norm in B, C & Dt layers
if (ssm_dt_b_c_rms) { if (ssm_dt_b_c_rms) {
@ -9127,23 +9238,23 @@ static struct ggml_tensor * llm_build_mamba(
dt = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_dt, dt); dt = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_dt, dt);
dt = ggml_add(ctx, dt, model.layers[il].ssm_dt_b); dt = ggml_add(ctx, dt, model.layers[il].ssm_dt_b);
x = ggml_reshape_4d(ctx, x, head_dim, n_head, n_seq_tokens, n_seqs);
// Custom operator to optimize the parallel associative scan // Custom operator to optimize the parallel associative scan
// as described in the Annex D of the Mamba paper. // as described in the Annex D of the Mamba paper.
// => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C); struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C, model.layers[il].ssm_d);
// store last states // store last states
ggml_build_forward_expand(graph, ggml_build_forward_expand(graph,
ggml_cpy(ctx, ggml_cpy(ctx,
ggml_view_1d(ctx, y_ssm, d_state*d_inner*n_seqs, x->nb[3]), ggml_view_1d(ctx, y_ssm, d_state*d_inner*n_seqs, x->nb[3]*x->ne[3]),
ggml_view_1d(ctx, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all)))); ggml_view_1d(ctx, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all))));
struct ggml_tensor * y = ggml_view_3d(ctx, y_ssm, d_inner, n_seq_tokens, n_seqs, x->nb[1], x->nb[2], 0); struct ggml_tensor * y = ggml_view_3d(ctx, y_ssm, d_inner, n_seq_tokens, n_seqs, x->nb[2], x->nb[3], 0);
// TODO: skip computing output earlier for unused tokens // TODO: skip computing output earlier for unused tokens
// {d_inner, n_seq_tokens, n_seqs} * {d_inner} => {d_inner, n_seq_tokens, n_seqs}
y = ggml_add(ctx, y, ggml_mul(ctx, x, model.layers[il].ssm_d));
y = ggml_mul(ctx, y, ggml_silu(ctx, ggml_cont(ctx, z))); y = ggml_mul(ctx, y, ggml_silu(ctx, ggml_cont(ctx, z)));
// {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
@ -9157,6 +9268,136 @@ static struct ggml_tensor * llm_build_mamba(
return cur; return cur;
} }
static struct ggml_tensor * llm_build_mamba2(
struct ggml_context * ctx,
struct llama_context & lctx,
const llama_ubatch & batch,
struct ggml_cgraph * graph,
struct ggml_tensor * cur,
struct ggml_tensor * state_copy,
struct ggml_tensor * state_mask,
int32_t kv_head,
int32_t n_kv,
const llm_build_cb & cb,
int il) {
const llama_model & model = lctx.model;
const llama_hparams & hparams = model.hparams;
const llama_kv_cache & kv = lctx.kv_self;
const int64_t d_conv = hparams.ssm_d_conv;
const int64_t d_inner = hparams.ssm_d_inner;
const int64_t d_state = hparams.ssm_d_state;
const int64_t n_head = hparams.ssm_dt_rank;
const int64_t head_dim = d_inner / n_head; // FIXME
const int64_t n_group = hparams.ssm_n_group;
const int64_t n_seqs = batch.n_seqs;
const int64_t n_seq_tokens = batch.n_seq_tokens;
GGML_ASSERT(n_seqs != 0);
GGML_ASSERT(batch.equal_seqs);
GGML_ASSERT(batch.n_tokens == n_seq_tokens * n_seqs);
struct ggml_tensor * conv_states_all = kv.k_l[il];
struct ggml_tensor * ssm_states_all = kv.v_l[il];
// (ab)using the KV cache to store the states
struct ggml_tensor * conv = llm_build_copy_mask_state(ctx,
graph, conv_states_all, state_copy, state_mask,
hparams.n_embd_k_s(), kv.size, kv_head, n_kv, n_seqs);
conv = ggml_reshape_3d(ctx, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs);
struct ggml_tensor * ssm = llm_build_copy_mask_state(ctx,
graph, ssm_states_all, state_copy, state_mask,
hparams.n_embd_v_s(), kv.size, kv_head, n_kv, n_seqs);
ssm = ggml_reshape_4d(ctx, ssm, d_state, head_dim, n_head, n_seqs);
// {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
cur = ggml_reshape_3d(ctx, cur, cur->ne[0], n_seq_tokens, n_seqs);
// d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
// {n_embd, d_in_proj} @ {n_embd, n_seq_tokens, n_seqs} => {d_in_proj, n_seq_tokens, n_seqs}
struct ggml_tensor * zxBCdt = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_in, cur);
// split the above in three
struct ggml_tensor * z = ggml_view_3d(ctx, zxBCdt, d_inner, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], 0);
struct ggml_tensor * xBC = ggml_view_3d(ctx, zxBCdt, d_inner + 2*n_group*d_state, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], d_inner*ggml_element_size(zxBCdt));
struct ggml_tensor * dt = ggml_view_3d(ctx, zxBCdt, n_head, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], (2*d_inner + 2*n_group*d_state)*ggml_element_size(zxBCdt));
// conv
{
// => {d_conv - 1 + n_seq_tokens, d_inner + 2*n_group*d_state, n_seqs}
struct ggml_tensor * conv_x = ggml_concat(ctx, conv, ggml_transpose(ctx, xBC), 0);
// copy last (d_conv - 1) columns back into the state cache
struct ggml_tensor * last_conv = ggml_view_3d(ctx, conv_x, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs, conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0]));
ggml_build_forward_expand(graph,
ggml_cpy(ctx, last_conv,
ggml_view_1d(ctx, conv_states_all,
(d_conv - 1)*(d_inner + 2*n_group*d_state)*(n_seqs),
kv_head*(d_conv - 1)*(d_inner)*ggml_element_size(conv_states_all))));
// 1D convolution
// The equivalent is to make a self-overlapping view of conv_x
// over d_conv columns at each stride in the 3rd dimension,
// then element-wise multiply that with the conv1d weight,
// then sum the elements of each row,
// (the last two steps are a dot product over rows (also doable with mul_mat))
// then permute away the ne[0] dimension,
// and then you're left with the resulting x tensor.
// For simultaneous sequences, all sequences need to have the same length.
xBC = ggml_ssm_conv(ctx, conv_x, model.layers[il].ssm_conv1d);
// bias
xBC = ggml_add(ctx, xBC, model.layers[il].ssm_conv1d_b);
xBC = ggml_silu(ctx, xBC);
}
// ssm
{
// These correspond to V K Q in SSM/attention duality
struct ggml_tensor * x = ggml_view_4d(ctx, xBC, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*xBC->nb[0], xBC->nb[1], xBC->nb[2], 0);
struct ggml_tensor * B = ggml_view_4d(ctx, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], d_inner*ggml_element_size(xBC));
struct ggml_tensor * C = ggml_view_4d(ctx, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], (d_inner + n_group*d_state)*ggml_element_size(xBC));
// {n_head, n_seq_tokens, n_seqs}
dt = ggml_add(ctx, dt, model.layers[il].ssm_dt_b);
// TODO: use semistructured matrices to implement state-space duality
// => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C, model.layers[il].ssm_d);
// store last states
ggml_build_forward_expand(graph,
ggml_cpy(ctx,
ggml_view_1d(ctx, y_ssm, d_state*d_inner*n_seqs, ggml_nelements(x)*x->nb[0]),
ggml_view_1d(ctx, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all))));
struct ggml_tensor * y = ggml_view_3d(ctx, y_ssm, d_inner, n_seq_tokens, n_seqs, n_head*x->nb[1], n_seq_tokens*n_head*x->nb[1], 0);
// TODO: skip computing output earlier for unused tokens
y = ggml_mul(ctx, y, ggml_silu(ctx, ggml_cont(ctx, z)));
// grouped RMS norm
y = ggml_reshape_4d(ctx, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs);
y = llm_build_norm(ctx, y, hparams,
model.layers[il].ssm_norm, NULL,
LLM_NORM_RMS, cb, il);
y = ggml_reshape_3d(ctx, y, d_inner, n_seq_tokens, n_seqs);
// {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
cur = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_out, y);
}
// {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
cur = ggml_reshape_2d(ctx, cur, cur->ne[0], n_seq_tokens * n_seqs);
cb(cur, "mamba_out", il);
return cur;
}
struct llm_build_context { struct llm_build_context {
const llama_model & model; const llama_model & model;
llama_context & lctx; llama_context & lctx;
@ -12788,7 +13029,7 @@ struct llm_build_context {
return gf; return gf;
} }
struct ggml_cgraph * build_mamba() { struct ggml_cgraph * build_mamba(int32_t version = 1) {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
struct ggml_tensor * cur; struct ggml_tensor * cur;
@ -12807,9 +13048,19 @@ struct llm_build_context {
LLM_NORM_RMS, cb, il); LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il); cb(cur, "attn_norm", il);
switch (version) {
case 2:
cur = llm_build_mamba2(ctx0, lctx, batch, gf, cur,
state_copy, state_mask,
kv_head, n_kv, cb, il);
break;
case 1:
default:
cur = llm_build_mamba(ctx0, lctx, batch, gf, cur, cur = llm_build_mamba(ctx0, lctx, batch, gf, cur,
state_copy, state_mask, state_copy, state_mask,
kv_head, n_kv, cb, il); kv_head, n_kv, cb, il);
break;
}
if (il == n_layer - 1) { if (il == n_layer - 1) {
// skip computing output for unused tokens // skip computing output for unused tokens
@ -14858,7 +15109,11 @@ static struct ggml_cgraph * llama_build_graph(
} break; } break;
case LLM_ARCH_MAMBA: case LLM_ARCH_MAMBA:
{ {
result = llm.build_mamba(); result = llm.build_mamba(/* version */ 1);
} break;
case LLM_ARCH_MAMBA2:
{
result = llm.build_mamba(/* version */ 2);
} break; } break;
case LLM_ARCH_XVERSE: case LLM_ARCH_XVERSE:
{ {
@ -17954,6 +18209,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
case LLM_ARCH_REFACT: case LLM_ARCH_REFACT:
case LLM_ARCH_BLOOM: case LLM_ARCH_BLOOM:
case LLM_ARCH_MAMBA: case LLM_ARCH_MAMBA:
case LLM_ARCH_MAMBA2:
case LLM_ARCH_JINA_BERT_V2: case LLM_ARCH_JINA_BERT_V2:
case LLM_ARCH_T5: case LLM_ARCH_T5:
case LLM_ARCH_T5ENCODER: case LLM_ARCH_T5ENCODER:
@ -18125,6 +18381,7 @@ llama_token llama_model_decoder_start_token(const struct llama_model * model) {
bool llama_model_is_recurrent(const struct llama_model * model) { bool llama_model_is_recurrent(const struct llama_model * model) {
switch (model->arch) { switch (model->arch) {
case LLM_ARCH_MAMBA2:
case LLM_ARCH_MAMBA: return true; case LLM_ARCH_MAMBA: return true;
default: return false; default: return false;
} }