mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-11-11 13:30:35 +00:00
stablelm : StableLM support (#3586)
* Add support for stablelm-3b-4e1t * Supports GPU offloading of (n-1) layers
This commit is contained in:
parent
b46d12f86d
commit
36eed0c42c
@ -93,6 +93,7 @@ as the main playground for developing new features for the [ggml](https://github
|
|||||||
- [X] [Persimmon 8B](https://github.com/ggerganov/llama.cpp/pull/3410)
|
- [X] [Persimmon 8B](https://github.com/ggerganov/llama.cpp/pull/3410)
|
||||||
- [X] [MPT](https://github.com/ggerganov/llama.cpp/pull/3417)
|
- [X] [MPT](https://github.com/ggerganov/llama.cpp/pull/3417)
|
||||||
- [X] [Bloom](https://github.com/ggerganov/llama.cpp/pull/3553)
|
- [X] [Bloom](https://github.com/ggerganov/llama.cpp/pull/3553)
|
||||||
|
- [X] [StableLM-3b-4e1t](https://github.com/ggerganov/llama.cpp/pull/3586)
|
||||||
|
|
||||||
|
|
||||||
**Bindings:**
|
**Bindings:**
|
||||||
|
@ -150,8 +150,6 @@ class Model:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_model_architecture(model_architecture):
|
def from_model_architecture(model_architecture):
|
||||||
if model_architecture == "StableLMEpochForCausalLM":
|
|
||||||
return StableLMModel
|
|
||||||
if model_architecture == "GPTNeoXForCausalLM":
|
if model_architecture == "GPTNeoXForCausalLM":
|
||||||
return GPTNeoXModel
|
return GPTNeoXModel
|
||||||
if model_architecture == "BloomForCausalLM":
|
if model_architecture == "BloomForCausalLM":
|
||||||
@ -168,6 +166,8 @@ class Model:
|
|||||||
return RefactModel
|
return RefactModel
|
||||||
if model_architecture == "PersimmonForCausalLM":
|
if model_architecture == "PersimmonForCausalLM":
|
||||||
return PersimmonModel
|
return PersimmonModel
|
||||||
|
if model_architecture in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
|
||||||
|
return StableLMModel
|
||||||
return Model
|
return Model
|
||||||
|
|
||||||
def _is_model_safetensors(self) -> bool:
|
def _is_model_safetensors(self) -> bool:
|
||||||
@ -201,6 +201,8 @@ class Model:
|
|||||||
return gguf.MODEL_ARCH.REFACT
|
return gguf.MODEL_ARCH.REFACT
|
||||||
if arch == "PersimmonForCausalLM":
|
if arch == "PersimmonForCausalLM":
|
||||||
return gguf.MODEL_ARCH.PERSIMMON
|
return gguf.MODEL_ARCH.PERSIMMON
|
||||||
|
if arch in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
|
||||||
|
return gguf.MODEL_ARCH.STABLELM
|
||||||
|
|
||||||
raise NotImplementedError(f'Architecture "{arch}" not supported!')
|
raise NotImplementedError(f'Architecture "{arch}" not supported!')
|
||||||
|
|
||||||
@ -294,15 +296,6 @@ class Model:
|
|||||||
special_vocab.add_to_gguf(self.gguf_writer)
|
special_vocab.add_to_gguf(self.gguf_writer)
|
||||||
|
|
||||||
|
|
||||||
class StableLMModel(Model):
|
|
||||||
def set_gguf_parameters(self):
|
|
||||||
super().set_gguf_parameters()
|
|
||||||
self.gguf_writer.add_rope_dimension_count(
|
|
||||||
int(self.hparams["rope_pct"] * (self.hparams["hidden_size"] // self.hparams["num_attention_heads"])),
|
|
||||||
)
|
|
||||||
self.gguf_writer.add_layer_norm_eps(1e-5)
|
|
||||||
|
|
||||||
|
|
||||||
class GPTNeoXModel(Model):
|
class GPTNeoXModel(Model):
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
block_count = self.hparams["num_hidden_layers"]
|
block_count = self.hparams["num_hidden_layers"]
|
||||||
@ -824,6 +817,21 @@ class PersimmonModel(Model):
|
|||||||
self.gguf_writer.add_tensor(new_name, data)
|
self.gguf_writer.add_tensor(new_name, data)
|
||||||
|
|
||||||
|
|
||||||
|
class StableLMModel(Model):
|
||||||
|
def set_gguf_parameters(self):
|
||||||
|
hparams = self.hparams
|
||||||
|
block_count = hparams["num_hidden_layers"]
|
||||||
|
|
||||||
|
self.gguf_writer.add_name(dir_model.name)
|
||||||
|
self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
|
||||||
|
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
|
||||||
|
self.gguf_writer.add_block_count(block_count)
|
||||||
|
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
|
||||||
|
self.gguf_writer.add_rope_dimension_count(int(hparams["rope_pct"]*(hparams["hidden_size"] // hparams["num_attention_heads"])))
|
||||||
|
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
|
||||||
|
self.gguf_writer.add_parallel_residual(hparams["use_parallel_residual"] if "use_parallel_residual" in hparams else True)
|
||||||
|
self.gguf_writer.add_layer_norm_eps(1e-5)
|
||||||
|
|
||||||
###### CONVERSION LOGIC ######
|
###### CONVERSION LOGIC ######
|
||||||
|
|
||||||
def parse_args() -> argparse.Namespace:
|
def parse_args() -> argparse.Namespace:
|
||||||
|
@ -90,6 +90,7 @@ class MODEL_ARCH(IntEnum):
|
|||||||
REFACT = auto()
|
REFACT = auto()
|
||||||
BERT = auto()
|
BERT = auto()
|
||||||
BLOOM = auto()
|
BLOOM = auto()
|
||||||
|
STABLELM = auto()
|
||||||
|
|
||||||
|
|
||||||
class MODEL_TENSOR(IntEnum):
|
class MODEL_TENSOR(IntEnum):
|
||||||
@ -129,6 +130,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
|||||||
MODEL_ARCH.REFACT: "refact",
|
MODEL_ARCH.REFACT: "refact",
|
||||||
MODEL_ARCH.BERT: "bert",
|
MODEL_ARCH.BERT: "bert",
|
||||||
MODEL_ARCH.BLOOM: "bloom",
|
MODEL_ARCH.BLOOM: "bloom",
|
||||||
|
MODEL_ARCH.STABLELM: "stablelm",
|
||||||
}
|
}
|
||||||
|
|
||||||
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||||
@ -299,6 +301,21 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||||||
MODEL_TENSOR.FFN_DOWN,
|
MODEL_TENSOR.FFN_DOWN,
|
||||||
MODEL_TENSOR.FFN_UP,
|
MODEL_TENSOR.FFN_UP,
|
||||||
],
|
],
|
||||||
|
MODEL_ARCH.STABLELM: [
|
||||||
|
MODEL_TENSOR.TOKEN_EMBD,
|
||||||
|
MODEL_TENSOR.OUTPUT_NORM,
|
||||||
|
MODEL_TENSOR.OUTPUT,
|
||||||
|
MODEL_TENSOR.ROPE_FREQS,
|
||||||
|
MODEL_TENSOR.ATTN_NORM,
|
||||||
|
MODEL_TENSOR.ATTN_Q,
|
||||||
|
MODEL_TENSOR.ATTN_K,
|
||||||
|
MODEL_TENSOR.ATTN_V,
|
||||||
|
MODEL_TENSOR.ATTN_OUT,
|
||||||
|
MODEL_TENSOR.FFN_NORM,
|
||||||
|
MODEL_TENSOR.FFN_GATE,
|
||||||
|
MODEL_TENSOR.FFN_DOWN,
|
||||||
|
MODEL_TENSOR.FFN_UP,
|
||||||
|
],
|
||||||
MODEL_ARCH.GPT2: [
|
MODEL_ARCH.GPT2: [
|
||||||
# TODO
|
# TODO
|
||||||
],
|
],
|
||||||
|
284
llama.cpp
284
llama.cpp
@ -192,6 +192,7 @@ enum llm_arch {
|
|||||||
LLM_ARCH_PERSIMMON,
|
LLM_ARCH_PERSIMMON,
|
||||||
LLM_ARCH_REFACT,
|
LLM_ARCH_REFACT,
|
||||||
LLM_ARCH_BLOOM,
|
LLM_ARCH_BLOOM,
|
||||||
|
LLM_ARCH_STABLELM,
|
||||||
LLM_ARCH_UNKNOWN,
|
LLM_ARCH_UNKNOWN,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -207,6 +208,7 @@ static std::map<llm_arch, std::string> LLM_ARCH_NAMES = {
|
|||||||
{ LLM_ARCH_PERSIMMON, "persimmon" },
|
{ LLM_ARCH_PERSIMMON, "persimmon" },
|
||||||
{ LLM_ARCH_REFACT, "refact" },
|
{ LLM_ARCH_REFACT, "refact" },
|
||||||
{ LLM_ARCH_BLOOM, "bloom" },
|
{ LLM_ARCH_BLOOM, "bloom" },
|
||||||
|
{ LLM_ARCH_STABLELM, "stablelm" },
|
||||||
};
|
};
|
||||||
|
|
||||||
enum llm_kv {
|
enum llm_kv {
|
||||||
@ -495,6 +497,25 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES =
|
|||||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
LLM_ARCH_STABLELM,
|
||||||
|
{
|
||||||
|
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||||
|
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||||
|
{ LLM_TENSOR_OUTPUT, "output" },
|
||||||
|
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
|
||||||
|
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||||
|
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||||
|
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||||
|
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||||
|
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||||
|
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||||
|
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||||
|
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||||
|
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
{
|
{
|
||||||
LLM_ARCH_UNKNOWN,
|
LLM_ARCH_UNKNOWN,
|
||||||
{
|
{
|
||||||
@ -2216,6 +2237,16 @@ static void llm_load_hparams(
|
|||||||
default: model.type = e_model::MODEL_UNKNOWN;
|
default: model.type = e_model::MODEL_UNKNOWN;
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case LLM_ARCH_STABLELM:
|
||||||
|
{
|
||||||
|
GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS));
|
||||||
|
|
||||||
|
switch (hparams.n_layer) {
|
||||||
|
case 32: model.type = e_model::MODEL_3B; break;
|
||||||
|
default: model.type = e_model::MODEL_UNKNOWN;
|
||||||
|
}
|
||||||
|
} break;
|
||||||
|
|
||||||
default: (void)0;
|
default: (void)0;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3087,6 +3118,81 @@ static void llm_load_tensors(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case LLM_ARCH_STABLELM:
|
||||||
|
{
|
||||||
|
model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
|
||||||
|
|
||||||
|
// output
|
||||||
|
{
|
||||||
|
ggml_backend_type backend_norm;
|
||||||
|
ggml_backend_type backend_output;
|
||||||
|
|
||||||
|
if (n_gpu_layers > int(n_layer)) {
|
||||||
|
// norm is not performance relevant on its own but keeping it in VRAM reduces data copying
|
||||||
|
// on Windows however this is detrimental unless everything is on the GPU
|
||||||
|
#ifndef _WIN32
|
||||||
|
backend_norm = llama_backend_offload;
|
||||||
|
#else
|
||||||
|
backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload;
|
||||||
|
#endif // _WIN32
|
||||||
|
|
||||||
|
backend_output = llama_backend_offload_split;
|
||||||
|
} else {
|
||||||
|
backend_norm = GGML_BACKEND_CPU;
|
||||||
|
backend_output = GGML_BACKEND_CPU;
|
||||||
|
}
|
||||||
|
|
||||||
|
model.output_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, backend_norm);
|
||||||
|
model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm);
|
||||||
|
model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output);
|
||||||
|
|
||||||
|
if (backend_norm == GGML_BACKEND_GPU) {
|
||||||
|
vram_weights += ggml_nbytes(model.output_norm);
|
||||||
|
}
|
||||||
|
if (backend_output == GGML_BACKEND_GPU_SPLIT) {
|
||||||
|
vram_weights += ggml_nbytes(model.output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint32_t n_ff = hparams.n_ff;
|
||||||
|
|
||||||
|
const int i_gpu_start = n_layer - n_gpu_layers;
|
||||||
|
|
||||||
|
model.layers.resize(n_layer);
|
||||||
|
|
||||||
|
for (uint32_t i = 0; i < n_layer; ++i) {
|
||||||
|
/*
|
||||||
|
llama_model_loader: - tensor 4: blk.0.attn_output.weight f16 [ 2560, 2560, 1, 1 ]
|
||||||
|
*/
|
||||||
|
const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload; // NOLINT
|
||||||
|
const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload_split; // NOLINT
|
||||||
|
|
||||||
|
auto & layer = model.layers[i];
|
||||||
|
|
||||||
|
layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend);
|
||||||
|
layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, backend);
|
||||||
|
|
||||||
|
layer.wq = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, backend_split);
|
||||||
|
layer.wk = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, backend_split);
|
||||||
|
layer.wv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, backend_split);
|
||||||
|
layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split);
|
||||||
|
|
||||||
|
layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend);
|
||||||
|
layer.ffn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, backend);
|
||||||
|
|
||||||
|
layer.ffn_gate = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split);
|
||||||
|
layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split);
|
||||||
|
layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split);
|
||||||
|
|
||||||
|
if (backend == GGML_BACKEND_GPU) {
|
||||||
|
vram_weights +=
|
||||||
|
ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk) +
|
||||||
|
ggml_nbytes(layer.wv) + ggml_nbytes(layer.wo) + ggml_nbytes(layer.ffn_norm) +
|
||||||
|
ggml_nbytes(layer.ffn_gate) + ggml_nbytes(layer.ffn_down) + ggml_nbytes(layer.ffn_up);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} break;
|
||||||
|
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error("unknown architecture");
|
throw std::runtime_error("unknown architecture");
|
||||||
}
|
}
|
||||||
@ -4565,6 +4671,177 @@ struct llm_build_context {
|
|||||||
|
|
||||||
return gf;
|
return gf;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct ggml_cgraph * build_stablelm() {
|
||||||
|
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||||
|
|
||||||
|
struct ggml_tensor * cur;
|
||||||
|
struct ggml_tensor * inpL;
|
||||||
|
|
||||||
|
inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
|
||||||
|
cb(inpL, "inp_embd", -1);
|
||||||
|
|
||||||
|
// inp_pos - contains the positions
|
||||||
|
struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
|
||||||
|
cb(inp_pos, "inp_pos", -1);
|
||||||
|
|
||||||
|
// KQ_scale
|
||||||
|
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
|
||||||
|
cb(KQ_scale, "KQ_scale", -1);
|
||||||
|
|
||||||
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
|
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
|
||||||
|
cb(KQ_mask, "KQ_mask", -1);
|
||||||
|
|
||||||
|
// shift the entire K-cache if needed
|
||||||
|
if (do_rope_shift) {
|
||||||
|
llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, hparams.n_rot, freq_base, freq_scale, cb);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int il = 0; il < n_layer; ++il) {
|
||||||
|
struct ggml_tensor * inpSA = inpL;
|
||||||
|
|
||||||
|
// norm
|
||||||
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
||||||
|
model.layers[il].attn_norm,
|
||||||
|
model.layers[il].attn_norm_b,
|
||||||
|
LLM_NORM, cb, il);
|
||||||
|
cb(cur, "attn_norm", il);
|
||||||
|
|
||||||
|
// self-attention
|
||||||
|
{
|
||||||
|
// compute Q and K and RoPE them
|
||||||
|
struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
|
||||||
|
cb(tmpq, "tmpq", il);
|
||||||
|
|
||||||
|
struct ggml_tensor * tmpk = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
|
||||||
|
cb(tmpk, "tmpk", il);
|
||||||
|
|
||||||
|
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
|
||||||
|
cb(Vcur, "Vcur", il);
|
||||||
|
|
||||||
|
// RoPE the first n_rot of q/k, pass the other half, and concat.
|
||||||
|
struct ggml_tensor * qrot = ggml_cont(ctx0, ggml_view_3d(
|
||||||
|
ctx0, tmpq, hparams.n_rot, n_head, n_tokens,
|
||||||
|
ggml_element_size(tmpq) * n_embd_head,
|
||||||
|
ggml_element_size(tmpq) * n_embd_head * n_head,
|
||||||
|
0
|
||||||
|
));
|
||||||
|
cb(qrot, "qrot", il);
|
||||||
|
|
||||||
|
struct ggml_tensor * krot = ggml_cont(ctx0, ggml_view_3d(
|
||||||
|
ctx0, tmpk, hparams.n_rot, n_head, n_tokens,
|
||||||
|
ggml_element_size(tmpk) * n_embd_head,
|
||||||
|
ggml_element_size(tmpk) * n_embd_head * n_head_kv,
|
||||||
|
0
|
||||||
|
));
|
||||||
|
cb(krot, "krot", il);
|
||||||
|
|
||||||
|
// get the second half of tmpq, e.g tmpq[n_rot:, :, :]
|
||||||
|
struct ggml_tensor * qpass = ggml_view_3d(
|
||||||
|
ctx0, tmpq, (n_embd_head - hparams.n_rot), n_head, n_tokens,
|
||||||
|
ggml_element_size(tmpq) * n_embd_head,
|
||||||
|
ggml_element_size(tmpq) * n_embd_head * n_head,
|
||||||
|
ggml_element_size(tmpq) * hparams.n_rot
|
||||||
|
);
|
||||||
|
cb(qpass, "qpass", il);
|
||||||
|
|
||||||
|
struct ggml_tensor * kpass = ggml_view_3d(
|
||||||
|
ctx0, tmpk, (n_embd_head - hparams.n_rot), n_head_kv, n_tokens,
|
||||||
|
ggml_element_size(tmpk) * (n_embd_head),
|
||||||
|
ggml_element_size(tmpk) * (n_embd_head) * n_head_kv,
|
||||||
|
ggml_element_size(tmpk) * hparams.n_rot
|
||||||
|
);
|
||||||
|
cb(kpass, "kpass", il);
|
||||||
|
|
||||||
|
struct ggml_tensor * qrotated = ggml_rope_custom(
|
||||||
|
ctx0, qrot, inp_pos, hparams.n_rot, 2, 0, n_orig_ctx,
|
||||||
|
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
|
);
|
||||||
|
cb(qrotated, "qrotated", il);
|
||||||
|
|
||||||
|
struct ggml_tensor * krotated = ggml_rope_custom(
|
||||||
|
ctx0, krot, inp_pos, hparams.n_rot, 2, 0, n_orig_ctx,
|
||||||
|
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
|
);
|
||||||
|
cb(krotated, "krotated", il);
|
||||||
|
|
||||||
|
// ggml currently only supports concatenation on dim=2
|
||||||
|
// so we need to permute qrot, qpass, concat, then permute back.
|
||||||
|
qrotated = ggml_cont(ctx0, ggml_permute(ctx0, qrotated, 2, 1, 0, 3));
|
||||||
|
cb(qrotated, "qrotated", il);
|
||||||
|
|
||||||
|
krotated = ggml_cont(ctx0, ggml_permute(ctx0, krotated, 2, 1, 0, 3));
|
||||||
|
cb(krotated, "krotated", il);
|
||||||
|
|
||||||
|
qpass = ggml_cont(ctx0, ggml_permute(ctx0, qpass, 2, 1, 0, 3));
|
||||||
|
cb(qpass, "qpass", il);
|
||||||
|
|
||||||
|
kpass = ggml_cont(ctx0, ggml_permute(ctx0, kpass, 2, 1, 0, 3));
|
||||||
|
cb(kpass, "kpass", il);
|
||||||
|
|
||||||
|
struct ggml_tensor * Qcur = ggml_concat(ctx0, qrotated, qpass);
|
||||||
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
|
struct ggml_tensor * Kcur = ggml_concat(ctx0, krotated, kpass);
|
||||||
|
cb(Kcur, "Kcur", il);
|
||||||
|
|
||||||
|
struct ggml_tensor * Q = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 2, 1, 0, 3));
|
||||||
|
cb(Q, "Q", il);
|
||||||
|
|
||||||
|
Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 2, 1, 0, 3));
|
||||||
|
cb(Kcur, "Kcur", il);
|
||||||
|
|
||||||
|
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
|
||||||
|
|
||||||
|
cur = llm_build_kqv(ctx0, hparams, kv_self,
|
||||||
|
model.layers[il].wo, NULL,
|
||||||
|
Q, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il);
|
||||||
|
cb(cur, "kqv_out", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||||
|
cb(ffn_inp, "ffn_inp", il);
|
||||||
|
|
||||||
|
// feed-forward network
|
||||||
|
{
|
||||||
|
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
||||||
|
model.layers[il].ffn_norm,
|
||||||
|
model.layers[il].ffn_norm_b,
|
||||||
|
LLM_NORM, cb, il);
|
||||||
|
cb(cur, "ffn_norm", il);
|
||||||
|
|
||||||
|
cur = llm_build_ffn(ctx0, cur,
|
||||||
|
model.layers[il].ffn_up, NULL,
|
||||||
|
model.layers[il].ffn_gate, NULL,
|
||||||
|
model.layers[il].ffn_down, NULL,
|
||||||
|
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
||||||
|
cb(cur, "ffn_out", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||||
|
cb(cur, "l_out", il);
|
||||||
|
|
||||||
|
// input for next layer
|
||||||
|
inpL = cur;
|
||||||
|
}
|
||||||
|
|
||||||
|
cur = inpL;
|
||||||
|
|
||||||
|
cur = llm_build_norm(ctx0, cur, hparams,
|
||||||
|
model.output_norm,
|
||||||
|
model.output_norm_b,
|
||||||
|
LLM_NORM, cb, -1);
|
||||||
|
cb(cur, "result_norm", -1);
|
||||||
|
|
||||||
|
// lm_head
|
||||||
|
cur = ggml_mul_mat(ctx0, model.output, cur);
|
||||||
|
cb(cur, "result_output", -1);
|
||||||
|
|
||||||
|
ggml_build_forward_expand(gf, cur);
|
||||||
|
|
||||||
|
return gf;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
//
|
//
|
||||||
@ -5034,6 +5311,10 @@ static struct ggml_cgraph * llama_build_graph(
|
|||||||
{
|
{
|
||||||
result = llm.build_mpt();
|
result = llm.build_mpt();
|
||||||
} break;
|
} break;
|
||||||
|
case LLM_ARCH_STABLELM:
|
||||||
|
{
|
||||||
|
result = llm.build_stablelm();
|
||||||
|
} break;
|
||||||
default:
|
default:
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
}
|
}
|
||||||
@ -5209,7 +5490,8 @@ static int llama_decode_internal(
|
|||||||
model.arch == LLM_ARCH_FALCON ||
|
model.arch == LLM_ARCH_FALCON ||
|
||||||
model.arch == LLM_ARCH_REFACT ||
|
model.arch == LLM_ARCH_REFACT ||
|
||||||
model.arch == LLM_ARCH_MPT ||
|
model.arch == LLM_ARCH_MPT ||
|
||||||
model.arch == LLM_ARCH_STARCODER;
|
model.arch == LLM_ARCH_STARCODER ||
|
||||||
|
model.arch == LLM_ARCH_STABLELM;
|
||||||
|
|
||||||
const bool fully_offloaded = model.n_gpu_layers >= (int) hparams.n_layer + 3;
|
const bool fully_offloaded = model.n_gpu_layers >= (int) hparams.n_layer + 3;
|
||||||
if (ggml_cpu_has_cublas() && full_offload_supported && fully_offloaded) {
|
if (ggml_cpu_has_cublas() && full_offload_supported && fully_offloaded) {
|
||||||
|
BIN
models/ggml-vocab-stablelm-3b-4e1t.gguf
Normal file
BIN
models/ggml-vocab-stablelm-3b-4e1t.gguf
Normal file
Binary file not shown.
@ -33,9 +33,11 @@ llama_build_executable(test-tokenizer-1-bpe.cpp)
|
|||||||
llama_test_executable (test-tokenizer-1-falcon test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-falcon.gguf)
|
llama_test_executable (test-tokenizer-1-falcon test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-falcon.gguf)
|
||||||
llama_test_executable(test-tokenizer-1-aquila test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-aquila.gguf)
|
llama_test_executable(test-tokenizer-1-aquila test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-aquila.gguf)
|
||||||
llama_test_executable(test-tokenizer-1-mpt test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-mpt.gguf)
|
llama_test_executable(test-tokenizer-1-mpt test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-mpt.gguf)
|
||||||
|
llama_test_executable(test-tokenizer-1-stablelm-3b-4e1t test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-stablelm-3b-4e1t.gguf)
|
||||||
llama_test_executable(test-tokenizer-1-gpt-neox test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-gpt-neox.gguf)
|
llama_test_executable(test-tokenizer-1-gpt-neox test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-gpt-neox.gguf)
|
||||||
llama_test_executable(test-tokenizer-1-refact test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-refact.gguf)
|
llama_test_executable(test-tokenizer-1-refact test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-refact.gguf)
|
||||||
llama_test_executable(test-tokenizer-1-starcoder test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-starcoder.gguf)
|
llama_test_executable(test-tokenizer-1-starcoder test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-starcoder.gguf)
|
||||||
|
# llama_test_executable(test-tokenizer-1-bloom test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-bloom.gguf) # BIG
|
||||||
llama_build_and_test_executable(test-grammar-parser.cpp)
|
llama_build_and_test_executable(test-grammar-parser.cpp)
|
||||||
llama_build_and_test_executable(test-llama-grammar.cpp)
|
llama_build_and_test_executable(test-llama-grammar.cpp)
|
||||||
llama_build_and_test_executable(test-grad0.cpp) # SLOW
|
llama_build_and_test_executable(test-grad0.cpp) # SLOW
|
||||||
|
Loading…
Reference in New Issue
Block a user