mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 10:54:36 +00:00
llama : add StableLM2 12B (#6635)
* StableLM2 12B support for huggingface -> GGUF * StableLM12 tensormapping and constants * StableLM-2-12b model support * fix * Added 12B support * Removed autoformatting; resolved bug where model_arch was not selecting StableLM2 * Formatting * Do QK norm stacking in model conversion step * Converge StableLM and StableLM2 code to simplify graph construction * Fix accidental removal * Removed warnings * Revert formatter * Move QK norm stack to private function so it's easier to read * refactor stablelm graph builder to support 1.6, 3b and 12b more efficiently * Proper check for None type for new_name to avoid crash; formatting; revert change to base class `write_tensors()` * Format * Formatting * format Co-authored-by: compilade <git@compilade.net> * Fix incorrect check for K norm * space after commas; Keep indentation multiple of 4 spaces * Flake8 format * Removed unnecessary conditional branches * Removed unused comment * Fixed incorrect tensor passing * Format --------- Co-authored-by: compilade <git@compilade.net>
This commit is contained in:
parent
f4dea7da18
commit
dbceec87c0
@ -1207,9 +1207,91 @@ class StableLMModel(Model):
|
|||||||
rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct"])
|
rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct"])
|
||||||
self.gguf_writer.add_rope_dimension_count(int(rotary_factor * (hparams["hidden_size"] // hparams["num_attention_heads"])))
|
self.gguf_writer.add_rope_dimension_count(int(rotary_factor * (hparams["hidden_size"] // hparams["num_attention_heads"])))
|
||||||
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
|
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
|
||||||
|
self.gguf_writer.add_head_count_kv(hparams["num_key_value_heads"])
|
||||||
self.gguf_writer.add_parallel_residual(hparams["use_parallel_residual"] if "use_parallel_residual" in hparams else True)
|
self.gguf_writer.add_parallel_residual(hparams["use_parallel_residual"] if "use_parallel_residual" in hparams else True)
|
||||||
self.gguf_writer.add_layer_norm_eps(self.find_hparam(["layer_norm_eps", "norm_eps"]))
|
self.gguf_writer.add_layer_norm_eps(self.find_hparam(["layer_norm_eps", "norm_eps"]))
|
||||||
|
|
||||||
|
def write_tensors(self):
|
||||||
|
block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
|
||||||
|
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
|
||||||
|
n_head = self.hparams.get("num_attention_heads")
|
||||||
|
n_kv_head = self.hparams.get("num_key_value_heads")
|
||||||
|
q_norms = dict()
|
||||||
|
k_norms = dict()
|
||||||
|
for name, data_torch in self.get_tensors():
|
||||||
|
# we don't need these
|
||||||
|
if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")):
|
||||||
|
continue
|
||||||
|
|
||||||
|
old_dtype = data_torch.dtype
|
||||||
|
|
||||||
|
# convert any unsupported data types to float32
|
||||||
|
if data_torch.dtype not in (torch.float16, torch.float32):
|
||||||
|
data_torch = data_torch.to(torch.float32)
|
||||||
|
|
||||||
|
data = data_torch.squeeze().numpy()
|
||||||
|
n_dims = len(data.shape)
|
||||||
|
if name.find("q_layernorm.norms") != -1:
|
||||||
|
q_norms[name] = data
|
||||||
|
if len(q_norms) >= (block_count * n_head):
|
||||||
|
self._stack_qk_norm(block_count, name, tensor_map, n_head, q_norms, n_dims, layer_name="q_layernorm")
|
||||||
|
continue
|
||||||
|
if name.find("k_layernorm.norms") != -1:
|
||||||
|
k_norms[name] = data
|
||||||
|
if len(k_norms) >= (block_count * n_kv_head):
|
||||||
|
self._stack_qk_norm(block_count, name, tensor_map, n_kv_head, k_norms, n_dims, layer_name="k_layernorm")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# map tensor names
|
||||||
|
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
||||||
|
if new_name is None:
|
||||||
|
print(f"Can not map tensor {name!r}")
|
||||||
|
sys.exit()
|
||||||
|
|
||||||
|
n_dims = len(data.shape)
|
||||||
|
data_dtype = data.dtype
|
||||||
|
|
||||||
|
# if f32 desired, convert any float16 to float32
|
||||||
|
if self.ftype == 0 and data_dtype == np.float16:
|
||||||
|
data = data.astype(np.float32)
|
||||||
|
|
||||||
|
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
|
||||||
|
if self.ftype == 1 and data_dtype == np.float16 and (n_dims == 1 or new_name.endswith("_norm.weight")):
|
||||||
|
data = data.astype(np.float32)
|
||||||
|
|
||||||
|
# if f16 desired, convert any float32 2-dim weight tensors to float16
|
||||||
|
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and not new_name.endswith("_norm.weight") and n_dims == 2:
|
||||||
|
data = data.astype(np.float16)
|
||||||
|
|
||||||
|
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
||||||
|
|
||||||
|
self.gguf_writer.add_tensor(new_name, data)
|
||||||
|
|
||||||
|
def _stack_qk_norm(self, block_count, name, tensor_map, n_head, norms, n_dims, layer_name="q_layernorm"):
|
||||||
|
for bid in range(block_count):
|
||||||
|
datas = []
|
||||||
|
for xid in range(n_head):
|
||||||
|
ename = f"model.layers.{bid}.self_attn.{layer_name}.norms.{xid}.weight"
|
||||||
|
datas.append(norms[ename])
|
||||||
|
del norms[ename]
|
||||||
|
data = np.stack(datas, axis=0)
|
||||||
|
data_dtype = data.dtype
|
||||||
|
merged_name = f"model.layers.{bid}.self_attn.{layer_name}.weight"
|
||||||
|
new_name = tensor_map.get_name(merged_name, try_suffixes=(".weight", ".bias"))
|
||||||
|
if new_name is None:
|
||||||
|
print(f"Can not map tensor {name!r}")
|
||||||
|
sys.exit()
|
||||||
|
if self.ftype == 1 and data_dtype == np.float16 and (n_dims == 1 or new_name.endswith("_norm.weight")):
|
||||||
|
data = data.astype(np.float32)
|
||||||
|
|
||||||
|
# if f16 desired, convert any float32 2-dim weight tensors to float16
|
||||||
|
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and not new_name.endswith("_norm.weight") and n_dims == 2:
|
||||||
|
data = data.astype(np.float16)
|
||||||
|
|
||||||
|
print(f"{new_name}, n_dims = {len(data.shape)}, shape = {data.shape} --> {data.dtype}")
|
||||||
|
|
||||||
|
self.gguf_writer.add_tensor(new_name, data)
|
||||||
|
|
||||||
|
|
||||||
@Model.register("LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM")
|
@Model.register("LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM")
|
||||||
class LlamaModel(Model):
|
class LlamaModel(Model):
|
||||||
|
@ -455,6 +455,8 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||||||
MODEL_TENSOR.FFN_GATE,
|
MODEL_TENSOR.FFN_GATE,
|
||||||
MODEL_TENSOR.FFN_DOWN,
|
MODEL_TENSOR.FFN_DOWN,
|
||||||
MODEL_TENSOR.FFN_UP,
|
MODEL_TENSOR.FFN_UP,
|
||||||
|
MODEL_TENSOR.ATTN_Q_NORM,
|
||||||
|
MODEL_TENSOR.ATTN_K_NORM,
|
||||||
],
|
],
|
||||||
MODEL_ARCH.QWEN: [
|
MODEL_ARCH.QWEN: [
|
||||||
MODEL_TENSOR.TOKEN_EMBD,
|
MODEL_TENSOR.TOKEN_EMBD,
|
||||||
|
62
llama.cpp
62
llama.cpp
@ -716,6 +716,8 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
|
|||||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||||
|
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
||||||
|
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -1744,6 +1746,7 @@ enum e_model {
|
|||||||
MODEL_4B,
|
MODEL_4B,
|
||||||
MODEL_7B,
|
MODEL_7B,
|
||||||
MODEL_8B,
|
MODEL_8B,
|
||||||
|
MODEL_12B,
|
||||||
MODEL_13B,
|
MODEL_13B,
|
||||||
MODEL_14B,
|
MODEL_14B,
|
||||||
MODEL_15B,
|
MODEL_15B,
|
||||||
@ -3607,6 +3610,7 @@ static const char * llama_model_type_name(e_model type) {
|
|||||||
case MODEL_3B: return "3B";
|
case MODEL_3B: return "3B";
|
||||||
case MODEL_7B: return "7B";
|
case MODEL_7B: return "7B";
|
||||||
case MODEL_8B: return "8B";
|
case MODEL_8B: return "8B";
|
||||||
|
case MODEL_12B: return "12B";
|
||||||
case MODEL_13B: return "13B";
|
case MODEL_13B: return "13B";
|
||||||
case MODEL_14B: return "14B";
|
case MODEL_14B: return "14B";
|
||||||
case MODEL_15B: return "15B";
|
case MODEL_15B: return "15B";
|
||||||
@ -3898,6 +3902,7 @@ static void llm_load_hparams(
|
|||||||
switch (hparams.n_layer) {
|
switch (hparams.n_layer) {
|
||||||
case 24: model.type = e_model::MODEL_1B; break;
|
case 24: model.type = e_model::MODEL_1B; break;
|
||||||
case 32: model.type = e_model::MODEL_3B; break;
|
case 32: model.type = e_model::MODEL_3B; break;
|
||||||
|
case 40: model.type = e_model::MODEL_12B; break;
|
||||||
default: model.type = e_model::MODEL_UNKNOWN;
|
default: model.type = e_model::MODEL_UNKNOWN;
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
@ -5128,8 +5133,13 @@ static bool llm_load_tensors(
|
|||||||
layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, false);
|
layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, false);
|
||||||
layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, false);
|
layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, false);
|
||||||
|
|
||||||
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
|
// optional q and k layernorms, present in StableLM 2 12B
|
||||||
layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd});
|
layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {hparams.n_embd_head_k, hparams.n_head}, false);
|
||||||
|
layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {hparams.n_embd_head_k, hparams.n_head_kv}, false);
|
||||||
|
|
||||||
|
// optional FFN norm, not present in StableLM 2 12B which uses parallel residual
|
||||||
|
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, false);
|
||||||
|
layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, false);
|
||||||
|
|
||||||
layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
|
layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
|
||||||
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd});
|
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd});
|
||||||
@ -8197,7 +8207,7 @@ struct llm_build_context {
|
|||||||
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
||||||
|
|
||||||
for (int il = 0; il < n_layer; ++il) {
|
for (int il = 0; il < n_layer; ++il) {
|
||||||
struct ggml_tensor * inpSA = inpL;
|
|
||||||
|
|
||||||
// norm
|
// norm
|
||||||
cur = llm_build_norm(ctx0, inpL, hparams,
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
||||||
@ -8206,6 +8216,8 @@ struct llm_build_context {
|
|||||||
LLM_NORM, cb, il);
|
LLM_NORM, cb, il);
|
||||||
cb(cur, "attn_norm", il);
|
cb(cur, "attn_norm", il);
|
||||||
|
|
||||||
|
struct ggml_tensor * inpSA = cur;
|
||||||
|
|
||||||
// self-attention
|
// self-attention
|
||||||
{
|
{
|
||||||
// compute Q and K and RoPE them
|
// compute Q and K and RoPE them
|
||||||
@ -8230,15 +8242,36 @@ struct llm_build_context {
|
|||||||
cb(Vcur, "Vcur", il);
|
cb(Vcur, "Vcur", il);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||||
|
cb(Qcur, "Qcur", il);
|
||||||
|
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||||
|
cb(Kcur, "Kcur", il);
|
||||||
|
|
||||||
|
if (model.layers[il].attn_q_norm) {
|
||||||
|
Qcur = llm_build_norm(ctx0, Qcur, hparams,
|
||||||
|
model.layers[il].attn_q_norm,
|
||||||
|
NULL,
|
||||||
|
LLM_NORM, cb, il);
|
||||||
|
cb(Qcur, "Qcur", il);
|
||||||
|
}
|
||||||
|
if (model.layers[il].attn_k_norm) {
|
||||||
|
Kcur = llm_build_norm(ctx0, Kcur, hparams,
|
||||||
|
model.layers[il].attn_k_norm,
|
||||||
|
NULL,
|
||||||
|
LLM_NORM, cb, il);
|
||||||
|
cb(Kcur, "Kcur", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
Qcur = ggml_rope_custom(
|
Qcur = ggml_rope_custom(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
|
ctx0, Qcur, inp_pos,
|
||||||
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
Kcur = ggml_rope_custom(
|
Kcur = ggml_rope_custom(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
|
ctx0, Kcur, inp_pos,
|
||||||
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
@ -8253,20 +8286,25 @@ struct llm_build_context {
|
|||||||
// skip computing output for unused tokens
|
// skip computing output for unused tokens
|
||||||
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||||
|
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
||||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
|
||||||
cb(ffn_inp, "ffn_inp", il);
|
cb(ffn_inp, "ffn_inp", il);
|
||||||
|
|
||||||
// feed-forward network
|
// feed-forward network
|
||||||
{
|
{
|
||||||
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
if (model.layers[il].ffn_norm) {
|
||||||
model.layers[il].ffn_norm,
|
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
||||||
model.layers[il].ffn_norm_b,
|
model.layers[il].ffn_norm,
|
||||||
LLM_NORM, cb, il);
|
model.layers[il].ffn_norm_b,
|
||||||
cb(cur, "ffn_norm", il);
|
LLM_NORM, cb, il);
|
||||||
|
cb(cur, "ffn_norm", il);
|
||||||
|
} else {
|
||||||
|
// parallel residual
|
||||||
|
cur = inpSA;
|
||||||
|
}
|
||||||
cur = llm_build_ffn(ctx0, cur,
|
cur = llm_build_ffn(ctx0, cur,
|
||||||
model.layers[il].ffn_up, NULL,
|
model.layers[il].ffn_up, NULL,
|
||||||
model.layers[il].ffn_gate, NULL,
|
model.layers[il].ffn_gate, NULL,
|
||||||
|
Loading…
Reference in New Issue
Block a user