mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-09-22 21:16:20 +00:00
add solar pro support
solar pro introduces block skip connections where blocks are connected to other, non-sequential blocks with a scale multiple this change adds 4 new keys to store the skip connections and one new tensor to store the scalar. the scalar is implemented a 1-dimensional tensor with 2 elements dervied from the model's bskcn_tv configuration. in general, the values are (bskcn_tv, 1 - bskcn_tv)
This commit is contained in:
parent
64c6af3195
commit
c42ec2f8bb
@ -4079,6 +4079,25 @@ class ExaoneModel(Model):
|
|||||||
|
|
||||||
super().prepare_tensors()
|
super().prepare_tensors()
|
||||||
|
|
||||||
|
@Model.register("SolarForCausalLM")
|
||||||
|
class SolarModel(LlamaModel):
|
||||||
|
model_arch = gguf.MODEL_ARCH.SOLAR
|
||||||
|
|
||||||
|
def set_gguf_parameters(self):
|
||||||
|
super().set_gguf_parameters()
|
||||||
|
|
||||||
|
for i, bskcn in enumerate(self.hparams[k] for k in self.hparams.keys() if k.startswith("bskcn_") and k != 'bskcn_tv'):
|
||||||
|
# store the skip connections as a layer index where a non-zero value indicates a skip connection
|
||||||
|
# this approach simplifies lookup at inference time
|
||||||
|
self.gguf_writer.add_block_skip_connection(i, [1 if n in bskcn else 0 for n in range(self.block_count)])
|
||||||
|
|
||||||
|
def prepare_tensors(self):
|
||||||
|
if bskcn_tv := self.find_hparam(['bskcn_tv'], optional=True):
|
||||||
|
# use bskcn_tv[1] for inference since bskcn_tv[0] is for training
|
||||||
|
self.gguf_writer.add_tensor(self.format_tensor_name(gguf.MODEL_TENSOR.BSKCN_TV), np.array([bskcn_tv[1], 1 - bskcn_tv[1]], dtype=np.float32))
|
||||||
|
|
||||||
|
super().prepare_tensors()
|
||||||
|
|
||||||
|
|
||||||
@Model.register("GraniteForCausalLM")
|
@Model.register("GraniteForCausalLM")
|
||||||
class GraniteModel(LlamaModel):
|
class GraniteModel(LlamaModel):
|
||||||
|
@ -101,20 +101,21 @@ class Keys:
|
|||||||
EMBEDDING_SCALE = "{arch}.embedding_scale"
|
EMBEDDING_SCALE = "{arch}.embedding_scale"
|
||||||
|
|
||||||
class Attention:
|
class Attention:
|
||||||
HEAD_COUNT = "{arch}.attention.head_count"
|
HEAD_COUNT = "{arch}.attention.head_count"
|
||||||
HEAD_COUNT_KV = "{arch}.attention.head_count_kv"
|
HEAD_COUNT_KV = "{arch}.attention.head_count_kv"
|
||||||
MAX_ALIBI_BIAS = "{arch}.attention.max_alibi_bias"
|
MAX_ALIBI_BIAS = "{arch}.attention.max_alibi_bias"
|
||||||
CLAMP_KQV = "{arch}.attention.clamp_kqv"
|
CLAMP_KQV = "{arch}.attention.clamp_kqv"
|
||||||
KEY_LENGTH = "{arch}.attention.key_length"
|
KEY_LENGTH = "{arch}.attention.key_length"
|
||||||
VALUE_LENGTH = "{arch}.attention.value_length"
|
VALUE_LENGTH = "{arch}.attention.value_length"
|
||||||
LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon"
|
LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon"
|
||||||
LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
|
LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
|
||||||
CAUSAL = "{arch}.attention.causal"
|
CAUSAL = "{arch}.attention.causal"
|
||||||
Q_LORA_RANK = "{arch}.attention.q_lora_rank"
|
Q_LORA_RANK = "{arch}.attention.q_lora_rank"
|
||||||
KV_LORA_RANK = "{arch}.attention.kv_lora_rank"
|
KV_LORA_RANK = "{arch}.attention.kv_lora_rank"
|
||||||
REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count"
|
REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count"
|
||||||
SLIDING_WINDOW = "{arch}.attention.sliding_window"
|
SLIDING_WINDOW = "{arch}.attention.sliding_window"
|
||||||
SCALE = "{arch}.attention.scale"
|
SCALE = "{arch}.attention.scale"
|
||||||
|
BLOCK_SKIP_CONNECTION = "{arch}.attention.block_skip_connection.{n}"
|
||||||
|
|
||||||
class Rope:
|
class Rope:
|
||||||
DIMENSION_COUNT = "{arch}.rope.dimension_count"
|
DIMENSION_COUNT = "{arch}.rope.dimension_count"
|
||||||
@ -235,6 +236,7 @@ class MODEL_ARCH(IntEnum):
|
|||||||
NEMOTRON = auto()
|
NEMOTRON = auto()
|
||||||
EXAONE = auto()
|
EXAONE = auto()
|
||||||
GRANITE = auto()
|
GRANITE = auto()
|
||||||
|
SOLAR = auto()
|
||||||
|
|
||||||
|
|
||||||
class MODEL_TENSOR(IntEnum):
|
class MODEL_TENSOR(IntEnum):
|
||||||
@ -342,6 +344,7 @@ class MODEL_TENSOR(IntEnum):
|
|||||||
ENC_FFN_DOWN = auto()
|
ENC_FFN_DOWN = auto()
|
||||||
ENC_FFN_UP = auto()
|
ENC_FFN_UP = auto()
|
||||||
ENC_OUTPUT_NORM = auto()
|
ENC_OUTPUT_NORM = auto()
|
||||||
|
BSKCN_TV = auto()
|
||||||
|
|
||||||
|
|
||||||
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||||
@ -392,6 +395,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
|||||||
MODEL_ARCH.NEMOTRON: "nemotron",
|
MODEL_ARCH.NEMOTRON: "nemotron",
|
||||||
MODEL_ARCH.EXAONE: "exaone",
|
MODEL_ARCH.EXAONE: "exaone",
|
||||||
MODEL_ARCH.GRANITE: "granite",
|
MODEL_ARCH.GRANITE: "granite",
|
||||||
|
MODEL_ARCH.SOLAR: "solar",
|
||||||
}
|
}
|
||||||
|
|
||||||
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||||
@ -499,6 +503,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
|||||||
MODEL_TENSOR.ENC_FFN_DOWN: "enc.blk.{bid}.ffn_down",
|
MODEL_TENSOR.ENC_FFN_DOWN: "enc.blk.{bid}.ffn_down",
|
||||||
MODEL_TENSOR.ENC_FFN_UP: "enc.blk.{bid}.ffn_up",
|
MODEL_TENSOR.ENC_FFN_UP: "enc.blk.{bid}.ffn_up",
|
||||||
MODEL_TENSOR.ENC_OUTPUT_NORM: "enc.output_norm",
|
MODEL_TENSOR.ENC_OUTPUT_NORM: "enc.output_norm",
|
||||||
|
MODEL_TENSOR.BSKCN_TV: "bskcn_tv",
|
||||||
}
|
}
|
||||||
|
|
||||||
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||||
@ -521,6 +526,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||||||
MODEL_TENSOR.FFN_GATE_EXP,
|
MODEL_TENSOR.FFN_GATE_EXP,
|
||||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||||
MODEL_TENSOR.FFN_UP_EXP,
|
MODEL_TENSOR.FFN_UP_EXP,
|
||||||
|
MODEL_TENSOR.BSKCN_TV,
|
||||||
],
|
],
|
||||||
MODEL_ARCH.GROK: [
|
MODEL_ARCH.GROK: [
|
||||||
MODEL_TENSOR.TOKEN_EMBD,
|
MODEL_TENSOR.TOKEN_EMBD,
|
||||||
@ -1242,6 +1248,21 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||||||
MODEL_TENSOR.FFN_DOWN,
|
MODEL_TENSOR.FFN_DOWN,
|
||||||
MODEL_TENSOR.FFN_UP,
|
MODEL_TENSOR.FFN_UP,
|
||||||
],
|
],
|
||||||
|
MODEL_ARCH.SOLAR: [
|
||||||
|
MODEL_TENSOR.TOKEN_EMBD,
|
||||||
|
MODEL_TENSOR.OUTPUT_NORM,
|
||||||
|
MODEL_TENSOR.OUTPUT,
|
||||||
|
MODEL_TENSOR.ATTN_NORM,
|
||||||
|
MODEL_TENSOR.ATTN_Q,
|
||||||
|
MODEL_TENSOR.ATTN_K,
|
||||||
|
MODEL_TENSOR.ATTN_V,
|
||||||
|
MODEL_TENSOR.ATTN_OUT,
|
||||||
|
MODEL_TENSOR.FFN_NORM,
|
||||||
|
MODEL_TENSOR.FFN_GATE,
|
||||||
|
MODEL_TENSOR.FFN_DOWN,
|
||||||
|
MODEL_TENSOR.FFN_UP,
|
||||||
|
MODEL_TENSOR.BSKCN_TV,
|
||||||
|
],
|
||||||
# TODO
|
# TODO
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -712,6 +712,9 @@ class GGUFWriter:
|
|||||||
def add_attention_scale(self, value: float) -> None:
|
def add_attention_scale(self, value: float) -> None:
|
||||||
self.add_float32(Keys.Attention.SCALE.format(arch=self.arch), value)
|
self.add_float32(Keys.Attention.SCALE.format(arch=self.arch), value)
|
||||||
|
|
||||||
|
def add_block_skip_connection(self, n: int, value: list[int]) -> None:
|
||||||
|
self.add_array(Keys.Attention.BLOCK_SKIP_CONNECTION.format(arch=self.arch, n=n), value)
|
||||||
|
|
||||||
def add_pooling_type(self, value: PoolingType) -> None:
|
def add_pooling_type(self, value: PoolingType) -> None:
|
||||||
self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value)
|
self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value)
|
||||||
|
|
||||||
|
241
src/llama.cpp
241
src/llama.cpp
@ -215,6 +215,7 @@ enum llm_arch {
|
|||||||
LLM_ARCH_EXAONE,
|
LLM_ARCH_EXAONE,
|
||||||
LLM_ARCH_RWKV6,
|
LLM_ARCH_RWKV6,
|
||||||
LLM_ARCH_GRANITE,
|
LLM_ARCH_GRANITE,
|
||||||
|
LLM_ARCH_SOLAR,
|
||||||
LLM_ARCH_UNKNOWN,
|
LLM_ARCH_UNKNOWN,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -266,6 +267,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|||||||
{ LLM_ARCH_EXAONE, "exaone" },
|
{ LLM_ARCH_EXAONE, "exaone" },
|
||||||
{ LLM_ARCH_RWKV6, "rwkv6" },
|
{ LLM_ARCH_RWKV6, "rwkv6" },
|
||||||
{ LLM_ARCH_GRANITE, "granite" },
|
{ LLM_ARCH_GRANITE, "granite" },
|
||||||
|
{ LLM_ARCH_SOLAR, "solar" },
|
||||||
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -322,6 +324,7 @@ enum llm_kv {
|
|||||||
LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
|
LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
|
||||||
LLM_KV_ATTENTION_SLIDING_WINDOW,
|
LLM_KV_ATTENTION_SLIDING_WINDOW,
|
||||||
LLM_KV_ATTENTION_SCALE,
|
LLM_KV_ATTENTION_SCALE,
|
||||||
|
LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION,
|
||||||
|
|
||||||
LLM_KV_ROPE_DIMENSION_COUNT,
|
LLM_KV_ROPE_DIMENSION_COUNT,
|
||||||
LLM_KV_ROPE_FREQ_BASE,
|
LLM_KV_ROPE_FREQ_BASE,
|
||||||
@ -429,6 +432,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
|||||||
{ LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
|
{ LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
|
||||||
{ LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
|
{ LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
|
||||||
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
|
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
|
||||||
|
{ LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION, "%s.attention.block_skip_connection.%d" },
|
||||||
|
|
||||||
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
|
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
|
||||||
{ LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
|
{ LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
|
||||||
@ -600,6 +604,7 @@ enum llm_tensor {
|
|||||||
LLM_TENSOR_ENC_FFN_DOWN,
|
LLM_TENSOR_ENC_FFN_DOWN,
|
||||||
LLM_TENSOR_ENC_FFN_UP,
|
LLM_TENSOR_ENC_FFN_UP,
|
||||||
LLM_TENSOR_ENC_OUTPUT_NORM,
|
LLM_TENSOR_ENC_OUTPUT_NORM,
|
||||||
|
LLM_TENSOR_BSKCN_TV,
|
||||||
};
|
};
|
||||||
|
|
||||||
static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES = {
|
static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES = {
|
||||||
@ -1478,6 +1483,24 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
|
|||||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
LLM_ARCH_SOLAR,
|
||||||
|
{
|
||||||
|
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||||
|
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||||
|
{ LLM_TENSOR_OUTPUT, "output" },
|
||||||
|
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||||
|
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||||
|
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||||
|
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||||
|
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||||
|
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||||
|
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||||
|
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||||
|
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||||
|
{ LLM_TENSOR_BSKCN_TV, "bskcn_tv" },
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
LLM_ARCH_UNKNOWN,
|
LLM_ARCH_UNKNOWN,
|
||||||
{
|
{
|
||||||
@ -2311,6 +2334,7 @@ enum e_model {
|
|||||||
MODEL_15B,
|
MODEL_15B,
|
||||||
MODEL_16B,
|
MODEL_16B,
|
||||||
MODEL_20B,
|
MODEL_20B,
|
||||||
|
MODEL_22B,
|
||||||
MODEL_30B,
|
MODEL_30B,
|
||||||
MODEL_34B,
|
MODEL_34B,
|
||||||
MODEL_35B,
|
MODEL_35B,
|
||||||
@ -2359,6 +2383,8 @@ struct llama_hparams {
|
|||||||
std::array<uint32_t, LLAMA_MAX_LAYERS> n_head_kv_arr;
|
std::array<uint32_t, LLAMA_MAX_LAYERS> n_head_kv_arr;
|
||||||
std::array<uint32_t, LLAMA_MAX_LAYERS> n_ff_arr;
|
std::array<uint32_t, LLAMA_MAX_LAYERS> n_ff_arr;
|
||||||
|
|
||||||
|
std::array<std::array<uint32_t, LLAMA_MAX_LAYERS>, 4> n_bskcn_arr;
|
||||||
|
|
||||||
uint32_t n_layer_dense_lead = 0;
|
uint32_t n_layer_dense_lead = 0;
|
||||||
uint32_t n_lora_q = 0;
|
uint32_t n_lora_q = 0;
|
||||||
uint32_t n_lora_kv = 0;
|
uint32_t n_lora_kv = 0;
|
||||||
@ -2429,6 +2455,7 @@ struct llama_hparams {
|
|||||||
if (this->n_head_arr != other.n_head_arr) return true;
|
if (this->n_head_arr != other.n_head_arr) return true;
|
||||||
if (this->n_head_kv_arr != other.n_head_kv_arr) return true;
|
if (this->n_head_kv_arr != other.n_head_kv_arr) return true;
|
||||||
if (this->n_ff_arr != other.n_ff_arr) return true;
|
if (this->n_ff_arr != other.n_ff_arr) return true;
|
||||||
|
if (this->n_bskcn_arr != other.n_bskcn_arr) return true;
|
||||||
|
|
||||||
if (this->n_rel_attn_bkts != other.n_rel_attn_bkts) return true;
|
if (this->n_rel_attn_bkts != other.n_rel_attn_bkts) return true;
|
||||||
if (this->n_layer_dense_lead != other.n_layer_dense_lead) return true;
|
if (this->n_layer_dense_lead != other.n_layer_dense_lead) return true;
|
||||||
@ -2538,6 +2565,14 @@ struct llama_hparams {
|
|||||||
return ssm_d_state * ssm_d_inner;
|
return ssm_d_state * ssm_d_inner;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool n_bskcn(uint32_t n, uint32_t il = 0) const {
|
||||||
|
if (il < n_layer) {
|
||||||
|
return n_bskcn_arr[n][il] > 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
GGML_ABORT("fatal error");
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
|
static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
|
||||||
@ -2719,6 +2754,8 @@ struct llama_layer {
|
|||||||
struct ggml_tensor * ffn_gate_scale;
|
struct ggml_tensor * ffn_gate_scale;
|
||||||
struct ggml_tensor * ffn_up_scale;
|
struct ggml_tensor * ffn_up_scale;
|
||||||
struct ggml_tensor * ffn_down_scale;
|
struct ggml_tensor * ffn_down_scale;
|
||||||
|
|
||||||
|
struct ggml_tensor * bskcn_tv;
|
||||||
};
|
};
|
||||||
|
|
||||||
// very similar to llama_batch,
|
// very similar to llama_batch,
|
||||||
@ -6065,6 +6102,21 @@ static void llm_load_hparams(
|
|||||||
default: model.type = e_model::MODEL_UNKNOWN;
|
default: model.type = e_model::MODEL_UNKNOWN;
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case LLM_ARCH_SOLAR:
|
||||||
|
{
|
||||||
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||||
|
|
||||||
|
for (int i = 0; i < hparams.n_bskcn_arr.max_size(); ++i) {
|
||||||
|
auto & bskcn = hparams.n_bskcn_arr.at(i);
|
||||||
|
bskcn.fill(0);
|
||||||
|
ml.get_key_or_arr(::format(LLM_KV_NAMES.at(LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION), LLM_ARCH_NAMES.at(ml.llm_kv.arch), i), bskcn, hparams.n_layer, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (hparams.n_layer) {
|
||||||
|
case 64: model.type = e_model::MODEL_22B; break;
|
||||||
|
default: model.type = e_model::MODEL_UNKNOWN;
|
||||||
|
}
|
||||||
|
}
|
||||||
default: (void)0;
|
default: (void)0;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -8665,6 +8717,38 @@ static bool llm_load_tensors(
|
|||||||
}
|
}
|
||||||
|
|
||||||
} break;
|
} break;
|
||||||
|
case LLM_ARCH_SOLAR:
|
||||||
|
{
|
||||||
|
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
|
||||||
|
|
||||||
|
// output
|
||||||
|
{
|
||||||
|
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
|
||||||
|
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < n_layer; ++i) {
|
||||||
|
ggml_context * ctx_layer = ctx_for_layer(i);
|
||||||
|
ggml_context * ctx_split = ctx_for_layer_split(i);
|
||||||
|
|
||||||
|
auto & layer = model.layers[i];
|
||||||
|
|
||||||
|
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
|
||||||
|
|
||||||
|
layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head});
|
||||||
|
layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa});
|
||||||
|
layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa});
|
||||||
|
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd});
|
||||||
|
|
||||||
|
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
|
||||||
|
|
||||||
|
layer.bskcn_tv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_BSKCN_TV, "weight"), {2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
|
||||||
|
|
||||||
|
layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
|
||||||
|
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd});
|
||||||
|
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
|
||||||
|
}
|
||||||
|
} break;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error("unknown architecture");
|
throw std::runtime_error("unknown architecture");
|
||||||
}
|
}
|
||||||
@ -15790,6 +15874,158 @@ struct llm_build_context {
|
|||||||
|
|
||||||
return gf;
|
return gf;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ggml_cgraph * build_solar() {
|
||||||
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
||||||
|
|
||||||
|
// mutable variable, needed during the last layer of the computation to skip unused tokens
|
||||||
|
int32_t n_tokens = this->n_tokens;
|
||||||
|
|
||||||
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||||
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||||
|
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||||
|
|
||||||
|
struct ggml_tensor * cur;
|
||||||
|
struct ggml_tensor * inpL;
|
||||||
|
|
||||||
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
|
||||||
|
|
||||||
|
// inp_pos - contains the positions
|
||||||
|
struct ggml_tensor * inp_pos = build_inp_pos();
|
||||||
|
|
||||||
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
||||||
|
|
||||||
|
struct ggml_tensor * bskcn_1;
|
||||||
|
struct ggml_tensor * bskcn_2;
|
||||||
|
|
||||||
|
for (int il = 0; il < n_layer; ++il) {
|
||||||
|
struct ggml_tensor * inpSA = inpL;
|
||||||
|
|
||||||
|
if (hparams.n_bskcn(0, il)) {
|
||||||
|
bskcn_1 = inpSA;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (hparams.n_bskcn(1, il)) {
|
||||||
|
bskcn_2 = inpSA;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (hparams.n_bskcn(2, il)) {
|
||||||
|
inpSA = ggml_add(
|
||||||
|
ctx0,
|
||||||
|
ggml_mul(ctx0, bskcn_1, ggml_view_1d(ctx0, model.layers[il].bskcn_tv, 1, 0)),
|
||||||
|
ggml_mul(ctx0, inpSA, ggml_view_1d(ctx0, model.layers[il].bskcn_tv, 1, ggml_element_size(model.layers[il].bskcn_tv))));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (hparams.n_bskcn(3, il)) {
|
||||||
|
inpSA = ggml_add(
|
||||||
|
ctx0,
|
||||||
|
ggml_mul(ctx0, bskcn_2, ggml_view_1d(ctx0, model.layers[il].bskcn_tv, 1, 0)),
|
||||||
|
ggml_mul(ctx0, inpSA, ggml_view_1d(ctx0, model.layers[il].bskcn_tv, 1, ggml_element_size(model.layers[il].bskcn_tv))));
|
||||||
|
}
|
||||||
|
|
||||||
|
// norm
|
||||||
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
||||||
|
model.layers[il].attn_norm, NULL,
|
||||||
|
LLM_NORM_RMS, cb, il);
|
||||||
|
cb(cur, "attn_norm", il);
|
||||||
|
|
||||||
|
// self-attention
|
||||||
|
{
|
||||||
|
// rope freq factors for llama3; may return nullptr for llama2 and other models
|
||||||
|
struct ggml_tensor * rope_factors = build_rope_factors(il);
|
||||||
|
|
||||||
|
// compute Q and K and RoPE them
|
||||||
|
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
|
||||||
|
cb(Qcur, "Qcur", il);
|
||||||
|
if (model.layers[il].bq) {
|
||||||
|
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
||||||
|
cb(Qcur, "Qcur", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
|
||||||
|
cb(Kcur, "Kcur", il);
|
||||||
|
if (model.layers[il].bk) {
|
||||||
|
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
||||||
|
cb(Kcur, "Kcur", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
|
||||||
|
cb(Vcur, "Vcur", il);
|
||||||
|
if (model.layers[il].bv) {
|
||||||
|
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
||||||
|
cb(Vcur, "Vcur", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
Qcur = ggml_rope_ext(
|
||||||
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
|
||||||
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
|
);
|
||||||
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
|
Kcur = ggml_rope_ext(
|
||||||
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
|
||||||
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
|
);
|
||||||
|
cb(Kcur, "Kcur", il);
|
||||||
|
|
||||||
|
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
||||||
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (il == n_layer - 1) {
|
||||||
|
// skip computing output for unused tokens
|
||||||
|
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||||
|
n_tokens = n_outputs;
|
||||||
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||||
|
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||||
|
cb(ffn_inp, "ffn_inp", il);
|
||||||
|
|
||||||
|
// feed-forward network
|
||||||
|
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
||||||
|
model.layers[il].ffn_norm, NULL,
|
||||||
|
LLM_NORM_RMS, cb, il);
|
||||||
|
cb(cur, "ffn_norm", il);
|
||||||
|
|
||||||
|
cur = llm_build_ffn(ctx0, lctx, cur,
|
||||||
|
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
|
||||||
|
model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
|
||||||
|
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
|
||||||
|
NULL,
|
||||||
|
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
||||||
|
cb(cur, "ffn_out", il);
|
||||||
|
|
||||||
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||||
|
cb(cur, "ffn_out", il);
|
||||||
|
|
||||||
|
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
||||||
|
cb(cur, "l_out", il);
|
||||||
|
|
||||||
|
// input for next layer
|
||||||
|
inpL = cur;
|
||||||
|
}
|
||||||
|
|
||||||
|
cur = inpL;
|
||||||
|
|
||||||
|
cur = llm_build_norm(ctx0, cur, hparams,
|
||||||
|
model.output_norm, NULL,
|
||||||
|
LLM_NORM_RMS, cb, -1);
|
||||||
|
cb(cur, "result_norm", -1);
|
||||||
|
|
||||||
|
// lm_head
|
||||||
|
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
||||||
|
cb(cur, "result_output", -1);
|
||||||
|
|
||||||
|
ggml_build_forward_expand(gf, cur);
|
||||||
|
|
||||||
|
return gf;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
|
static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
|
||||||
@ -16049,6 +16285,10 @@ static struct ggml_cgraph * llama_build_graph(
|
|||||||
{
|
{
|
||||||
result = llm.build_rwkv6();
|
result = llm.build_rwkv6();
|
||||||
} break;
|
} break;
|
||||||
|
case LLM_ARCH_SOLAR:
|
||||||
|
{
|
||||||
|
result = llm.build_solar();
|
||||||
|
} break;
|
||||||
default:
|
default:
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
}
|
}
|
||||||
@ -19173,6 +19413,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
|
|||||||
case LLM_ARCH_DEEPSEEK2:
|
case LLM_ARCH_DEEPSEEK2:
|
||||||
case LLM_ARCH_CHATGLM:
|
case LLM_ARCH_CHATGLM:
|
||||||
case LLM_ARCH_GRANITE:
|
case LLM_ARCH_GRANITE:
|
||||||
|
case LLM_ARCH_SOLAR:
|
||||||
return LLAMA_ROPE_TYPE_NORM;
|
return LLAMA_ROPE_TYPE_NORM;
|
||||||
|
|
||||||
// the pairs of head values are offset by n_rot/2
|
// the pairs of head values are offset by n_rot/2
|
||||||
|
Loading…
Reference in New Issue
Block a user