llama : support InternLM2 (#5184)

* support InternLM2 inference
  * add add_space_prefix KV pair
This commit is contained in:
Guoteng 2024-02-01 17:19:51 +08:00 committed by GitHub
parent 1cfb5372cf
commit ce32060198
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 387 additions and 5 deletions

View File

@ -203,6 +203,8 @@ class Model:
return CodeShellModel return CodeShellModel
if model_architecture == "OrionForCausalLM": if model_architecture == "OrionForCausalLM":
return OrionModel return OrionModel
if model_architecture == "InternLM2ForCausalLM":
return InternLM2Model
return Model return Model
def _is_model_safetensors(self) -> bool: def _is_model_safetensors(self) -> bool:
@ -254,6 +256,8 @@ class Model:
return gguf.MODEL_ARCH.CODESHELL return gguf.MODEL_ARCH.CODESHELL
if arch == "OrionForCausalLM": if arch == "OrionForCausalLM":
return gguf.MODEL_ARCH.ORION return gguf.MODEL_ARCH.ORION
if arch == "InternLM2ForCausalLM":
return gguf.MODEL_ARCH.INTERNLM2
raise NotImplementedError(f'Architecture "{arch}" not supported!') raise NotImplementedError(f'Architecture "{arch}" not supported!')
@ -1344,6 +1348,154 @@ class CodeShellModel(Model):
self.gguf_writer.add_tensor("output.weight", data) self.gguf_writer.add_tensor("output.weight", data)
print(name, f"=> output.weight, shape = {data.shape}, {old_dtype} --> {data.dtype}") print(name, f"=> output.weight, shape = {data.shape}, {old_dtype} --> {data.dtype}")
class InternLM2Model(Model):
def set_vocab(self):
# (TODO): Is there a better way?
# Copy from _set_vocab_sentencepiece, The only difference is that we will treat the character
# \x00 specially and convert it into an emoji character to prevent it from being mistakenly
# recognized as an empty string in C++.
from sentencepiece import SentencePieceProcessor
from sentencepiece import sentencepiece_model_pb2 as model
tokenizer_path = self.dir_model / 'tokenizer.model'
tokens: list[bytes] = []
scores: list[float] = []
toktypes: list[int] = []
if not tokenizer_path.is_file():
print(f'Error: Missing {tokenizer_path}', file=sys.stderr)
sys.exit(1)
sentencepiece_model = model.ModelProto()
sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix
tokenizer = SentencePieceProcessor(str(tokenizer_path))
vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())
for token_id in range(vocab_size):
piece = tokenizer.id_to_piece(token_id)
text = piece.encode("utf-8")
score = tokenizer.get_score(token_id)
if text == b"\x00":
# (TODO): fixme
# Hack here and replace the \x00 characters.
print(f"InternLM2 convert token '{text}' to '🐉'!")
text = "🐉"
toktype = SentencePieceTokenTypes.NORMAL
if tokenizer.is_unknown(token_id):
toktype = SentencePieceTokenTypes.UNKNOWN
elif tokenizer.is_control(token_id):
toktype = SentencePieceTokenTypes.CONTROL
elif tokenizer.is_unused(token_id):
toktype = SentencePieceTokenTypes.UNUSED
elif tokenizer.is_byte(token_id):
toktype = SentencePieceTokenTypes.BYTE
tokens.append(text)
scores.append(score)
toktypes.append(toktype)
added_tokens_file = self.dir_model / 'added_tokens.json'
if added_tokens_file.is_file():
with open(added_tokens_file, "r", encoding="utf-8") as f:
added_tokens_json = json.load(f)
for key in added_tokens_json:
tokens.append(key.encode("utf-8"))
scores.append(-1000.0)
toktypes.append(SentencePieceTokenTypes.USER_DEFINED)
self.gguf_writer.add_tokenizer_model("llama")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_scores(scores)
self.gguf_writer.add_token_types(toktypes)
self.gguf_writer.add_add_space_prefix(add_prefix)
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
special_vocab.add_to_gguf(self.gguf_writer)
def set_gguf_parameters(self):
self.gguf_writer.add_name("InternLM2")
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"])
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"])
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"])
def post_write_tensors(self, tensor_map, name, data_torch):
old_dtype = data_torch.dtype
# convert any unsupported data types to float32
if data_torch.dtype not in (torch.float16, torch.float32):
data_torch = data_torch.to(torch.float32)
data = data_torch.squeeze().numpy()
# map tensor names
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
if new_name is None:
print(f"Can not map tensor {name!r}")
sys.exit()
n_dims = len(data.shape)
data_dtype = data.dtype
# if f32 desired, convert any float16 to float32
if self.ftype == 0 and data_dtype == np.float16:
data = data.astype(np.float32)
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
data = data.astype(np.float32)
# if f16 desired, convert any float32 2-dim weight tensors to float16
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
data = data.astype(np.float16)
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
self.gguf_writer.add_tensor(new_name, data)
def write_tensors(self):
from einops import rearrange
num_heads = self.hparams.get("num_attention_heads")
num_kv_heads = self.hparams.get("num_key_value_heads")
hidden_size = self.hparams.get("hidden_size")
q_per_kv = num_heads // num_kv_heads
head_dim = hidden_size // num_heads
num_groups = num_heads // q_per_kv
block_count = self.hparams["num_hidden_layers"]
model_kv = dict(self.get_tensors())
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
qkv_pattern = r"model\.layers\.(\d+)\.attention\.wqkv"
for name, data_torch in model_kv.items():
# we don't need these
if name.endswith(".rotary_emb.inv_freq"):
continue
if re.match(qkv_pattern, name):
bid = re.findall(qkv_pattern, name)[0]
qkv = data_torch
qkv = rearrange(qkv.T, " o (g n i) ->o g n i", g=num_groups, n=q_per_kv + 2, i=head_dim)
q, k, v = qkv[..., : q_per_kv, :], qkv[..., q_per_kv: q_per_kv + 1, :], qkv[..., q_per_kv + 1: q_per_kv + 2, :]
q = rearrange(q, " o g n i -> o (g n i)").T
k = rearrange(k, " o g n i -> o (g n i)").T
v = rearrange(v, " o g n i -> o (g n i)").T
self.post_write_tensors(tensor_map, f"model.layers.{bid}.attention.wq.weight", q)
self.post_write_tensors(tensor_map, f"model.layers.{bid}.attention.wk.weight", k)
self.post_write_tensors(tensor_map, f"model.layers.{bid}.attention.wv.weight", v)
else:
self.post_write_tensors(tensor_map, name, data_torch)
###### CONVERSION LOGIC ###### ###### CONVERSION LOGIC ######

