mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-11-11 13:30:35 +00:00
Add support for BERT embedding models (#5423)
* BERT model graph construction (build_bert) * WordPiece tokenizer (llm_tokenize_wpm) * Add flag for non-causal attention models * Allow for models that only output embeddings * Support conversion of BERT models to GGUF * Based on prior work by @xyzhang626 and @skeskinen --------- Co-authored-by: Jared Van Bortel <jared@nomic.ai> Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
97a336507e
commit
2891c8aa9a
@ -209,6 +209,8 @@ class Model:
|
||||
return InternLM2Model
|
||||
if model_architecture == "MiniCPMForCausalLM":
|
||||
return MiniCPMModel
|
||||
if model_architecture == "BertModel":
|
||||
return BertModel
|
||||
return Model
|
||||
|
||||
def _is_model_safetensors(self) -> bool:
|
||||
@ -264,6 +266,8 @@ class Model:
|
||||
return gguf.MODEL_ARCH.INTERNLM2
|
||||
if arch == "MiniCPMForCausalLM":
|
||||
return gguf.MODEL_ARCH.MINICPM
|
||||
if arch == "BertModel":
|
||||
return gguf.MODEL_ARCH.BERT
|
||||
|
||||
raise NotImplementedError(f'Architecture "{arch}" not supported!')
|
||||
|
||||
@ -1629,6 +1633,96 @@ in chat mode so that the conversation can end normally.")
|
||||
self.post_write_tensors(tensor_map, name, data_torch)
|
||||
|
||||
|
||||
class BertModel(Model):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.block_count = self.hparams["num_hidden_layers"]
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
# TODO(cebtenzzre): merge with parent class
|
||||
self.gguf_writer.add_name(self.dir_model.name)
|
||||
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
|
||||
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
|
||||
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
|
||||
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"])
|
||||
self.gguf_writer.add_causal_attention(False)
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
|
||||
def set_vocab(self):
|
||||
path = self.dir_model
|
||||
added_tokens_path = self.dir_model if self.dir_model.exists() else None
|
||||
|
||||
# use huggingface vocab to get all tokens
|
||||
vocab = HfVocab(path, added_tokens_path)
|
||||
tokens, scores, toktypes = zip(*vocab.all_tokens())
|
||||
assert len(tokens) == vocab.vocab_size
|
||||
|
||||
# we need this to validate the size of the token_type embeddings
|
||||
# though currently we are passing all zeros to the token_type embeddings
|
||||
n_token_types = len(set(toktypes))
|
||||
self.gguf_writer.add_token_type_count(n_token_types)
|
||||
|
||||
# convert to phantom space vocab
|
||||
def phantom(tok, typ):
|
||||
if tok.startswith(b"[") and tok.endswith(b"]"):
|
||||
return tok
|
||||
if tok.startswith(b"##"):
|
||||
return tok[2:]
|
||||
return b"\xe2\x96\x81" + tok
|
||||
tokens = [phantom(t, y) for t, y in zip(tokens, toktypes)]
|
||||
|
||||
# set up bos and eos tokens (cls and sep)
|
||||
self.gguf_writer.add_bos_token_id(vocab.tokenizer.cls_token_id)
|
||||
self.gguf_writer.add_eos_token_id(vocab.tokenizer.sep_token_id)
|
||||
|
||||
# add vocab to gguf
|
||||
self.gguf_writer.add_tokenizer_model("bert")
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_scores(scores)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
|
||||
# handle special tokens
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
def write_tensors(self):
|
||||
tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
|
||||
tensors = dict(self.get_tensors())
|
||||
for name, data_torch in tensors.items():
|
||||
# we are only using BERT for embeddings so we don't need the pooling layer
|
||||
if name in ("embeddings.position_ids", "pooler.dense.weight", "pooler.dense.bias"):
|
||||
continue # we don't need these
|
||||
|
||||
# map tensor names
|
||||
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
||||
if new_name is None:
|
||||
print(f"Can not map tensor {name!r}")
|
||||
sys.exit()
|
||||
|
||||
data = data_torch.squeeze().numpy()
|
||||
n_dims = len(data.shape)
|
||||
new_dtype: type[np.floating[Any]]
|
||||
|
||||
if (
|
||||
self.ftype == 1 and name.endswith(".weight") and n_dims == 2
|
||||
and name != "embeddings.token_type_embeddings.weight" # not used with get_rows, must be F32
|
||||
):
|
||||
# if f16 desired, convert any float32 2-dim weight tensors to float16
|
||||
new_dtype = np.float16
|
||||
else:
|
||||
# if f32 desired, convert any float16 to float32
|
||||
new_dtype = np.float32
|
||||
|
||||
print(f"{new_name}, n_dims = {n_dims}, {data_torch.dtype} --> {new_dtype}")
|
||||
|
||||
if data.dtype != new_dtype:
|
||||
data = data.astype(new_dtype)
|
||||
|
||||
self.gguf_writer.add_tensor(new_name, data)
|
||||
|
||||
|
||||
###### CONVERSION LOGIC ######
|
||||
|
||||
|
||||
|
@ -87,7 +87,17 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
const int n_embd = llama_n_embd(model);
|
||||
const auto * embeddings = llama_get_embeddings(ctx);
|
||||
auto * embeddings = llama_get_embeddings(ctx);
|
||||
|
||||
// l2-normalize embeddings
|
||||
float norm = 0;
|
||||
for (int i = 0; i < n_embd; i++) {
|
||||
norm += embeddings[i] * embeddings[i];
|
||||
}
|
||||
norm = sqrt(norm);
|
||||
for (int i = 0; i < n_embd; i++) {
|
||||
embeddings[i] /= norm;
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_embd; i++) {
|
||||
printf("%f ", embeddings[i]);
|
||||
|
@ -50,6 +50,7 @@ class Keys:
|
||||
VALUE_LENGTH = "{arch}.attention.value_length"
|
||||
LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon"
|
||||
LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
|
||||
CAUSAL = "{arch}.attention.causal"
|
||||
|
||||
class Rope:
|
||||
DIMENSION_COUNT = "{arch}.rope.dimension_count"
|
||||
@ -60,22 +61,23 @@ class Keys:
|
||||
SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
|
||||
|
||||
class Tokenizer:
|
||||
MODEL = "tokenizer.ggml.model"
|
||||
LIST = "tokenizer.ggml.tokens"
|
||||
TOKEN_TYPE = "tokenizer.ggml.token_type"
|
||||
SCORES = "tokenizer.ggml.scores"
|
||||
MERGES = "tokenizer.ggml.merges"
|
||||
BOS_ID = "tokenizer.ggml.bos_token_id"
|
||||
EOS_ID = "tokenizer.ggml.eos_token_id"
|
||||
UNK_ID = "tokenizer.ggml.unknown_token_id"
|
||||
SEP_ID = "tokenizer.ggml.seperator_token_id"
|
||||
PAD_ID = "tokenizer.ggml.padding_token_id"
|
||||
ADD_BOS = "tokenizer.ggml.add_bos_token"
|
||||
ADD_EOS = "tokenizer.ggml.add_eos_token"
|
||||
ADD_PREFIX = "tokenizer.ggml.add_space_prefix"
|
||||
HF_JSON = "tokenizer.huggingface.json"
|
||||
RWKV = "tokenizer.rwkv.world"
|
||||
CHAT_TEMPLATE = "tokenizer.chat_template"
|
||||
MODEL = "tokenizer.ggml.model"
|
||||
LIST = "tokenizer.ggml.tokens"
|
||||
TOKEN_TYPE = "tokenizer.ggml.token_type"
|
||||
TOKEN_TYPE_COUNT = "tokenizer.ggml.token_type_count" # for BERT-style token types
|
||||
SCORES = "tokenizer.ggml.scores"
|
||||
MERGES = "tokenizer.ggml.merges"
|
||||
BOS_ID = "tokenizer.ggml.bos_token_id"
|
||||
EOS_ID = "tokenizer.ggml.eos_token_id"
|
||||
UNK_ID = "tokenizer.ggml.unknown_token_id"
|
||||
SEP_ID = "tokenizer.ggml.seperator_token_id"
|
||||
PAD_ID = "tokenizer.ggml.padding_token_id"
|
||||
ADD_BOS = "tokenizer.ggml.add_bos_token"
|
||||
ADD_EOS = "tokenizer.ggml.add_eos_token"
|
||||
ADD_PREFIX = "tokenizer.ggml.add_space_prefix"
|
||||
HF_JSON = "tokenizer.huggingface.json"
|
||||
RWKV = "tokenizer.rwkv.world"
|
||||
CHAT_TEMPLATE = "tokenizer.chat_template"
|
||||
|
||||
|
||||
#
|
||||
@ -122,6 +124,7 @@ class MODEL_TENSOR(IntEnum):
|
||||
ATTN_OUT = auto()
|
||||
ATTN_NORM = auto()
|
||||
ATTN_NORM_2 = auto()
|
||||
ATTN_OUT_NORM = auto()
|
||||
ATTN_ROT_EMBD = auto()
|
||||
FFN_GATE_INP = auto()
|
||||
FFN_NORM = auto()
|
||||
@ -134,6 +137,7 @@ class MODEL_TENSOR(IntEnum):
|
||||
FFN_UP_EXP = auto()
|
||||
ATTN_Q_NORM = auto()
|
||||
ATTN_K_NORM = auto()
|
||||
LAYER_OUT_NORM = auto()
|
||||
|
||||
|
||||
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
@ -178,6 +182,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd",
|
||||
MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm",
|
||||
MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm",
|
||||
MODEL_TENSOR.ATTN_OUT_NORM: "blk.{bid}.attn_output_norm",
|
||||
MODEL_TENSOR.FFN_GATE_INP: "blk.{bid}.ffn_gate_inp",
|
||||
MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm",
|
||||
MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate",
|
||||
@ -187,6 +192,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate.{xid}",
|
||||
MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down.{xid}",
|
||||
MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up.{xid}",
|
||||
MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm",
|
||||
}
|
||||
|
||||
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
@ -262,17 +268,18 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
],
|
||||
MODEL_ARCH.BERT: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.TOKEN_EMBD_NORM,
|
||||
MODEL_TENSOR.TOKEN_TYPES,
|
||||
MODEL_TENSOR.POS_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_OUT_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.LAYER_OUT_NORM,
|
||||
],
|
||||
MODEL_ARCH.MPT: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
|
@ -357,6 +357,9 @@ class GGUFWriter:
|
||||
def add_layer_norm_rms_eps(self, value: float) -> None:
|
||||
self.add_float32(Keys.Attention.LAYERNORM_RMS_EPS.format(arch=self.arch), value)
|
||||
|
||||
def add_causal_attention(self, value: bool) -> None:
|
||||
self.add_bool(Keys.Attention.CAUSAL.format(arch=self.arch), value)
|
||||
|
||||
def add_rope_dimension_count(self, count: int) -> None:
|
||||
self.add_uint32(Keys.Rope.DIMENSION_COUNT.format(arch=self.arch), count)
|
||||
|
||||
@ -387,6 +390,9 @@ class GGUFWriter:
|
||||
def add_token_types(self, types: Sequence[TokenType] | Sequence[int]) -> None:
|
||||
self.add_array(Keys.Tokenizer.TOKEN_TYPE, types)
|
||||
|
||||
def add_token_type_count(self, value: int) -> None:
|
||||
self.add_uint32(Keys.Tokenizer.TOKEN_TYPE_COUNT, value)
|
||||
|
||||
def add_token_scores(self, scores: Sequence[float]) -> None:
|
||||
self.add_array(Keys.Tokenizer.SCORES, scores)
|
||||
|
||||
|
@ -30,6 +30,7 @@ class TensorNameMap:
|
||||
# Normalization of token embeddings
|
||||
MODEL_TENSOR.TOKEN_EMBD_NORM: (
|
||||
"word_embeddings_layernorm", # bloom
|
||||
"embeddings.LayerNorm", # bert
|
||||
),
|
||||
|
||||
# Position embeddings
|
||||
@ -54,7 +55,6 @@ class TensorNameMap:
|
||||
"transformer.ln_f", # gpt2 gpt-j falcon
|
||||
"model.norm", # llama-hf baichuan internlm2
|
||||
"norm", # llama-pth
|
||||
"embeddings.LayerNorm", # bert
|
||||
"transformer.norm_f", # mpt
|
||||
"ln_f", # refact bloom qwen gpt2
|
||||
"language_model.encoder.final_layernorm", # persimmon
|
||||
@ -79,7 +79,6 @@ class TensorNameMap:
|
||||
"transformer.h.{bid}.ln_mlp", # falcon40b
|
||||
"model.layers.{bid}.input_layernorm", # llama-hf
|
||||
"layers.{bid}.attention_norm", # llama-pth
|
||||
"encoder.layer.{bid}.attention.output.LayerNorm", # bert
|
||||
"language_model.encoder.layers.{bid}.input_layernorm", # persimmon
|
||||
"model.layers.{bid}.ln1", # yi
|
||||
"h.{bid}.ln_1", # gpt2
|
||||
@ -155,6 +154,11 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.attention.wo", # internlm2
|
||||
),
|
||||
|
||||
# Attention output norm
|
||||
MODEL_TENSOR.ATTN_OUT_NORM: (
|
||||
"encoder.layer.{bid}.attention.output.LayerNorm", # bert
|
||||
),
|
||||
|
||||
# Rotary embeddings
|
||||
MODEL_TENSOR.ATTN_ROT_EMBD: (
|
||||
"model.layers.{bid}.self_attn.rotary_emb.inv_freq", # llama-hf
|
||||
@ -171,7 +175,6 @@ class TensorNameMap:
|
||||
"transformer.blocks.{bid}.norm_2", # mpt
|
||||
"model.layers.{bid}.post_attention_layernorm", # llama-hf
|
||||
"layers.{bid}.ffn_norm", # llama-pth
|
||||
"encoder.layer.{bid}.output.LayerNorm", # bert
|
||||
"language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon
|
||||
"model.layers.{bid}.ln2", # yi
|
||||
"h.{bid}.ln_2", # gpt2
|
||||
@ -266,6 +269,10 @@ class TensorNameMap:
|
||||
MODEL_TENSOR.ROPE_FREQS: (
|
||||
"language_model.encoder.layers.{bid}.self_attention.rotary_emb.inv_freq", # persimmon
|
||||
),
|
||||
|
||||
MODEL_TENSOR.LAYER_OUT_NORM: (
|
||||
"encoder.layer.{bid}.output.LayerNorm", # bert
|
||||
)
|
||||
}
|
||||
|
||||
mapping: dict[str, tuple[MODEL_TENSOR, str]]
|
||||
|
498
llama.cpp
498
llama.cpp
@ -196,6 +196,7 @@ enum llm_arch {
|
||||
LLM_ARCH_STARCODER,
|
||||
LLM_ARCH_PERSIMMON,
|
||||
LLM_ARCH_REFACT,
|
||||
LLM_ARCH_BERT,
|
||||
LLM_ARCH_BLOOM,
|
||||
LLM_ARCH_STABLELM,
|
||||
LLM_ARCH_QWEN,
|
||||
@ -220,6 +221,7 @@ static std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_STARCODER, "starcoder" },
|
||||
{ LLM_ARCH_PERSIMMON, "persimmon" },
|
||||
{ LLM_ARCH_REFACT, "refact" },
|
||||
{ LLM_ARCH_BERT, "bert" },
|
||||
{ LLM_ARCH_BLOOM, "bloom" },
|
||||
{ LLM_ARCH_STABLELM, "stablelm" },
|
||||
{ LLM_ARCH_QWEN, "qwen" },
|
||||
@ -261,6 +263,7 @@ enum llm_kv {
|
||||
LLM_KV_ATTENTION_VALUE_LENGTH,
|
||||
LLM_KV_ATTENTION_LAYERNORM_EPS,
|
||||
LLM_KV_ATTENTION_LAYERNORM_RMS_EPS,
|
||||
LLM_KV_ATTENTION_CAUSAL,
|
||||
|
||||
LLM_KV_ROPE_DIMENSION_COUNT,
|
||||
LLM_KV_ROPE_FREQ_BASE,
|
||||
@ -273,6 +276,7 @@ enum llm_kv {
|
||||
LLM_KV_TOKENIZER_MODEL,
|
||||
LLM_KV_TOKENIZER_LIST,
|
||||
LLM_KV_TOKENIZER_TOKEN_TYPE,
|
||||
LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT,
|
||||
LLM_KV_TOKENIZER_SCORES,
|
||||
LLM_KV_TOKENIZER_MERGES,
|
||||
LLM_KV_TOKENIZER_BOS_ID,
|
||||
@ -316,6 +320,7 @@ static std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
{ LLM_KV_ATTENTION_VALUE_LENGTH, "%s.attention.value_length" },
|
||||
{ LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" },
|
||||
{ LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" },
|
||||
{ LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" },
|
||||
|
||||
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
|
||||
{ LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
|
||||
@ -328,6 +333,7 @@ static std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
{ LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
|
||||
{ LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" },
|
||||
{ LLM_KV_TOKENIZER_TOKEN_TYPE, "tokenizer.ggml.token_type" },
|
||||
{ LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, "tokenizer.ggml.token_type_count" },
|
||||
{ LLM_KV_TOKENIZER_SCORES, "tokenizer.ggml.scores" },
|
||||
{ LLM_KV_TOKENIZER_MERGES, "tokenizer.ggml.merges" },
|
||||
{ LLM_KV_TOKENIZER_BOS_ID, "tokenizer.ggml.bos_token_id" },
|
||||
@ -355,6 +361,7 @@ struct LLM_KV {
|
||||
enum llm_tensor {
|
||||
LLM_TENSOR_TOKEN_EMBD,
|
||||
LLM_TENSOR_TOKEN_EMBD_NORM,
|
||||
LLM_TENSOR_TOKEN_TYPES,
|
||||
LLM_TENSOR_POS_EMBD,
|
||||
LLM_TENSOR_OUTPUT,
|
||||
LLM_TENSOR_OUTPUT_NORM,
|
||||
@ -536,6 +543,23 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES =
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_BERT,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
|
||||
{ LLM_TENSOR_TOKEN_TYPES, "token_types" },
|
||||
{ LLM_TENSOR_POS_EMBD, "position_embd" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_output_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.layer_output_norm" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_BLOOM,
|
||||
{
|
||||
@ -1440,6 +1464,11 @@ static llama_state g_state;
|
||||
// available llama models
|
||||
enum e_model {
|
||||
MODEL_UNKNOWN,
|
||||
MODEL_17M,
|
||||
MODEL_22M,
|
||||
MODEL_33M,
|
||||
MODEL_109M,
|
||||
MODEL_335M,
|
||||
MODEL_0_5B,
|
||||
MODEL_1B,
|
||||
MODEL_2B,
|
||||
@ -1481,6 +1510,7 @@ struct llama_hparams {
|
||||
uint32_t n_ff;
|
||||
uint32_t n_expert = 0;
|
||||
uint32_t n_expert_used = 0;
|
||||
uint32_t n_vocab_type = 0; // for BERT-style token types
|
||||
|
||||
float f_norm_eps;
|
||||
float f_norm_rms_eps;
|
||||
@ -1493,6 +1523,8 @@ struct llama_hparams {
|
||||
float f_clamp_kqv;
|
||||
float f_max_alibi_bias;
|
||||
|
||||
bool causal_attn = true;
|
||||
|
||||
|
||||
bool operator!=(const llama_hparams & other) const {
|
||||
if (this->vocab_only != other.vocab_only) return true;
|
||||
@ -1720,6 +1752,7 @@ struct llama_model {
|
||||
llama_vocab vocab;
|
||||
|
||||
struct ggml_tensor * tok_embd;
|
||||
struct ggml_tensor * type_embd;
|
||||
struct ggml_tensor * pos_embd;
|
||||
struct ggml_tensor * tok_norm;
|
||||
struct ggml_tensor * tok_norm_b;
|
||||
@ -1850,6 +1883,7 @@ struct llama_context {
|
||||
struct ggml_tensor * inp_pos; // I32 [n_batch]
|
||||
struct ggml_tensor * inp_KQ_mask; // F32 [n_ctx, n_batch]
|
||||
struct ggml_tensor * inp_K_shift; // I32 [n_ctx]
|
||||
struct ggml_tensor * inp_sum; // F32 [1, n_batch]
|
||||
|
||||
#ifdef GGML_USE_MPI
|
||||
ggml_mpi_context * ctx_mpi = NULL;
|
||||
@ -2829,6 +2863,7 @@ static const char * llama_model_vocab_type_name(enum llama_vocab_type type){
|
||||
switch (type) {
|
||||
case LLAMA_VOCAB_TYPE_SPM: return "SPM";
|
||||
case LLAMA_VOCAB_TYPE_BPE: return "BPE";
|
||||
case LLAMA_VOCAB_TYPE_WPM: return "WPM";
|
||||
default: return "unknown";
|
||||
}
|
||||
}
|
||||
@ -3000,6 +3035,26 @@ static void llm_load_hparams(
|
||||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_BERT:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
||||
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
|
||||
ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type);
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 3:
|
||||
model.type = e_model::MODEL_17M; break; // bge-micro
|
||||
case 6:
|
||||
model.type = e_model::MODEL_22M; break; // MiniLM-L6
|
||||
case 12:
|
||||
switch (hparams.n_embd) {
|
||||
case 384: model.type = e_model::MODEL_33M; break; // MiniLM-L12, bge-small
|
||||
case 768: model.type = e_model::MODEL_109M; break; // bge-base
|
||||
} break;
|
||||
case 24:
|
||||
model.type = e_model::MODEL_335M; break; // bge-large
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_BLOOM:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
||||
@ -3204,6 +3259,16 @@ static void llm_load_vocab(
|
||||
vocab.special_unk_id = -1;
|
||||
vocab.special_sep_id = -1;
|
||||
vocab.special_pad_id = -1;
|
||||
} else if (tokenizer_name == "bert") {
|
||||
vocab.type = LLAMA_VOCAB_TYPE_WPM;
|
||||
|
||||
// default special tokens
|
||||
vocab.special_bos_id = 101;
|
||||
vocab.special_eos_id = 102;
|
||||
vocab.special_unk_id = 100;
|
||||
vocab.special_sep_id = -1;
|
||||
vocab.special_pad_id = -1;
|
||||
vocab.add_space_prefix = false;
|
||||
} else {
|
||||
LLAMA_LOG_WARN("%s: unknown tokenizer: '%s'", __func__, tokenizer_name.c_str());
|
||||
LLAMA_LOG_WARN("%s: using default tokenizer: 'llama'", __func__);
|
||||
@ -3232,6 +3297,8 @@ static void llm_load_vocab(
|
||||
// determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n'
|
||||
if (vocab.type == LLAMA_VOCAB_TYPE_SPM) {
|
||||
vocab.linefeed_id = llama_byte_to_token(vocab, '\n');
|
||||
} else if (vocab.type == LLAMA_VOCAB_TYPE_WPM) {
|
||||
vocab.linefeed_id = vocab.special_pad_id;
|
||||
} else {
|
||||
const std::vector<int> ids = llama_tokenize_internal(vocab, "\u010A", false);
|
||||
GGML_ASSERT(!ids.empty() && "model vocab missing newline token");
|
||||
@ -3569,6 +3636,7 @@ static bool llm_load_tensors(
|
||||
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
||||
const int64_t n_embd_gqa = n_embd_v_gqa;
|
||||
const int64_t n_vocab = hparams.n_vocab;
|
||||
const int64_t n_vocab_type = hparams.n_vocab_type;
|
||||
const int64_t n_ff = hparams.n_ff;
|
||||
|
||||
GGML_ASSERT(n_embd_gqa == n_embd_k_gqa);
|
||||
@ -3783,11 +3851,50 @@ static bool llm_load_tensors(
|
||||
layer.attn_k_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {64});
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_BLOOM:
|
||||
case LLM_ARCH_BERT:
|
||||
{
|
||||
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
|
||||
model.tok_norm = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd});
|
||||
model.tok_norm_b = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd});
|
||||
model.type_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_vocab_type});
|
||||
model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, hparams.n_ctx_train});
|
||||
model.tok_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd});
|
||||
model.tok_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd});
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
ggml_context * ctx_layer = ctx_for_layer(i);
|
||||
ggml_context * ctx_split = ctx_for_layer_split(i);
|
||||
|
||||
auto & layer = model.layers[i];
|
||||
|
||||
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
|
||||
layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd});
|
||||
|
||||
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
|
||||
layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd});
|
||||
|
||||
layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
|
||||
layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd});
|
||||
|
||||
layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
|
||||
layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa});
|
||||
|
||||
layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
|
||||
layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa});
|
||||
|
||||
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
|
||||
layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd});
|
||||
|
||||
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
|
||||
layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff});
|
||||
|
||||
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
|
||||
layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd});
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_BLOOM:
|
||||
{
|
||||
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
|
||||
model.tok_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd});
|
||||
model.tok_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd});
|
||||
|
||||
// output
|
||||
{
|
||||
@ -4739,6 +4846,7 @@ struct llm_build_context {
|
||||
const int32_t n_orig_ctx;
|
||||
|
||||
const bool do_rope_shift;
|
||||
const bool causal_attn;
|
||||
|
||||
const llm_build_cb & cb;
|
||||
|
||||
@ -4782,6 +4890,7 @@ struct llm_build_context {
|
||||
kv_head (worst_case ? n_ctx - n_tokens : kv_self.head),
|
||||
n_orig_ctx (cparams.n_yarn_orig_ctx),
|
||||
do_rope_shift (worst_case || kv_self.has_shift),
|
||||
causal_attn (hparams.causal_attn),
|
||||
cb (cb),
|
||||
buf_compute_meta (lctx.buf_compute_meta) {
|
||||
// all initializations should be done in init()
|
||||
@ -5625,6 +5734,100 @@ struct llm_build_context {
|
||||
return gf;
|
||||
}
|
||||
|
||||
struct ggml_cgraph * build_bert() {
|
||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
|
||||
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
|
||||
struct ggml_tensor * cur;
|
||||
struct ggml_tensor * inpL;
|
||||
|
||||
// get input vectors with right size
|
||||
struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
|
||||
struct ggml_tensor * inp_sum = ggml_view_1d(ctx0, lctx.inp_sum, n_tokens, 0);
|
||||
|
||||
// construct input embeddings (token, type, position)
|
||||
inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
|
||||
// token types are hardcoded to zero ("Sentence A")
|
||||
struct ggml_tensor * type_row0 = ggml_view_1d(ctx0, model.type_embd, n_embd, 0);
|
||||
inpL = ggml_add(ctx0, inpL, type_row0);
|
||||
inpL = ggml_add(ctx0, ggml_get_rows(ctx0, model.pos_embd, inp_pos), inpL);
|
||||
cb(inpL, "inp_embd", -1);
|
||||
|
||||
// embed layer norm
|
||||
inpL = llm_build_norm(ctx0, inpL, hparams, model.tok_norm, model.tok_norm_b, LLM_NORM, cb, -1);
|
||||
cb(inpL, "inp_norm", -1);
|
||||
|
||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
|
||||
cb(KQ_mask, "KQ_mask", -1); // [n_kv, n_tokens]
|
||||
|
||||
// iterate layers
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * cur = inpL;
|
||||
|
||||
// self-attention
|
||||
{
|
||||
struct ggml_tensor * Qcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), model.layers[il].bq);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
struct ggml_tensor * Kcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), model.layers[il].bk);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
struct ggml_tensor * Vcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, cur), model.layers[il].bv);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
// seems like we just need to do this for Q?
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
|
||||
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
cb(cur, "kqv_out", il);
|
||||
}
|
||||
|
||||
// re-add the layer input
|
||||
cur = ggml_add(ctx0, cur, inpL);
|
||||
|
||||
// attention layer norm
|
||||
cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_norm, model.layers[il].attn_norm_b, LLM_NORM, cb, il);
|
||||
|
||||
struct ggml_tensor * ffn_inp = cur;
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
// feed-forward network
|
||||
cur = llm_build_ffn(ctx0, cur,
|
||||
model.layers[il].ffn_up, model.layers[il].ffn_up_b,
|
||||
NULL, NULL,
|
||||
model.layers[il].ffn_down, model.layers[il].ffn_down_b,
|
||||
NULL,
|
||||
LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
|
||||
// attentions bypass the intermediate layer
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
|
||||
// output layer norm
|
||||
cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].ffn_norm, model.layers[il].ffn_norm_b, LLM_NORM, cb, il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
// final output
|
||||
cur = inpL;
|
||||
|
||||
// pooling
|
||||
cur = ggml_mul_mat(ctx0, inp_sum, ggml_cont(ctx0, ggml_transpose(ctx0, cur)));
|
||||
cb(cur, "result_embed", -1);
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
struct ggml_cgraph * build_bloom() {
|
||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
|
||||
|
||||
@ -7060,7 +7263,8 @@ static struct ggml_cgraph * llama_build_graph(
|
||||
|
||||
for (int i = 0; i < n_kv; ++i) {
|
||||
float f;
|
||||
if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
|
||||
if (!lctx.kv_self.cells[i].has_seq_id(seq_id) ||
|
||||
(llm.causal_attn && lctx.kv_self.cells[i].pos > pos)) {
|
||||
f = -INFINITY;
|
||||
} else {
|
||||
f = 0;
|
||||
@ -7081,6 +7285,15 @@ static struct ggml_cgraph * llama_build_graph(
|
||||
data[i] = lctx.kv_self.cells[i].delta;
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_sum->buffer));
|
||||
float * data = (float *) lctx.inp_sum->data;
|
||||
|
||||
for (int i = 0; i < batch.n_tokens; ++i) {
|
||||
data[i] = 1.0f/float(batch.n_tokens);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
llm.init();
|
||||
@ -7110,6 +7323,10 @@ static struct ggml_cgraph * llama_build_graph(
|
||||
{
|
||||
result = llm.build_refact();
|
||||
} break;
|
||||
case LLM_ARCH_BERT:
|
||||
{
|
||||
result = llm.build_bert();
|
||||
} break;
|
||||
case LLM_ARCH_BLOOM:
|
||||
{
|
||||
result = llm.build_bloom();
|
||||
@ -7269,13 +7486,18 @@ static int llama_decode_internal(
|
||||
|
||||
// the output is always the last tensor in the graph
|
||||
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
|
||||
GGML_ASSERT(strcmp(res->name, "result_output") == 0);
|
||||
|
||||
// the embeddings could be the second to last tensor, or the third to last tensor
|
||||
struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2];
|
||||
if (strcmp(embeddings->name, "result_norm") != 0) {
|
||||
embeddings = gf->nodes[gf->n_nodes - 3];
|
||||
GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0);
|
||||
if (strcmp(res->name, "result_output") == 0) {
|
||||
// the embeddings could be the second to last tensor, or the third to last tensor
|
||||
if (strcmp(embeddings->name, "result_norm") != 0) {
|
||||
embeddings = gf->nodes[gf->n_nodes - 3];
|
||||
GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0);
|
||||
}
|
||||
} else if (strcmp(res->name, "result_embed") == 0) {
|
||||
embeddings = res;
|
||||
res = nullptr;
|
||||
} else {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
|
||||
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
|
||||
@ -7344,7 +7566,7 @@ static int llama_decode_internal(
|
||||
// extract logits
|
||||
// TODO: do not compute and extract logits if only embeddings are needed
|
||||
// need to update the graphs to skip "result_output"
|
||||
{
|
||||
if (res) {
|
||||
auto & logits_out = lctx.logits;
|
||||
|
||||
#ifndef NDEBUG
|
||||
@ -7388,9 +7610,11 @@ static int llama_decode_internal(
|
||||
if (!lctx.embedding.empty()) {
|
||||
auto & embedding_out = lctx.embedding;
|
||||
|
||||
const int64_t embed_pos = res ? n_embd * (n_tokens-1) : 0;
|
||||
|
||||
embedding_out.resize(n_embd);
|
||||
ggml_backend_t embeddings_backend = ggml_backend_sched_get_node_backend(lctx.sched, embeddings);
|
||||
ggml_backend_tensor_get_async(embeddings_backend, embeddings, embedding_out.data(), (n_embd*(n_tokens - 1))*sizeof(float), n_embd*sizeof(float));
|
||||
ggml_backend_tensor_get_async(embeddings_backend, embeddings, embedding_out.data(), embed_pos*sizeof(float), n_embd*sizeof(float));
|
||||
ggml_backend_synchronize(embeddings_backend);
|
||||
}
|
||||
|
||||
@ -7454,6 +7678,9 @@ static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) {
|
||||
GGML_ASSERT(false);
|
||||
return unicode_to_bytes_bpe(token_data.text);
|
||||
}
|
||||
case LLAMA_VOCAB_TYPE_WPM: {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
@ -7466,6 +7693,7 @@ static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) {
|
||||
const char buf[7] = { '<', '0', 'x', hex[ch >> 4], hex[ch & 15], '>', 0 };
|
||||
return vocab.token_to_id.at(buf);
|
||||
}
|
||||
case LLAMA_VOCAB_TYPE_WPM:
|
||||
case LLAMA_VOCAB_TYPE_BPE: {
|
||||
return vocab.token_to_id.at(bytes_to_unicode_bpe(ch));
|
||||
}
|
||||
@ -7936,12 +8164,212 @@ private:
|
||||
llm_bigram_bpe::queue work_queue;
|
||||
};
|
||||
|
||||
typedef enum FRAGMENT_BUFFER_VARIANT_TYPE{
|
||||
struct llm_tokenizer_wpm {
|
||||
llm_tokenizer_wpm(const llama_vocab & vocab): vocab(vocab) {}
|
||||
|
||||
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
|
||||
auto * token_map = &vocab.token_to_id;
|
||||
|
||||
// normalize and split by whitespace
|
||||
std::vector<std::string> words = preprocess(text);
|
||||
|
||||
// bos token prepended already
|
||||
|
||||
// find the longest tokens that form the words
|
||||
for (const std::string &word : words) {
|
||||
// skip empty words
|
||||
if (word.size() == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// prepend phantom space
|
||||
std::string word1 = "\xe2\x96\x81" + word;
|
||||
int n = word1.size();
|
||||
|
||||
// we're at the start of a new word
|
||||
int i = 0;
|
||||
bool match_any = false;
|
||||
|
||||
// move through character position in word
|
||||
while (i < n) {
|
||||
// loop through possible match length
|
||||
bool match = false;
|
||||
for (int j = n; j > i; j--) {
|
||||
auto it = token_map->find(word1.substr(i, j - i));
|
||||
if (it != token_map->end()) {
|
||||
output.push_back(it->second);
|
||||
match = true;
|
||||
match_any = true;
|
||||
i = j;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// must be an unknown character
|
||||
if (!match) {
|
||||
i++;
|
||||
}
|
||||
}
|
||||
|
||||
// we didn't find any matches for this word
|
||||
if (!match_any) {
|
||||
output.push_back(vocab.special_unk_id);
|
||||
}
|
||||
}
|
||||
|
||||
// append eos token
|
||||
output.push_back(vocab.special_eos_id);
|
||||
}
|
||||
|
||||
std::vector<std::string> preprocess(const std::string & text) {
|
||||
std::string ori_str = normalize(text);
|
||||
uint64_t ori_size = ori_str.size();
|
||||
|
||||
// single punct / single symbol / single digit
|
||||
// baseline: add whitespace on the left and right of punct and chinese characters
|
||||
std::vector<std::string> words;
|
||||
std::string new_str = "";
|
||||
uint64_t i = 0;
|
||||
while (i < ori_size) {
|
||||
int utf_char_len = utf8_len(ori_str[i]);
|
||||
if ((utf_char_len == 1) && ispunct(ori_str[i])) {
|
||||
new_str += " ";
|
||||
new_str += ori_str[i];
|
||||
new_str += " ";
|
||||
i += 1;
|
||||
}
|
||||
else if ((utf_char_len == 3) && is_chinese_char(ori_str.substr(i, 3))) {
|
||||
new_str += " ";
|
||||
new_str += ori_str.substr(i, 3);
|
||||
new_str += " ";
|
||||
i += 3;
|
||||
}
|
||||
else {
|
||||
new_str += ori_str[i];
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// split by whitespace
|
||||
uint64_t l = 0;
|
||||
uint64_t r = 0;
|
||||
while (r < new_str.size()) {
|
||||
// if is whitespace
|
||||
if (isspace(new_str[r])) {
|
||||
if (r > l) words.push_back(new_str.substr(l, (r - l)));
|
||||
l = r + 1;
|
||||
r = l;
|
||||
}
|
||||
else {
|
||||
r += 1;
|
||||
}
|
||||
}
|
||||
if (r > l) {
|
||||
words.push_back(new_str.substr(l, (r - l)));
|
||||
}
|
||||
return words;
|
||||
}
|
||||
|
||||
std::string normalize(const std::string & text) {
|
||||
// TODO: handle chinese characters? https://github.com/huggingface/tokenizers/blob/ef5f50605ddf9f8caef1598c0e4853862b9707a7/tokenizers/src/normalizers/bert.rs#L98
|
||||
std::string text2 = strip_accents(text);
|
||||
for (size_t i = 0; i < text2.size(); i += utf8_len(text2[i])) {
|
||||
char c = text2[i];
|
||||
if (c >= 'A' && c <= 'Z') {
|
||||
text2[i] = c - 'A' + 'a';
|
||||
}
|
||||
}
|
||||
return text2;
|
||||
}
|
||||
|
||||
bool is_chinese_char(const std::string & str) {
|
||||
int len = str.length();
|
||||
unsigned int codepoint = 0;
|
||||
int num_bytes = 0;
|
||||
int i = 0;
|
||||
unsigned char ch = static_cast<unsigned char>(str[i]);
|
||||
if (ch <= 0x7f) {
|
||||
codepoint = ch;
|
||||
num_bytes = 1;
|
||||
} else if ((ch >> 5) == 0x06) {
|
||||
codepoint = ch & 0x1f;
|
||||
num_bytes = 2;
|
||||
} else if ((ch >> 4) == 0x0e) {
|
||||
codepoint = ch & 0x0f;
|
||||
num_bytes = 3;
|
||||
} else if ((ch >> 3) == 0x1e) {
|
||||
codepoint = ch & 0x07;
|
||||
num_bytes = 4;
|
||||
}
|
||||
for (int j = 1; j < num_bytes; ++j) {
|
||||
if (i + j >= len) {
|
||||
return false; // incomplete UTF-8 character
|
||||
}
|
||||
unsigned char next_ch = static_cast<unsigned char>(str[i + j]);
|
||||
if ((next_ch >> 6) != 0x02) {
|
||||
return false; // invalid trailing byte
|
||||
}
|
||||
codepoint = (codepoint << 6) | (next_ch & 0x3f);
|
||||
}
|
||||
if ((codepoint >= 0x4E00 && codepoint <= 0x9FFF) ||
|
||||
(codepoint >= 0x3400 && codepoint <= 0x4DBF) ||
|
||||
(codepoint >= 0x20000 && codepoint <= 0x2A6DF) ||
|
||||
(codepoint >= 0x2A700 && codepoint <= 0x2B73F) ||
|
||||
(codepoint >= 0x2B740 && codepoint <= 0x2B81F) ||
|
||||
(codepoint >= 0x2B920 && codepoint <= 0x2CEAF) || // this should be 0x2B820 but in hf rust code it is 0x2B920
|
||||
(codepoint >= 0xF900 && codepoint <= 0xFAFF) ||
|
||||
(codepoint >= 0x2F800 && codepoint <= 0x2FA1F) ||
|
||||
(codepoint >= 0x3000 && codepoint <= 0x303F) ||
|
||||
(codepoint >= 0xFF00 && codepoint <= 0xFFEF)) {
|
||||
return true; // NOLINT
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string strip_accents(const std::string & input_string) {
|
||||
std::string resultString;
|
||||
std::map<std::string, char> accent_map = {
|
||||
{"À", 'A'}, {"Á", 'A'}, {"Â", 'A'}, {"Ã", 'A'}, {"Ä", 'A'}, {"Å", 'A'},
|
||||
{"à", 'a'}, {"á", 'a'}, {"â", 'a'}, {"ã", 'a'}, {"ä", 'a'}, {"å", 'a'},
|
||||
{"È", 'E'}, {"É", 'E'}, {"Ê", 'E'}, {"Ë", 'E'}, {"è", 'e'}, {"é", 'e'},
|
||||
{"ê", 'e'}, {"ë", 'e'}, {"Ì", 'I'}, {"Í", 'I'}, {"Î", 'I'}, {"Ï", 'I'},
|
||||
{"ì", 'i'}, {"í", 'i'}, {"î", 'i'}, {"ï", 'i'}, {"Ò", 'O'}, {"Ó", 'O'},
|
||||
{"Ô", 'O'}, {"Õ", 'O'}, {"Ö", 'O'}, {"ò", 'o'}, {"ó", 'o'}, {"ô", 'o'},
|
||||
{"õ", 'o'}, {"ö", 'o'}, {"Ù", 'U'}, {"Ú", 'U'}, {"Û", 'U'}, {"Ü", 'U'},
|
||||
{"ù", 'u'}, {"ú", 'u'}, {"û", 'u'}, {"ü", 'u'}, {"Ý", 'Y'}, {"ý", 'y'},
|
||||
{"Ç", 'C'}, {"ç", 'c'}, {"Ñ", 'N'}, {"ñ", 'n'},
|
||||
};
|
||||
|
||||
for (size_t i = 0; i < input_string.length();) {
|
||||
int len = utf8_len(input_string[i]);
|
||||
std::string curChar = input_string.substr(i, len);
|
||||
auto iter = accent_map.find(curChar);
|
||||
if (iter != accent_map.end()) {
|
||||
resultString += iter->second;
|
||||
} else {
|
||||
resultString += curChar;
|
||||
}
|
||||
i += len;
|
||||
}
|
||||
|
||||
return resultString;
|
||||
}
|
||||
|
||||
static size_t utf8_len(char src) {
|
||||
const size_t lookup[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4};
|
||||
uint8_t highbits = static_cast<uint8_t>(src) >> 4;
|
||||
return lookup[highbits];
|
||||
}
|
||||
|
||||
const llama_vocab & vocab;
|
||||
};
|
||||
|
||||
typedef enum FRAGMENT_BUFFER_VARIANT_TYPE {
|
||||
FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN,
|
||||
FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT
|
||||
} FRAGMENT_BUFFER_VARIANT_TYPE;
|
||||
|
||||
struct fragment_buffer_variant{
|
||||
struct fragment_buffer_variant {
|
||||
fragment_buffer_variant(llama_vocab::id _token)
|
||||
:
|
||||
type(FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN),
|
||||
@ -7971,8 +8399,7 @@ struct fragment_buffer_variant{
|
||||
|
||||
// #define PRETOKENIZERDEBUG
|
||||
|
||||
static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer)
|
||||
{
|
||||
static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer) {
|
||||
// for each special token
|
||||
for (const auto & st: vocab.special_tokens_cache) {
|
||||
const auto & special_token = st.first;
|
||||
@ -8090,10 +8517,8 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
|
||||
switch (vocab.type) {
|
||||
case LLAMA_VOCAB_TYPE_SPM:
|
||||
{
|
||||
for (const auto & fragment: fragment_buffer)
|
||||
{
|
||||
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT)
|
||||
{
|
||||
for (const auto & fragment: fragment_buffer) {
|
||||
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
|
||||
// without adding this leading whitespace, we do not get the same results as the original tokenizer
|
||||
|
||||
// TODO: It's likely possible to get rid of this string copy entirely
|
||||
@ -8113,19 +8538,15 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
|
||||
llm_tokenizer_spm tokenizer(vocab);
|
||||
llama_escape_whitespace(raw_text);
|
||||
tokenizer.tokenize(raw_text, output);
|
||||
}
|
||||
else // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
|
||||
{
|
||||
} else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
|
||||
output.push_back(fragment.token);
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case LLAMA_VOCAB_TYPE_BPE:
|
||||
{
|
||||
for (const auto & fragment: fragment_buffer)
|
||||
{
|
||||
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT)
|
||||
{
|
||||
for (const auto & fragment: fragment_buffer) {
|
||||
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
|
||||
auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
|
||||
|
||||
#ifdef PRETOKENIZERDEBUG
|
||||
@ -8133,9 +8554,23 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
|
||||
#endif
|
||||
llm_tokenizer_bpe tokenizer(vocab);
|
||||
tokenizer.tokenize(raw_text, output);
|
||||
} else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
|
||||
output.push_back(fragment.token);
|
||||
}
|
||||
else // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
|
||||
{
|
||||
}
|
||||
} break;
|
||||
case LLAMA_VOCAB_TYPE_WPM:
|
||||
{
|
||||
for (const auto & fragment: fragment_buffer) {
|
||||
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
|
||||
auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
|
||||
|
||||
#ifdef PRETOKENIZERDEBUG
|
||||
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
|
||||
#endif
|
||||
llm_tokenizer_wpm tokenizer(vocab);
|
||||
tokenizer.tokenize(raw_text, output);
|
||||
} else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
|
||||
output.push_back(fragment.token);
|
||||
}
|
||||
}
|
||||
@ -10799,7 +11234,7 @@ struct llama_context * llama_new_context_with_model(
|
||||
// graph inputs
|
||||
{
|
||||
ggml_init_params init_params = {
|
||||
/* .mem_size */ ggml_tensor_overhead()*5,
|
||||
/* .mem_size */ ggml_tensor_overhead()*7,
|
||||
/* .mem_buffer */ nullptr,
|
||||
/* .no_alloc */ true,
|
||||
};
|
||||
@ -10810,12 +11245,14 @@ struct llama_context * llama_new_context_with_model(
|
||||
ctx->inp_pos = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
|
||||
ctx->inp_KQ_mask = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_ctx, cparams.n_batch);
|
||||
ctx->inp_K_shift = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_ctx);
|
||||
ctx->inp_sum = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, 1, cparams.n_batch);
|
||||
|
||||
ggml_set_name(ctx->inp_tokens, "inp_tokens");
|
||||
ggml_set_name(ctx->inp_embd, "inp_embd");
|
||||
ggml_set_name(ctx->inp_pos, "inp_pos");
|
||||
ggml_set_name(ctx->inp_KQ_mask, "inp_KQ_mask");
|
||||
ggml_set_name(ctx->inp_K_shift, "inp_K_shift");
|
||||
ggml_set_name(ctx->inp_sum, "inp_sum");
|
||||
|
||||
ctx->buf_input = ggml_backend_alloc_ctx_tensors_from_buft(ctx->ctx_input, llama_default_buffer_type_cpu(true));
|
||||
|
||||
@ -11746,6 +12183,7 @@ static std::string llama_decode_text(const std::string & text) {
|
||||
int32_t llama_token_to_piece(const struct llama_model * model, llama_token token, char * buf, int32_t length) {
|
||||
if (0 <= token && token < llama_n_vocab(model)) {
|
||||
switch (llama_vocab_get_type(model->vocab)) {
|
||||
case LLAMA_VOCAB_TYPE_WPM:
|
||||
case LLAMA_VOCAB_TYPE_SPM: {
|
||||
// NOTE: we accept all unsupported token types,
|
||||
// suppressing them like CONTROL tokens.
|
||||
|
Loading…
Reference in New Issue
Block a user