mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 11:24:35 +00:00
Compare commits
3 Commits
253578003e
...
aafdf7b0c9
Author | SHA1 | Date | |
---|---|---|---|
|
aafdf7b0c9 | ||
|
09fe2e7613 | ||
|
2116f48bec |
@ -3379,6 +3379,24 @@ class CommandR2Model(Model):
|
|||||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
|
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
|
||||||
|
|
||||||
|
|
||||||
|
@Model.register("Cohere2ForCausalLM")
|
||||||
|
class Cohere2Model(Model):
|
||||||
|
model_arch = gguf.MODEL_ARCH.COHERE2
|
||||||
|
|
||||||
|
def set_gguf_parameters(self):
|
||||||
|
super().set_gguf_parameters()
|
||||||
|
|
||||||
|
self.gguf_writer.add_logit_scale(self.hparams["logit_scale"])
|
||||||
|
self.gguf_writer.add_sliding_window(self.hparams["sliding_window"])
|
||||||
|
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
|
||||||
|
|
||||||
|
rotary_pct = self.hparams["rotary_pct"]
|
||||||
|
hidden_size = self.hparams["hidden_size"]
|
||||||
|
num_attention_heads = self.hparams["num_attention_heads"]
|
||||||
|
self.gguf_writer.add_rope_dimension_count(int(rotary_pct * (hidden_size // num_attention_heads)))
|
||||||
|
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
|
||||||
|
|
||||||
|
|
||||||
@Model.register("OlmoForCausalLM")
|
@Model.register("OlmoForCausalLM")
|
||||||
@Model.register("OLMoForCausalLM")
|
@Model.register("OLMoForCausalLM")
|
||||||
class OlmoModel(Model):
|
class OlmoModel(Model):
|
||||||
|
@ -450,6 +450,8 @@ These words will not be included in the completion, so make sure to add them to
|
|||||||
|
|
||||||
`post_sampling_probs`: Returns the probabilities of top `n_probs` tokens after applying sampling chain.
|
`post_sampling_probs`: Returns the probabilities of top `n_probs` tokens after applying sampling chain.
|
||||||
|
|
||||||
|
`response_fields`: A list of response fields, for example: `"response_fields": ["content", "generation_settings/n_predict"]`. If the specified field is missing, it will simply be omitted from the response without triggering an error.
|
||||||
|
|
||||||
**Response format**
|
**Response format**
|
||||||
|
|
||||||
- Note: In streaming mode (`stream`), only `content`, `tokens` and `stop` will be returned until end of completion. Responses are sent using the [Server-sent events](https://html.spec.whatwg.org/multipage/server-sent-events.html) standard. Note: the browser's `EventSource` interface cannot be used due to its lack of `POST` request support.
|
- Note: In streaming mode (`stream`), only `content`, `tokens` and `stop` will be returned until end of completion. Responses are sent using the [Server-sent events](https://html.spec.whatwg.org/multipage/server-sent-events.html) standard. Note: the browser's `EventSource` interface cannot be used due to its lack of `POST` request support.
|
||||||
|
@ -92,6 +92,7 @@ struct slot_params {
|
|||||||
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
|
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
|
||||||
|
|
||||||
std::vector<std::string> antiprompt;
|
std::vector<std::string> antiprompt;
|
||||||
|
std::vector<std::string> response_fields;
|
||||||
bool timings_per_token = false;
|
bool timings_per_token = false;
|
||||||
bool post_sampling_probs = false;
|
bool post_sampling_probs = false;
|
||||||
bool ignore_eos = false;
|
bool ignore_eos = false;
|
||||||
@ -209,6 +210,7 @@ struct server_task {
|
|||||||
params.n_discard = json_value(data, "n_discard", defaults.n_discard);
|
params.n_discard = json_value(data, "n_discard", defaults.n_discard);
|
||||||
//params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
|
//params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
|
||||||
params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
|
params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
|
||||||
|
params.response_fields = json_value(data, "response_fields", std::vector<std::string>());
|
||||||
|
|
||||||
params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k);
|
params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k);
|
||||||
params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p);
|
params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p);
|
||||||
@ -522,6 +524,7 @@ struct server_task_result_cmpl_final : server_task_result {
|
|||||||
|
|
||||||
bool post_sampling_probs;
|
bool post_sampling_probs;
|
||||||
std::vector<completion_token_output> probs_output;
|
std::vector<completion_token_output> probs_output;
|
||||||
|
std::vector<std::string> response_fields;
|
||||||
|
|
||||||
slot_params generation_params;
|
slot_params generation_params;
|
||||||
|
|
||||||
@ -568,7 +571,7 @@ struct server_task_result_cmpl_final : server_task_result {
|
|||||||
if (!stream && !probs_output.empty()) {
|
if (!stream && !probs_output.empty()) {
|
||||||
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs);
|
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs);
|
||||||
}
|
}
|
||||||
return res;
|
return response_fields.empty() ? res : json_get_nested_values(response_fields, res);
|
||||||
}
|
}
|
||||||
|
|
||||||
json to_json_oaicompat_chat() {
|
json to_json_oaicompat_chat() {
|
||||||
@ -2066,6 +2069,7 @@ struct server_context {
|
|||||||
res->tokens = slot.generated_tokens;
|
res->tokens = slot.generated_tokens;
|
||||||
res->timings = slot.get_timings();
|
res->timings = slot.get_timings();
|
||||||
res->prompt = common_detokenize(ctx, slot.prompt_tokens, true);
|
res->prompt = common_detokenize(ctx, slot.prompt_tokens, true);
|
||||||
|
res->response_fields = slot.params.response_fields;
|
||||||
|
|
||||||
res->truncated = slot.truncated;
|
res->truncated = slot.truncated;
|
||||||
res->n_decoded = slot.n_decoded;
|
res->n_decoded = slot.n_decoded;
|
||||||
|
@ -257,6 +257,40 @@ def test_completion_parallel_slots(n_slots: int, n_requests: int):
|
|||||||
# assert match_regex(re_content, res.body["content"])
|
# assert match_regex(re_content, res.body["content"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"prompt,n_predict,response_fields",
|
||||||
|
[
|
||||||
|
("I believe the meaning of life is", 8, []),
|
||||||
|
("I believe the meaning of life is", 32, ["content", "generation_settings/n_predict", "prompt"]),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_completion_response_fields(
|
||||||
|
prompt: str, n_predict: int, response_fields: list[str]
|
||||||
|
):
|
||||||
|
global server
|
||||||
|
server.start()
|
||||||
|
res = server.make_request(
|
||||||
|
"POST",
|
||||||
|
"/completion",
|
||||||
|
data={
|
||||||
|
"n_predict": n_predict,
|
||||||
|
"prompt": prompt,
|
||||||
|
"response_fields": response_fields,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert "content" in res.body
|
||||||
|
assert len(res.body["content"])
|
||||||
|
if len(response_fields):
|
||||||
|
assert res.body["generation_settings/n_predict"] == n_predict
|
||||||
|
assert res.body["prompt"] == "<s> " + prompt
|
||||||
|
assert isinstance(res.body["content"], str)
|
||||||
|
assert len(res.body) == len(response_fields)
|
||||||
|
else:
|
||||||
|
assert len(res.body)
|
||||||
|
assert "generation_settings" in res.body
|
||||||
|
|
||||||
|
|
||||||
def test_n_probs():
|
def test_n_probs():
|
||||||
global server
|
global server
|
||||||
server.start()
|
server.start()
|
||||||
|
@ -90,6 +90,28 @@ static bool json_is_array_of_mixed_numbers_strings(const json & data) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get value by path(key1 / key2)
|
||||||
|
static json json_get_nested_values(const std::vector<std::string> & paths, const json & js) {
|
||||||
|
json result = json::object();
|
||||||
|
|
||||||
|
for (const std::string & path : paths) {
|
||||||
|
json current = js;
|
||||||
|
const auto keys = string_split<std::string>(path, /*separator*/ '/');
|
||||||
|
bool valid_path = true;
|
||||||
|
for (const std::string & k : keys) {
|
||||||
|
if (valid_path && current.is_object() && current.contains(k)) {
|
||||||
|
current = current[k];
|
||||||
|
} else {
|
||||||
|
valid_path = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (valid_path) {
|
||||||
|
result[path] = current;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* this handles 2 cases:
|
* this handles 2 cases:
|
||||||
* - only string, example: "string"
|
* - only string, example: "string"
|
||||||
|
@ -255,6 +255,7 @@ class MODEL_ARCH(IntEnum):
|
|||||||
MAMBA = auto()
|
MAMBA = auto()
|
||||||
XVERSE = auto()
|
XVERSE = auto()
|
||||||
COMMAND_R = auto()
|
COMMAND_R = auto()
|
||||||
|
COHERE2 = auto()
|
||||||
DBRX = auto()
|
DBRX = auto()
|
||||||
OLMO = auto()
|
OLMO = auto()
|
||||||
OLMO2 = auto()
|
OLMO2 = auto()
|
||||||
@ -437,6 +438,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
|||||||
MODEL_ARCH.MAMBA: "mamba",
|
MODEL_ARCH.MAMBA: "mamba",
|
||||||
MODEL_ARCH.XVERSE: "xverse",
|
MODEL_ARCH.XVERSE: "xverse",
|
||||||
MODEL_ARCH.COMMAND_R: "command-r",
|
MODEL_ARCH.COMMAND_R: "command-r",
|
||||||
|
MODEL_ARCH.COHERE2: "cohere2",
|
||||||
MODEL_ARCH.DBRX: "dbrx",
|
MODEL_ARCH.DBRX: "dbrx",
|
||||||
MODEL_ARCH.OLMO: "olmo",
|
MODEL_ARCH.OLMO: "olmo",
|
||||||
MODEL_ARCH.OLMO2: "olmo2",
|
MODEL_ARCH.OLMO2: "olmo2",
|
||||||
@ -1136,6 +1138,18 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||||||
MODEL_TENSOR.ATTN_K_NORM,
|
MODEL_TENSOR.ATTN_K_NORM,
|
||||||
MODEL_TENSOR.ATTN_Q_NORM,
|
MODEL_TENSOR.ATTN_Q_NORM,
|
||||||
],
|
],
|
||||||
|
MODEL_ARCH.COHERE2: [
|
||||||
|
MODEL_TENSOR.TOKEN_EMBD,
|
||||||
|
MODEL_TENSOR.OUTPUT_NORM,
|
||||||
|
MODEL_TENSOR.ATTN_NORM,
|
||||||
|
MODEL_TENSOR.ATTN_Q,
|
||||||
|
MODEL_TENSOR.ATTN_K,
|
||||||
|
MODEL_TENSOR.ATTN_V,
|
||||||
|
MODEL_TENSOR.ATTN_OUT,
|
||||||
|
MODEL_TENSOR.FFN_GATE,
|
||||||
|
MODEL_TENSOR.FFN_DOWN,
|
||||||
|
MODEL_TENSOR.FFN_UP,
|
||||||
|
],
|
||||||
MODEL_ARCH.DBRX: [
|
MODEL_ARCH.DBRX: [
|
||||||
MODEL_TENSOR.TOKEN_EMBD,
|
MODEL_TENSOR.TOKEN_EMBD,
|
||||||
MODEL_TENSOR.OUTPUT_NORM,
|
MODEL_TENSOR.OUTPUT_NORM,
|
||||||
|
189
src/llama.cpp
189
src/llama.cpp
@ -179,6 +179,7 @@ enum llm_arch {
|
|||||||
LLM_ARCH_MAMBA,
|
LLM_ARCH_MAMBA,
|
||||||
LLM_ARCH_XVERSE,
|
LLM_ARCH_XVERSE,
|
||||||
LLM_ARCH_COMMAND_R,
|
LLM_ARCH_COMMAND_R,
|
||||||
|
LLM_ARCH_COHERE2,
|
||||||
LLM_ARCH_DBRX,
|
LLM_ARCH_DBRX,
|
||||||
LLM_ARCH_OLMO,
|
LLM_ARCH_OLMO,
|
||||||
LLM_ARCH_OLMO2,
|
LLM_ARCH_OLMO2,
|
||||||
@ -237,6 +238,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|||||||
{ LLM_ARCH_MAMBA, "mamba" },
|
{ LLM_ARCH_MAMBA, "mamba" },
|
||||||
{ LLM_ARCH_XVERSE, "xverse" },
|
{ LLM_ARCH_XVERSE, "xverse" },
|
||||||
{ LLM_ARCH_COMMAND_R, "command-r" },
|
{ LLM_ARCH_COMMAND_R, "command-r" },
|
||||||
|
{ LLM_ARCH_COHERE2, "cohere2" },
|
||||||
{ LLM_ARCH_DBRX, "dbrx" },
|
{ LLM_ARCH_DBRX, "dbrx" },
|
||||||
{ LLM_ARCH_OLMO, "olmo" },
|
{ LLM_ARCH_OLMO, "olmo" },
|
||||||
{ LLM_ARCH_OLMO2, "olmo2" },
|
{ LLM_ARCH_OLMO2, "olmo2" },
|
||||||
@ -1268,6 +1270,21 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|||||||
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
LLM_ARCH_COHERE2,
|
||||||
|
{
|
||||||
|
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||||
|
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||||
|
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||||
|
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||||
|
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||||
|
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||||
|
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||||
|
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||||
|
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||||
|
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
LLM_ARCH_DBRX,
|
LLM_ARCH_DBRX,
|
||||||
{
|
{
|
||||||
@ -6151,6 +6168,16 @@ static void llm_load_hparams(
|
|||||||
default: model.type = e_model::MODEL_UNKNOWN;
|
default: model.type = e_model::MODEL_UNKNOWN;
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case LLM_ARCH_COHERE2:
|
||||||
|
{
|
||||||
|
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
|
||||||
|
ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
|
||||||
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
||||||
|
switch (hparams.n_layer) {
|
||||||
|
case 32: model.type = e_model::MODEL_8B; break;
|
||||||
|
default: model.type = e_model::MODEL_UNKNOWN;
|
||||||
|
}
|
||||||
|
} break;
|
||||||
case LLM_ARCH_DBRX:
|
case LLM_ARCH_DBRX:
|
||||||
{
|
{
|
||||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
||||||
@ -8970,6 +8997,32 @@ static bool llm_load_tensors(
|
|||||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case LLM_ARCH_COHERE2:
|
||||||
|
{
|
||||||
|
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
|
||||||
|
|
||||||
|
// output
|
||||||
|
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
|
||||||
|
// init output from the input tok embed
|
||||||
|
model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab },
|
||||||
|
llama_model_loader::TENSOR_DUPLICATED);
|
||||||
|
|
||||||
|
for (int i = 0; i < n_layer; ++i) {
|
||||||
|
auto & layer = model.layers[i];
|
||||||
|
|
||||||
|
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
|
||||||
|
|
||||||
|
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd }, 0);
|
||||||
|
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_gqa }, 0);
|
||||||
|
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_gqa }, 0);
|
||||||
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0);
|
||||||
|
|
||||||
|
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0);
|
||||||
|
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0);
|
||||||
|
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break;
|
||||||
case LLM_ARCH_OLMO: // adapted from LLM_ARCH_LLAMA with norm params removed
|
case LLM_ARCH_OLMO: // adapted from LLM_ARCH_LLAMA with norm params removed
|
||||||
{
|
{
|
||||||
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||||
@ -15051,6 +15104,137 @@ struct llm_build_context {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct ggml_cgraph * build_cohere2() {
|
||||||
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
||||||
|
|
||||||
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||||
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||||
|
const float f_logit_scale = hparams.f_logit_scale;
|
||||||
|
|
||||||
|
struct ggml_tensor * cur;
|
||||||
|
struct ggml_tensor * inpL;
|
||||||
|
|
||||||
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
||||||
|
|
||||||
|
// inp_pos - contains the positions
|
||||||
|
struct ggml_tensor * inp_pos = build_inp_pos();
|
||||||
|
|
||||||
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
|
// cohere2 requires different mask for layers using sliding window (SWA)
|
||||||
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
||||||
|
struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa();
|
||||||
|
|
||||||
|
// sliding window switch pattern
|
||||||
|
const int32_t sliding_window_pattern = 4;
|
||||||
|
|
||||||
|
for (int il = 0; il < n_layer; ++il) {
|
||||||
|
// three layers sliding window attention (window size 4096) and ROPE
|
||||||
|
// fourth layer uses global attention without positional embeddings
|
||||||
|
const bool is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1);
|
||||||
|
struct ggml_tensor * KQ_mask_l = is_sliding ? KQ_mask_swa : KQ_mask;
|
||||||
|
|
||||||
|
// norm
|
||||||
|
cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM, cb, il);
|
||||||
|
cb(cur, "attn_norm", il);
|
||||||
|
struct ggml_tensor * ffn_inp = cur;
|
||||||
|
|
||||||
|
// self-attention
|
||||||
|
{
|
||||||
|
// rope freq factors for 128k context
|
||||||
|
struct ggml_tensor * rope_factors = build_rope_factors(il);
|
||||||
|
|
||||||
|
// compute Q and K and RoPE them
|
||||||
|
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
|
||||||
|
cb(Qcur, "Qcur", il);
|
||||||
|
if (model.layers[il].bq) {
|
||||||
|
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
||||||
|
cb(Qcur, "Qcur", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
|
||||||
|
cb(Kcur, "Kcur", il);
|
||||||
|
if (model.layers[il].bk) {
|
||||||
|
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
||||||
|
cb(Kcur, "Kcur", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
|
||||||
|
cb(Vcur, "Vcur", il);
|
||||||
|
if (model.layers[il].bv) {
|
||||||
|
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
||||||
|
cb(Vcur, "Vcur", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (is_sliding) {
|
||||||
|
Qcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
|
||||||
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor,
|
||||||
|
beta_fast, beta_slow);
|
||||||
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
|
Kcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
|
||||||
|
rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor,
|
||||||
|
attn_factor, beta_fast, beta_slow);
|
||||||
|
cb(Kcur, "Kcur", il);
|
||||||
|
} else {
|
||||||
|
// For non-sliding layers, just reshape without applying RoPE
|
||||||
|
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||||
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
|
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||||
|
cb(Kcur, "Kcur", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur,
|
||||||
|
KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f / sqrtf(float(n_embd_head)), cb, il);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (il == n_layer - 1) {
|
||||||
|
// skip computing output for unused tokens
|
||||||
|
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||||
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||||
|
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
||||||
|
ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * attn_out = cur;
|
||||||
|
|
||||||
|
// feed-forward network
|
||||||
|
{
|
||||||
|
cur = llm_build_ffn(ctx0, lctx, ffn_inp, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate,
|
||||||
|
NULL, NULL, model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR,
|
||||||
|
cb, il);
|
||||||
|
cb(cur, "ffn_out", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
// add together residual + FFN + self-attention
|
||||||
|
cur = ggml_add(ctx0, cur, inpL);
|
||||||
|
cur = ggml_add(ctx0, cur, attn_out);
|
||||||
|
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
||||||
|
cb(cur, "l_out", il);
|
||||||
|
|
||||||
|
// input for next layer
|
||||||
|
inpL = cur;
|
||||||
|
}
|
||||||
|
|
||||||
|
cur = inpL;
|
||||||
|
|
||||||
|
cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, NULL, LLM_NORM, cb, -1);
|
||||||
|
cb(cur, "result_norm", -1);
|
||||||
|
|
||||||
|
// lm_head
|
||||||
|
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
||||||
|
|
||||||
|
if (f_logit_scale) {
|
||||||
|
cur = ggml_scale(ctx0, cur, f_logit_scale);
|
||||||
|
}
|
||||||
|
|
||||||
|
cb(cur, "result_output", -1);
|
||||||
|
|
||||||
|
ggml_build_forward_expand(gf, cur);
|
||||||
|
|
||||||
|
return gf;
|
||||||
|
}
|
||||||
|
|
||||||
// ref: https://allenai.org/olmo
|
// ref: https://allenai.org/olmo
|
||||||
// based on the original build_llama() function, changes:
|
// based on the original build_llama() function, changes:
|
||||||
// * non-parametric layer norm
|
// * non-parametric layer norm
|
||||||
@ -17802,6 +17986,10 @@ static struct ggml_cgraph * llama_build_graph(
|
|||||||
{
|
{
|
||||||
result = llm.build_command_r();
|
result = llm.build_command_r();
|
||||||
} break;
|
} break;
|
||||||
|
case LLM_ARCH_COHERE2:
|
||||||
|
{
|
||||||
|
result = llm.build_cohere2();
|
||||||
|
} break;
|
||||||
case LLM_ARCH_DBRX:
|
case LLM_ARCH_DBRX:
|
||||||
{
|
{
|
||||||
result = llm.build_dbrx();
|
result = llm.build_dbrx();
|
||||||
@ -21075,6 +21263,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
|
|||||||
case LLM_ARCH_MINICPM:
|
case LLM_ARCH_MINICPM:
|
||||||
case LLM_ARCH_XVERSE:
|
case LLM_ARCH_XVERSE:
|
||||||
case LLM_ARCH_COMMAND_R:
|
case LLM_ARCH_COMMAND_R:
|
||||||
|
case LLM_ARCH_COHERE2:
|
||||||
case LLM_ARCH_OLMO:
|
case LLM_ARCH_OLMO:
|
||||||
case LLM_ARCH_ARCTIC:
|
case LLM_ARCH_ARCTIC:
|
||||||
case LLM_ARCH_DEEPSEEK:
|
case LLM_ARCH_DEEPSEEK:
|
||||||
|
Loading…
Reference in New Issue
Block a user