View File

@ -72,6 +72,7 @@ class Keys:
PAD_ID = "tokenizer.ggml.padding_token_id" PAD_ID = "tokenizer.ggml.padding_token_id"
ADD_BOS = "tokenizer.ggml.add_bos_token" ADD_BOS = "tokenizer.ggml.add_bos_token"
ADD_EOS = "tokenizer.ggml.add_eos_token" ADD_EOS = "tokenizer.ggml.add_eos_token"
ADD_PREFIX = "tokenizer.ggml.add_space_prefix"
HF_JSON = "tokenizer.huggingface.json" HF_JSON = "tokenizer.huggingface.json"
RWKV = "tokenizer.rwkv.world" RWKV = "tokenizer.rwkv.world"
CHAT_TEMPLATE = "tokenizer.chat_template" CHAT_TEMPLATE = "tokenizer.chat_template"
@ -102,6 +103,7 @@ class MODEL_ARCH(IntEnum):
PLAMO = auto() PLAMO = auto()
CODESHELL = auto() CODESHELL = auto()
ORION = auto() ORION = auto()
INTERNLM2 = auto()
class MODEL_TENSOR(IntEnum): class MODEL_TENSOR(IntEnum):
@ -153,6 +155,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.PLAMO: "plamo", MODEL_ARCH.PLAMO: "plamo",
MODEL_ARCH.CODESHELL: "codeshell", MODEL_ARCH.CODESHELL: "codeshell",
MODEL_ARCH.ORION: "orion", MODEL_ARCH.ORION: "orion",
MODEL_ARCH.INTERNLM2: "internlm2",
} }
TENSOR_NAMES: dict[MODEL_TENSOR, str] = { TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@ -446,6 +449,21 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP, MODEL_TENSOR.FFN_UP,
], ],
MODEL_ARCH.INTERNLM2: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_ROT_EMBD,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
# TODO # TODO
} }

View File

@ -411,6 +411,9 @@ class GGUFWriter:
def add_add_eos_token(self, value: bool) -> None: def add_add_eos_token(self, value: bool) -> None:
self.add_bool(Keys.Tokenizer.ADD_EOS, value) self.add_bool(Keys.Tokenizer.ADD_EOS, value)
def add_add_space_prefix(self, value: bool) -> None:
self.add_bool(Keys.Tokenizer.ADD_PREFIX, value)
def add_chat_template(self, value: str) -> None: def add_chat_template(self, value: str) -> None:
self.add_string(Keys.Tokenizer.CHAT_TEMPLATE, value) self.add_string(Keys.Tokenizer.CHAT_TEMPLATE, value)

View File

@ -19,6 +19,7 @@ class TensorNameMap:
"language_model.embedding.word_embeddings", # persimmon "language_model.embedding.word_embeddings", # persimmon
"wte", # gpt2 "wte", # gpt2
"transformer.embd.wte", # phi2 "transformer.embd.wte", # phi2
"model.tok_embeddings", # internlm2
), ),
# Token type embeddings # Token type embeddings
@ -42,7 +43,7 @@ class TensorNameMap:
MODEL_TENSOR.OUTPUT: ( MODEL_TENSOR.OUTPUT: (
"embed_out", # gptneox "embed_out", # gptneox
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen "lm_head", # gpt2 mpt falcon llama-hf baichuan qwen
"output", # llama-pth bloom "output", # llama-pth bloom internlm2
"word_embeddings_for_head", # persimmon "word_embeddings_for_head", # persimmon
"lm_head.linear", # phi2 "lm_head.linear", # phi2
), ),
@ -51,7 +52,7 @@ class TensorNameMap:
MODEL_TENSOR.OUTPUT_NORM: ( MODEL_TENSOR.OUTPUT_NORM: (
"gpt_neox.final_layer_norm", # gptneox "gpt_neox.final_layer_norm", # gptneox
"transformer.ln_f", # gpt2 gpt-j falcon "transformer.ln_f", # gpt2 gpt-j falcon
"model.norm", # llama-hf baichuan "model.norm", # llama-hf baichuan internlm2
"norm", # llama-pth "norm", # llama-pth
"embeddings.LayerNorm", # bert "embeddings.LayerNorm", # bert
"transformer.norm_f", # mpt "transformer.norm_f", # mpt
@ -84,6 +85,7 @@ class TensorNameMap:
"h.{bid}.ln_1", # gpt2 "h.{bid}.ln_1", # gpt2
"transformer.h.{bid}.ln", # phi2 "transformer.h.{bid}.ln", # phi2
"model.layers.layers.{bid}.norm", # plamo "model.layers.layers.{bid}.norm", # plamo
"model.layers.{bid}.attention_norm", # internlm2
), ),
# Attention norm 2 # Attention norm 2
@ -111,6 +113,7 @@ class TensorNameMap:
"encoder.layer.{bid}.attention.self.query", # bert "encoder.layer.{bid}.attention.self.query", # bert
"transformer.h.{bid}.attn.q_proj", # gpt-j "transformer.h.{bid}.attn.q_proj", # gpt-j
"model.layers.layers.{bid}.self_attn.q_proj", # plamo "model.layers.layers.{bid}.self_attn.q_proj", # plamo
"model.layers.{bid}.attention.wq" # internlm2
), ),
# Attention key # Attention key
@ -120,6 +123,7 @@ class TensorNameMap:
"encoder.layer.{bid}.attention.self.key", # bert "encoder.layer.{bid}.attention.self.key", # bert
"transformer.h.{bid}.attn.k_proj", # gpt-j "transformer.h.{bid}.attn.k_proj", # gpt-j
"model.layers.layers.{bid}.self_attn.k_proj", # plamo "model.layers.layers.{bid}.self_attn.k_proj", # plamo
"model.layers.{bid}.attention.wk" # internlm2
), ),
# Attention value # Attention value
@ -129,6 +133,7 @@ class TensorNameMap:
"encoder.layer.{bid}.attention.self.value", # bert "encoder.layer.{bid}.attention.self.value", # bert
"transformer.h.{bid}.attn.v_proj", # gpt-j "transformer.h.{bid}.attn.v_proj", # gpt-j
"model.layers.layers.{bid}.self_attn.v_proj", # plamo "model.layers.layers.{bid}.self_attn.v_proj", # plamo
"model.layers.{bid}.attention.wv" # internlm2
), ),
# Attention output # Attention output
@ -147,6 +152,7 @@ class TensorNameMap:
"h.{bid}.attn.c_proj", # gpt2 "h.{bid}.attn.c_proj", # gpt2
"transformer.h.{bid}.mixer.out_proj", # phi2 "transformer.h.{bid}.mixer.out_proj", # phi2
"model.layers.layers.{bid}.self_attn.o_proj", # plamo "model.layers.layers.{bid}.self_attn.o_proj", # plamo
"model.layers.{bid}.attention.wo", # internlm2
), ),
# Rotary embeddings # Rotary embeddings
@ -169,6 +175,7 @@ class TensorNameMap:
"language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon "language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon
"model.layers.{bid}.ln2", # yi "model.layers.{bid}.ln2", # yi
"h.{bid}.ln_2", # gpt2 "h.{bid}.ln_2", # gpt2
"model.layers.{bid}.ffn_norm", # internlm2
), ),
MODEL_TENSOR.FFN_GATE_INP: ( MODEL_TENSOR.FFN_GATE_INP: (
@ -194,6 +201,7 @@ class TensorNameMap:
"transformer.h.{bid}.mlp.fc1", # phi2 "transformer.h.{bid}.mlp.fc1", # phi2
"model.layers.{bid}.mlp.fc1", # phi2 "model.layers.{bid}.mlp.fc1", # phi2
"model.layers.layers.{bid}.mlp.up_proj", # plamo "model.layers.layers.{bid}.mlp.up_proj", # plamo
"model.layers.{bid}.feed_forward.w3", # internlm2
), ),
MODEL_TENSOR.FFN_UP_EXP: ( MODEL_TENSOR.FFN_UP_EXP: (
@ -212,6 +220,7 @@ class TensorNameMap:
"layers.{bid}.feed_forward.w1", # llama-pth "layers.{bid}.feed_forward.w1", # llama-pth
"transformer.h.{bid}.mlp.w2", # qwen "transformer.h.{bid}.mlp.w2", # qwen
"model.layers.layers.{bid}.mlp.gate_proj", # plamo "model.layers.layers.{bid}.mlp.gate_proj", # plamo
"model.layers.{bid}.feed_forward.w1", # internlm2
), ),
MODEL_TENSOR.FFN_GATE_EXP: ( MODEL_TENSOR.FFN_GATE_EXP: (
@ -236,6 +245,7 @@ class TensorNameMap:
"transformer.h.{bid}.mlp.fc2", # phi2 "transformer.h.{bid}.mlp.fc2", # phi2
"model.layers.{bid}.mlp.fc2", # phi2 "model.layers.{bid}.mlp.fc2", # phi2
"model.layers.layers.{bid}.mlp.down_proj", # plamo "model.layers.layers.{bid}.mlp.down_proj", # plamo
"model.layers.{bid}.feed_forward.w2", # internlm2
), ),
MODEL_TENSOR.FFN_DOWN_EXP: ( MODEL_TENSOR.FFN_DOWN_EXP: (

203
llama.cpp
View File

@ -204,6 +204,7 @@ enum llm_arch {
LLM_ARCH_PLAMO, LLM_ARCH_PLAMO,
LLM_ARCH_CODESHELL, LLM_ARCH_CODESHELL,
LLM_ARCH_ORION, LLM_ARCH_ORION,
LLM_ARCH_INTERNLM2,
LLM_ARCH_UNKNOWN, LLM_ARCH_UNKNOWN,
}; };
@ -226,6 +227,7 @@ static std::map<llm_arch, std::string> LLM_ARCH_NAMES = {
{ LLM_ARCH_PLAMO, "plamo" }, { LLM_ARCH_PLAMO, "plamo" },
{ LLM_ARCH_CODESHELL, "codeshell" }, { LLM_ARCH_CODESHELL, "codeshell" },
{ LLM_ARCH_ORION, "orion" }, { LLM_ARCH_ORION, "orion" },
{ LLM_ARCH_INTERNLM2, "internlm2" },
}; };
enum llm_kv { enum llm_kv {
@ -278,6 +280,7 @@ enum llm_kv {
LLM_KV_TOKENIZER_PAD_ID, LLM_KV_TOKENIZER_PAD_ID,
LLM_KV_TOKENIZER_ADD_BOS, LLM_KV_TOKENIZER_ADD_BOS,
LLM_KV_TOKENIZER_ADD_EOS, LLM_KV_TOKENIZER_ADD_EOS,
LLM_KV_TOKENIZER_ADD_PREFIX,
LLM_KV_TOKENIZER_HF_JSON, LLM_KV_TOKENIZER_HF_JSON,
LLM_KV_TOKENIZER_RWKV, LLM_KV_TOKENIZER_RWKV,
}; };
@ -332,6 +335,7 @@ static std::map<llm_kv, std::string> LLM_KV_NAMES = {
{ LLM_KV_TOKENIZER_PAD_ID, "tokenizer.ggml.padding_token_id" }, { LLM_KV_TOKENIZER_PAD_ID, "tokenizer.ggml.padding_token_id" },
{ LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token" }, { LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token" },
{ LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token" }, { LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token" },
{ LLM_KV_TOKENIZER_ADD_PREFIX, "tokenizer.ggml.add_space_prefix" },
{ LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" }, { LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" },
{ LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" }, { LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" },
}; };
@ -669,7 +673,23 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES =
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
}, },
}, },
{
LLM_ARCH_INTERNLM2,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{ {
LLM_ARCH_UNKNOWN, LLM_ARCH_UNKNOWN,
{ {
@ -1377,6 +1397,7 @@ enum e_model {
MODEL_13B, MODEL_13B,
MODEL_14B, MODEL_14B,
MODEL_15B, MODEL_15B,
MODEL_20B,
MODEL_30B, MODEL_30B,
MODEL_34B, MODEL_34B,
MODEL_40B, MODEL_40B,
@ -1618,6 +1639,8 @@ struct llama_vocab {
id special_suffix_id = 32008; id special_suffix_id = 32008;
id special_eot_id = 32010; id special_eot_id = 32010;
bool add_space_prefix = true;
int find_bpe_rank(const std::string & token_left, const std::string & token_right) const { int find_bpe_rank(const std::string & token_left, const std::string & token_right) const {
GGML_ASSERT(token_left.find(' ') == std::string::npos); GGML_ASSERT(token_left.find(' ') == std::string::npos);
GGML_ASSERT(token_left.find('\n') == std::string::npos); GGML_ASSERT(token_left.find('\n') == std::string::npos);
@ -2731,6 +2754,7 @@ static const char * llama_model_type_name(e_model type) {
case MODEL_13B: return "13B"; case MODEL_13B: return "13B";
case MODEL_14B: return "14B"; case MODEL_14B: return "14B";
case MODEL_15B: return "15B"; case MODEL_15B: return "15B";
case MODEL_20B: return "20B";
case MODEL_30B: return "30B"; case MODEL_30B: return "30B";
case MODEL_34B: return "34B"; case MODEL_34B: return "34B";
case MODEL_40B: return "40B"; case MODEL_40B: return "40B";
@ -2743,6 +2767,14 @@ static const char * llama_model_type_name(e_model type) {
default: return "?B"; default: return "?B";
} }
} }
static const char * llama_model_vocab_type_name(enum llama_vocab_type type){
switch (type) {
case LLAMA_VOCAB_TYPE_SPM: return "SPM";
case LLAMA_VOCAB_TYPE_BPE: return "BPE";
default: return "unknown";
}
}
static void llm_load_arch(llama_model_loader & ml, llama_model & model) { static void llm_load_arch(llama_model_loader & ml, llama_model & model) {
model.arch = ml.get_arch(); model.arch = ml.get_arch();
@ -3006,6 +3038,15 @@ static void llm_load_hparams(
default: model.type = e_model::MODEL_UNKNOWN; default: model.type = e_model::MODEL_UNKNOWN;
} }
} break; } break;
case LLM_ARCH_INTERNLM2:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
switch (hparams.n_layer) {
case 32: model.type = e_model::MODEL_7B; break;
case 48: model.type = e_model::MODEL_20B; break;
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
default: (void)0; default: (void)0;
} }
@ -3057,6 +3098,11 @@ static void llm_load_vocab(
vocab.special_unk_id = 0; vocab.special_unk_id = 0;
vocab.special_sep_id = -1; vocab.special_sep_id = -1;
vocab.special_pad_id = -1; vocab.special_pad_id = -1;
const int add_space_prefix_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str());
if (add_space_prefix_keyidx != -1) {
vocab.add_space_prefix = gguf_get_val_bool(ctx, add_space_prefix_keyidx);
} // The default value of add_space_prefix is true.
} else if (tokenizer_name == "gpt2") { } else if (tokenizer_name == "gpt2") {
vocab.type = LLAMA_VOCAB_TYPE_BPE; vocab.type = LLAMA_VOCAB_TYPE_BPE;
@ -3269,7 +3315,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
// hparams // hparams
LLAMA_LOG_INFO("%s: format = %s\n", __func__, llama_file_version_name(ml.fver)); LLAMA_LOG_INFO("%s: format = %s\n", __func__, llama_file_version_name(ml.fver));
LLAMA_LOG_INFO("%s: arch = %s\n", __func__, LLM_ARCH_NAMES.at(model.arch).c_str()); LLAMA_LOG_INFO("%s: arch = %s\n", __func__, LLM_ARCH_NAMES.at(model.arch).c_str());
LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, vocab.type == LLAMA_VOCAB_TYPE_SPM ? "SPM" : "BPE"); // TODO: fix LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, llama_model_vocab_type_name(vocab.type));
LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, hparams.n_vocab); LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, hparams.n_vocab);
LLAMA_LOG_INFO("%s: n_merges = %u\n", __func__, (int) vocab.bpe_ranks.size()); LLAMA_LOG_INFO("%s: n_merges = %u\n", __func__, (int) vocab.bpe_ranks.size());
LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train); LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train);
@ -4018,8 +4064,35 @@ static bool llm_load_tensors(
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
} }
} break; } break;
case LLM_ARCH_INTERNLM2:
{
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
// output
{
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab});
}
for (int i = 0; i < n_layer; ++i) {
ggml_context * ctx_layer = ctx_for_layer(i);
ggml_context * ctx_split = ctx_for_layer_split(i);
auto & layer = model.layers[i];
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
// layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd});
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
}
} break;
default: default:
throw std::runtime_error("unknown architecture"); throw std::runtime_error("unknown architecture");
} }
@ -6588,6 +6661,126 @@ struct llm_build_context {
return gf; return gf;
} }
struct ggml_cgraph * build_internlm2() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
struct ggml_tensor * cur;
struct ggml_tensor * inpL;
inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
cb(inpL, "inp_embd", -1);
// inp_pos - contains the positions
struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
cb(inp_pos, "inp_pos", -1);
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
cb(KQ_mask, "KQ_mask", -1);
// shift the entire K-cache if needed
if (do_rope_shift) {
llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, lctx.inp_K_shift, LLM_ROPE, n_ctx, freq_base, freq_scale, cb);
}
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * inpSA = inpL;
// norm
cur = llm_build_norm(ctx0, inpL, hparams,
model.layers[il].attn_norm, NULL,
LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il);
// self-attention
{
// compute Q and K and RoPE them
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
if (model.layers[il].bq) {
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
cb(Qcur, "Qcur", il);
}
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
cb(Kcur, "Kcur", il);
if (model.layers[il].bk) {
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
cb(Kcur, "Kcur", il);
}
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
if (model.layers[il].bv) {
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
cb(Vcur, "Vcur", il);
}
Qcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
hparams.n_rot, 0, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
hparams.n_rot, 0, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur", il);
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il);
}
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
// feed-forward network
cur = llm_build_norm(ctx0, ffn_inp, hparams,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, cb, il);
cb(cur, "ffn_norm", il);
cur = llm_build_ffn(ctx0, cur,
model.layers[il].ffn_up, NULL,
model.layers[il].ffn_gate, NULL,
model.layers[il].ffn_down, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
cb(cur, "ffn_out", il);
cur = ggml_add(ctx0, cur, ffn_inp);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
}
cur = inpL;
cur = llm_build_norm(ctx0, cur, hparams,
model.output_norm, NULL,
LLM_NORM_RMS, cb, -1);
cb(cur, "result_norm", -1);
// lm_head
cur = ggml_mul_mat(ctx0, model.output, cur);
cb(cur, "result_output", -1);
ggml_build_forward_expand(gf, cur);
return gf;
}
}; };
static struct ggml_cgraph * llama_build_graph( static struct ggml_cgraph * llama_build_graph(
@ -6746,6 +6939,10 @@ static struct ggml_cgraph * llama_build_graph(
{ {
result = llm.build_orion(); result = llm.build_orion();
} break; } break;
case LLM_ARCH_INTERNLM2:
{
result = llm.build_internlm2();
} break;
default: default:
GGML_ASSERT(false); GGML_ASSERT(false);
} }
@ -7688,8 +7885,10 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
// //
auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length); auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
if (&fragment == &fragment_buffer.front()) { if (&fragment == &fragment_buffer.front()) {
if (vocab.add_space_prefix) {
raw_text = " " + raw_text; // prefix with space if the first token is not special raw_text = " " + raw_text; // prefix with space if the first token is not special
} }
}
#ifdef PRETOKENIZERDEBUG #ifdef PRETOKENIZERDEBUG
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str()); LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());