mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-24 10:24:35 +00:00
llama : support models without vocabulary (#5798)
* additional methods to read model and ctx parameters * vocab size as a part of a model metadata * models without vocabulary, convert.py part * models without vocabulary, llama.cpp part * PR clean up * converter scrypt fixes * llama_vocab_type update (renamed the new key) * pr review fixes * revert function renaming * one more NoVocab assert
This commit is contained in:
parent
044ec4b2a5
commit
69ff61397d
126
convert.py
126
convert.py
@ -332,6 +332,9 @@ class Params:
|
||||
#
|
||||
|
||||
class BpeVocab:
|
||||
tokenizer_model = "gpt2"
|
||||
name = "bpe"
|
||||
|
||||
def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None) -> None:
|
||||
self.bpe_tokenizer = json.loads(open(str(fname_tokenizer), encoding="utf-8").read())
|
||||
if isinstance(self.bpe_tokenizer.get('model'), dict):
|
||||
@ -390,6 +393,9 @@ class BpeVocab:
|
||||
|
||||
|
||||
class SentencePieceVocab:
|
||||
tokenizer_model = "llama"
|
||||
name = "spm"
|
||||
|
||||
def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None) -> None:
|
||||
self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer))
|
||||
added_tokens: dict[str, int]
|
||||
@ -453,6 +459,9 @@ class SentencePieceVocab:
|
||||
|
||||
|
||||
class HfVocab:
|
||||
tokenizer_model = "llama"
|
||||
name = "hfft"
|
||||
|
||||
def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None = None) -> None:
|
||||
try:
|
||||
from transformers import AutoTokenizer
|
||||
@ -553,7 +562,15 @@ class HfVocab:
|
||||
return f"<HfVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
|
||||
|
||||
|
||||
Vocab: TypeAlias = "BpeVocab | SentencePieceVocab | HfVocab"
|
||||
class NoVocab:
|
||||
tokenizer_model = "no_vocab"
|
||||
name = "no_vocab"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "<NoVocab for a model without integrated vocabulary>"
|
||||
|
||||
|
||||
Vocab: TypeAlias = "BpeVocab | SentencePieceVocab | HfVocab | NoVocab"
|
||||
|
||||
|
||||
#
|
||||
@ -935,8 +952,10 @@ def check_vocab_size(params: Params, vocab: Vocab, pad_vocab: bool = False) -> N
|
||||
# Handle special case where the model's vocab size is not set
|
||||
if params.n_vocab == -1:
|
||||
raise ValueError(
|
||||
f"The model's vocab size is set to -1 in params.json. Please update it manually. Maybe {vocab.vocab_size}?"
|
||||
f"The model's vocab size is set to -1 in params.json. Please update it manually.{f' Maybe {vocab.vocab_size}?' if hasattr(vocab, 'vocab_size') else ''}"
|
||||
)
|
||||
if isinstance(vocab, NoVocab):
|
||||
return # model has no vocab
|
||||
|
||||
# Check for a vocab size mismatch
|
||||
if params.n_vocab == vocab.vocab_size:
|
||||
@ -977,6 +996,7 @@ class OutputFile:
|
||||
name = str(params.path_model.parent).split('/')[-1]
|
||||
|
||||
self.gguf.add_name (name)
|
||||
self.gguf.add_vocab_size (params.n_vocab)
|
||||
self.gguf.add_context_length (params.n_ctx)
|
||||
self.gguf.add_embedding_length (params.n_embd)
|
||||
self.gguf.add_block_count (params.n_layer)
|
||||
@ -1013,21 +1033,9 @@ class OutputFile:
|
||||
if params.ftype is not None:
|
||||
self.gguf.add_file_type(params.ftype)
|
||||
|
||||
def handle_tokenizer_model(self, vocab: Vocab) -> str:
|
||||
# Map the vocab types to the supported tokenizer models
|
||||
tokenizer_model = {
|
||||
SentencePieceVocab: "llama",
|
||||
HfVocab: "llama",
|
||||
BpeVocab: "gpt2",
|
||||
}.get(type(vocab))
|
||||
|
||||
# Block if vocab type is not predefined
|
||||
if tokenizer_model is None:
|
||||
raise ValueError("Unknown vocab type: Not supported")
|
||||
|
||||
return tokenizer_model
|
||||
|
||||
def extract_vocabulary_from_model(self, vocab: Vocab) -> tuple[list[bytes], list[float], list[gguf.TokenType]]:
|
||||
assert not isinstance(vocab, NoVocab)
|
||||
|
||||
tokens = []
|
||||
scores = []
|
||||
toktypes = []
|
||||
@ -1043,11 +1051,8 @@ class OutputFile:
|
||||
return tokens, scores, toktypes
|
||||
|
||||
def add_meta_vocab(self, vocab: Vocab) -> None:
|
||||
# Handle the tokenizer model
|
||||
tokenizer_model = self.handle_tokenizer_model(vocab)
|
||||
|
||||
# Ensure that tokenizer_model is added to the GGUF model
|
||||
self.gguf.add_tokenizer_model(tokenizer_model)
|
||||
self.gguf.add_tokenizer_model(vocab.tokenizer_model)
|
||||
|
||||
# Extract model vocabulary for model conversion
|
||||
tokens, scores, toktypes = self.extract_vocabulary_from_model(vocab)
|
||||
@ -1074,6 +1079,26 @@ class OutputFile:
|
||||
def write_tensor_info(self) -> None:
|
||||
self.gguf.write_ti_data_to_file()
|
||||
|
||||
def write_tensor_data(self, ftype: GGMLFileType, model: LazyModel, concurrency: int) -> None:
|
||||
ndarrays_inner = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency=concurrency)
|
||||
if ftype == GGMLFileType.MostlyQ8_0:
|
||||
ndarrays = bounded_parallel_map(
|
||||
OutputFile.maybe_do_quantize, ndarrays_inner, concurrency=concurrency, max_workers=concurrency,
|
||||
use_processpool_executor=True,
|
||||
)
|
||||
else:
|
||||
ndarrays = map(OutputFile.maybe_do_quantize, ndarrays_inner)
|
||||
|
||||
start = time.time()
|
||||
for i, ((name, lazy_tensor), ndarray) in enumerate(zip(model.items(), ndarrays)):
|
||||
elapsed = time.time() - start
|
||||
size = ' x '.join(f"{dim:6d}" for dim in lazy_tensor.shape)
|
||||
padi = len(str(len(model)))
|
||||
print(
|
||||
f"[{i + 1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type.name:4} | T+{int(elapsed):4}"
|
||||
)
|
||||
self.gguf.write_tensor_data(ndarray)
|
||||
|
||||
def close(self) -> None:
|
||||
self.gguf.close()
|
||||
|
||||
@ -1082,7 +1107,7 @@ class OutputFile:
|
||||
fname_out: Path, params: Params, vocab: Vocab, svocab: gguf.SpecialVocab,
|
||||
endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, pad_vocab: bool = False,
|
||||
) -> None:
|
||||
check_vocab_size(params, vocab, pad_vocab = pad_vocab)
|
||||
check_vocab_size(params, vocab, pad_vocab=pad_vocab)
|
||||
|
||||
of = OutputFile(fname_out, endianess=endianess)
|
||||
|
||||
@ -1120,8 +1145,11 @@ class OutputFile:
|
||||
|
||||
# meta data
|
||||
of.add_meta_arch(params)
|
||||
of.add_meta_vocab(vocab)
|
||||
of.add_meta_special_vocab(svocab)
|
||||
if isinstance(vocab, NoVocab):
|
||||
of.gguf.add_tokenizer_model(vocab.tokenizer_model)
|
||||
else:
|
||||
of.add_meta_vocab(vocab)
|
||||
of.add_meta_special_vocab(svocab)
|
||||
|
||||
# tensor info
|
||||
for name, lazy_tensor in model.items():
|
||||
@ -1131,24 +1159,7 @@ class OutputFile:
|
||||
of.write_tensor_info()
|
||||
|
||||
# tensor data
|
||||
ndarrays_inner = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency = concurrency)
|
||||
if ftype == GGMLFileType.MostlyQ8_0:
|
||||
ndarrays = bounded_parallel_map(
|
||||
OutputFile.maybe_do_quantize, ndarrays_inner, concurrency=concurrency, max_workers=concurrency,
|
||||
use_processpool_executor=True,
|
||||
)
|
||||
else:
|
||||
ndarrays = map(OutputFile.maybe_do_quantize, ndarrays_inner)
|
||||
|
||||
start = time.time()
|
||||
for i, ((name, lazy_tensor), ndarray) in enumerate(zip(model.items(), ndarrays)):
|
||||
elapsed = time.time() - start
|
||||
size = ' x '.join(f"{dim:6d}" for dim in lazy_tensor.shape)
|
||||
padi = len(str(len(model)))
|
||||
print(
|
||||
f"[{i+1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type.name:4} | T+{int(elapsed):4}"
|
||||
)
|
||||
of.gguf.write_tensor_data(ndarray)
|
||||
of.write_tensor_data(ftype, model, concurrency)
|
||||
|
||||
of.close()
|
||||
|
||||
@ -1309,8 +1320,8 @@ class VocabFactory:
|
||||
return vtype, path
|
||||
raise FileNotFoundError(f"Could not find any of {[self._FILES[vt] for vt in vocab_types]}")
|
||||
|
||||
def _create_special_vocab(self, vocab: Vocab, vocabtype: str, model_parent_path: Path) -> gguf.SpecialVocab:
|
||||
load_merges = vocabtype == "bpe"
|
||||
def _create_special_vocab(self, vocab: Vocab, model_parent_path: Path) -> gguf.SpecialVocab:
|
||||
load_merges = vocab.name == "bpe"
|
||||
n_vocab = vocab.vocab_size if hasattr(vocab, "vocab_size") else None
|
||||
return gguf.SpecialVocab(
|
||||
model_parent_path,
|
||||
@ -1319,30 +1330,34 @@ class VocabFactory:
|
||||
n_vocab=n_vocab,
|
||||
)
|
||||
|
||||
def load_vocab(self, vocab_types: list[str], model_parent_path: Path) -> tuple[Vocab, gguf.SpecialVocab]:
|
||||
def _create_vocab_by_path(self, vocab_types: list[str]) -> Vocab:
|
||||
vocab_type, path = self._select_file(vocab_types)
|
||||
print(f"Loading vocab file {path!r}, type {vocab_type!r}")
|
||||
|
||||
added_tokens_path = path.parent / "added_tokens.json"
|
||||
vocab: Vocab
|
||||
if vocab_type == "bpe":
|
||||
vocab = BpeVocab(
|
||||
return BpeVocab(
|
||||
path, added_tokens_path if added_tokens_path.exists() else None
|
||||
)
|
||||
elif vocab_type == "spm":
|
||||
vocab = SentencePieceVocab(
|
||||
if vocab_type == "spm":
|
||||
return SentencePieceVocab(
|
||||
path, added_tokens_path if added_tokens_path.exists() else None
|
||||
)
|
||||
elif vocab_type == "hfft":
|
||||
vocab = HfVocab(
|
||||
if vocab_type == "hfft":
|
||||
return HfVocab(
|
||||
path.parent, added_tokens_path if added_tokens_path.exists() else None
|
||||
)
|
||||
raise ValueError(vocab_type)
|
||||
|
||||
def load_vocab(self, vocab_types: list[str], model_parent_path: Path) -> tuple[Vocab, gguf.SpecialVocab]:
|
||||
vocab: Vocab
|
||||
if len(vocab_types) == 1 and "no_vocab" in vocab_types:
|
||||
vocab = NoVocab()
|
||||
else:
|
||||
raise ValueError(vocab_type)
|
||||
vocab = self._create_vocab_by_path(vocab_types)
|
||||
# FIXME: Respect --vocab-dir?
|
||||
special_vocab = self._create_special_vocab(
|
||||
vocab,
|
||||
vocab_type,
|
||||
model_parent_path,
|
||||
)
|
||||
return vocab, special_vocab
|
||||
@ -1380,6 +1395,7 @@ def main(args_in: list[str] | None = None) -> None:
|
||||
parser.add_argument("--dump", action="store_true", help="don't convert, just show what's in the model")
|
||||
parser.add_argument("--dump-single", action="store_true", help="don't convert, just show what's in a single model file")
|
||||
parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab")
|
||||
parser.add_argument("--no-vocab", action="store_true", help="store model without the vocab")
|
||||
parser.add_argument("--outtype", choices=output_choices, help="output format - note: q8_0 may be very slow (default: f16 or f32 based on input)")
|
||||
parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file")
|
||||
parser.add_argument("--vocab-type", help="vocab types to try in order, choose from 'spm', 'bpe', 'hfft' (default: spm,hfft)", default="spm,hfft")
|
||||
@ -1392,6 +1408,10 @@ def main(args_in: list[str] | None = None) -> None:
|
||||
parser.add_argument("--skip-unknown", action="store_true", help="skip unknown tensor names instead of failing")
|
||||
|
||||
args = parser.parse_args(args_in)
|
||||
if args.no_vocab:
|
||||
if args.vocab_only:
|
||||
raise ValueError("no need to specify --vocab-only if using --no-vocab")
|
||||
args.vocab_type = "no_vocab"
|
||||
|
||||
if args.dump_single:
|
||||
model_plus = lazy_load_file(args.model)
|
||||
@ -1442,7 +1462,7 @@ def main(args_in: list[str] | None = None) -> None:
|
||||
print(f"Wrote {outfile}")
|
||||
return
|
||||
|
||||
if model_plus.vocab is not None and args.vocab_dir is None:
|
||||
if model_plus.vocab is not None and args.vocab_dir is None and not args.no_vocab:
|
||||
vocab = model_plus.vocab
|
||||
|
||||
print(f"Vocab info: {vocab}")
|
||||
|
@ -32,6 +32,7 @@ class Keys:
|
||||
FILE_TYPE = "general.file_type"
|
||||
|
||||
class LLM:
|
||||
VOCAB_SIZE = "{arch}.vocab_size"
|
||||
CONTEXT_LENGTH = "{arch}.context_length"
|
||||
EMBEDDING_LENGTH = "{arch}.embedding_length"
|
||||
BLOCK_COUNT = "{arch}.block_count"
|
||||
@ -752,6 +753,7 @@ KEY_GENERAL_SOURCE_HF_REPO = Keys.General.SOURCE_HF_REPO
|
||||
KEY_GENERAL_FILE_TYPE = Keys.General.FILE_TYPE
|
||||
|
||||
# LLM
|
||||
KEY_VOCAB_SIZE = Keys.LLM.VOCAB_SIZE
|
||||
KEY_CONTEXT_LENGTH = Keys.LLM.CONTEXT_LENGTH
|
||||
KEY_EMBEDDING_LENGTH = Keys.LLM.EMBEDDING_LENGTH
|
||||
KEY_BLOCK_COUNT = Keys.LLM.BLOCK_COUNT
|
||||
|
@ -321,6 +321,9 @@ class GGUFWriter:
|
||||
self.data_alignment = alignment
|
||||
self.add_uint32(Keys.General.ALIGNMENT, alignment)
|
||||
|
||||
def add_vocab_size(self, size: int) -> None:
|
||||
self.add_uint32(Keys.LLM.VOCAB_SIZE.format(arch=self.arch), size)
|
||||
|
||||
def add_context_length(self, length: int) -> None:
|
||||
self.add_uint32(Keys.LLM.CONTEXT_LENGTH.format(arch=self.arch), length)
|
||||
|
||||
|
92
llama.cpp
92
llama.cpp
@ -258,6 +258,7 @@ enum llm_kv {
|
||||
LLM_KV_GENERAL_SOURCE_URL,
|
||||
LLM_KV_GENERAL_SOURCE_HF_REPO,
|
||||
|
||||
LLM_KV_VOCAB_SIZE,
|
||||
LLM_KV_CONTEXT_LENGTH,
|
||||
LLM_KV_EMBEDDING_LENGTH,
|
||||
LLM_KV_BLOCK_COUNT,
|
||||
@ -321,6 +322,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
{ LLM_KV_GENERAL_SOURCE_URL, "general.source.url" },
|
||||
{ LLM_KV_GENERAL_SOURCE_HF_REPO, "general.source.huggingface.repository" },
|
||||
|
||||
{ LLM_KV_VOCAB_SIZE, "%s.vocab_size" },
|
||||
{ LLM_KV_CONTEXT_LENGTH, "%s.context_length" },
|
||||
{ LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" },
|
||||
{ LLM_KV_BLOCK_COUNT, "%s.block_count" },
|
||||
@ -3242,10 +3244,11 @@ static const char * llama_model_type_name(e_model type) {
|
||||
|
||||
static const char * llama_model_vocab_type_name(enum llama_vocab_type type){
|
||||
switch (type) {
|
||||
case LLAMA_VOCAB_TYPE_SPM: return "SPM";
|
||||
case LLAMA_VOCAB_TYPE_BPE: return "BPE";
|
||||
case LLAMA_VOCAB_TYPE_WPM: return "WPM";
|
||||
default: return "unknown";
|
||||
case LLAMA_VOCAB_TYPE_NONE: return "no vocab";
|
||||
case LLAMA_VOCAB_TYPE_SPM: return "SPM";
|
||||
case LLAMA_VOCAB_TYPE_BPE: return "BPE";
|
||||
case LLAMA_VOCAB_TYPE_WPM: return "WPM";
|
||||
default: return "unknown";
|
||||
}
|
||||
}
|
||||
|
||||
@ -3277,14 +3280,14 @@ static void llm_load_hparams(
|
||||
ml.get_key(LLM_KV_GENERAL_NAME, model.name, false);
|
||||
|
||||
// get hparams kv
|
||||
ml.get_arr_n(LLM_KV_TOKENIZER_LIST, hparams.n_vocab);
|
||||
ml.get_key (LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train);
|
||||
ml.get_key (LLM_KV_EMBEDDING_LENGTH, hparams.n_embd);
|
||||
ml.get_key (LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff);
|
||||
ml.get_key (LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head);
|
||||
ml.get_key (LLM_KV_BLOCK_COUNT, hparams.n_layer);
|
||||
ml.get_key (LLM_KV_EXPERT_COUNT, hparams.n_expert, false);
|
||||
ml.get_key (LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false);
|
||||
ml.get_key(LLM_KV_VOCAB_SIZE, hparams.n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, hparams.n_vocab);
|
||||
ml.get_key(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train);
|
||||
ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd);
|
||||
ml.get_key(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff);
|
||||
ml.get_key(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head);
|
||||
ml.get_key(LLM_KV_BLOCK_COUNT, hparams.n_layer);
|
||||
ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert, false);
|
||||
ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false);
|
||||
|
||||
GGML_ASSERT(hparams.n_expert <= LLAMA_MAX_EXPERTS);
|
||||
GGML_ASSERT(hparams.n_expert_used <= hparams.n_expert);
|
||||
@ -3645,30 +3648,25 @@ static void llm_load_vocab(
|
||||
|
||||
const auto kv = LLM_KV(model.arch);
|
||||
|
||||
const int token_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_LIST).c_str());
|
||||
if (token_idx == -1) {
|
||||
throw std::runtime_error("cannot find tokenizer vocab in model file\n");
|
||||
}
|
||||
|
||||
const float * scores = nullptr;
|
||||
const int score_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_SCORES).c_str());
|
||||
if (score_idx != -1) {
|
||||
scores = (const float * ) gguf_get_arr_data(ctx, score_idx);
|
||||
}
|
||||
|
||||
const int * toktypes = nullptr;
|
||||
const int toktype_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE).c_str());
|
||||
if (toktype_idx != -1) {
|
||||
toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx);
|
||||
}
|
||||
|
||||
// determine vocab type
|
||||
{
|
||||
std::string tokenizer_name;
|
||||
|
||||
ml.get_key(LLM_KV_TOKENIZER_MODEL, tokenizer_name);
|
||||
|
||||
if (tokenizer_name == "llama") {
|
||||
if (tokenizer_name == "no_vocab") {
|
||||
vocab.type = LLAMA_VOCAB_TYPE_NONE;
|
||||
|
||||
// default special tokens
|
||||
vocab.special_bos_id = -1;
|
||||
vocab.special_eos_id = -1;
|
||||
vocab.special_unk_id = -1;
|
||||
vocab.special_sep_id = -1;
|
||||
vocab.special_pad_id = -1;
|
||||
vocab.linefeed_id = -1;
|
||||
|
||||
return;
|
||||
} else if (tokenizer_name == "llama") {
|
||||
vocab.type = LLAMA_VOCAB_TYPE_SPM;
|
||||
|
||||
// default special tokens
|
||||
@ -3734,6 +3732,23 @@ static void llm_load_vocab(
|
||||
}
|
||||
}
|
||||
|
||||
const int token_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_LIST).c_str());
|
||||
if (token_idx == -1) {
|
||||
throw std::runtime_error("cannot find tokenizer vocab in model file\n");
|
||||
}
|
||||
|
||||
const float * scores = nullptr;
|
||||
const int score_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_SCORES).c_str());
|
||||
if (score_idx != -1) {
|
||||
scores = (const float * ) gguf_get_arr_data(ctx, score_idx);
|
||||
}
|
||||
|
||||
const int * toktypes = nullptr;
|
||||
const int toktype_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE).c_str());
|
||||
if (toktype_idx != -1) {
|
||||
toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx);
|
||||
}
|
||||
|
||||
const uint32_t n_vocab = gguf_get_arr_n(ctx, token_idx);
|
||||
|
||||
vocab.id_to_token.resize(n_vocab);
|
||||
@ -5023,7 +5038,8 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
|
||||
|
||||
llm_load_print_meta(ml, model);
|
||||
|
||||
if (model.hparams.n_vocab != model.vocab.id_to_token.size()) {
|
||||
if (model.vocab.type != LLAMA_VOCAB_TYPE_NONE &&
|
||||
model.hparams.n_vocab != model.vocab.id_to_token.size()) {
|
||||
throw std::runtime_error("vocab size mismatch");
|
||||
}
|
||||
|
||||
@ -9361,26 +9377,32 @@ static enum llama_vocab_type llama_vocab_get_type(const llama_vocab & vocab) {
|
||||
}
|
||||
|
||||
static bool llama_is_normal_token(const llama_vocab & vocab, llama_token id) {
|
||||
GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
|
||||
return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_NORMAL;
|
||||
}
|
||||
|
||||
static bool llama_is_unknown_token(const llama_vocab & vocab, llama_token id) {
|
||||
GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
|
||||
return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_UNKNOWN;
|
||||
}
|
||||
|
||||
static bool llama_is_control_token(const llama_vocab & vocab, llama_token id) {
|
||||
GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
|
||||
return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_CONTROL;
|
||||
}
|
||||
|
||||
static bool llama_is_byte_token(const llama_vocab & vocab, llama_token id) {
|
||||
GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
|
||||
return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_BYTE;
|
||||
}
|
||||
|
||||
static bool llama_is_user_defined_token(const llama_vocab& vocab, llama_token id) {
|
||||
GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
|
||||
return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_USER_DEFINED;
|
||||
}
|
||||
|
||||
static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) {
|
||||
GGML_ASSERT(llama_vocab_get_type(vocab) != LLAMA_VOCAB_TYPE_NONE);
|
||||
GGML_ASSERT(llama_is_byte_token(vocab, id));
|
||||
const auto& token_data = vocab.id_to_token.at(id);
|
||||
switch (llama_vocab_get_type(vocab)) {
|
||||
@ -9401,6 +9423,7 @@ static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) {
|
||||
}
|
||||
|
||||
static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) {
|
||||
GGML_ASSERT(llama_vocab_get_type(vocab) != LLAMA_VOCAB_TYPE_NONE);
|
||||
static const char * hex = "0123456789ABCDEF";
|
||||
switch (llama_vocab_get_type(vocab)) {
|
||||
case LLAMA_VOCAB_TYPE_SPM: {
|
||||
@ -10232,6 +10255,8 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case LLAMA_VOCAB_TYPE_NONE:
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
|
||||
return output;
|
||||
@ -13138,7 +13163,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
|
||||
}
|
||||
|
||||
int32_t llama_n_vocab(const struct llama_model * model) {
|
||||
return model->vocab.id_to_token.size();
|
||||
return model->hparams.n_vocab;
|
||||
}
|
||||
|
||||
int32_t llama_n_ctx_train(const struct llama_model * model) {
|
||||
@ -13962,14 +13987,17 @@ float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id
|
||||
}
|
||||
|
||||
const char * llama_token_get_text(const struct llama_model * model, llama_token token) {
|
||||
GGML_ASSERT(model->vocab.type != LLAMA_VOCAB_TYPE_NONE);
|
||||
return model->vocab.id_to_token[token].text.c_str();
|
||||
}
|
||||
|
||||
float llama_token_get_score(const struct llama_model * model, llama_token token) {
|
||||
GGML_ASSERT(model->vocab.type != LLAMA_VOCAB_TYPE_NONE);
|
||||
return model->vocab.id_to_token[token].score;
|
||||
}
|
||||
|
||||
llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token) {
|
||||
GGML_ASSERT(model->vocab.type != LLAMA_VOCAB_TYPE_NONE);
|
||||
return model->vocab.id_to_token[token].type;
|
||||
}
|
||||
|
||||
|
7
llama.h
7
llama.h
@ -59,9 +59,10 @@ extern "C" {
|
||||
typedef int32_t llama_seq_id;
|
||||
|
||||
enum llama_vocab_type {
|
||||
LLAMA_VOCAB_TYPE_SPM = 0, // SentencePiece
|
||||
LLAMA_VOCAB_TYPE_BPE = 1, // Byte Pair Encoding
|
||||
LLAMA_VOCAB_TYPE_WPM = 2, // WordPiece
|
||||
LLAMA_VOCAB_TYPE_NONE = 0, // For models without vocab
|
||||
LLAMA_VOCAB_TYPE_SPM = 1, // SentencePiece
|
||||
LLAMA_VOCAB_TYPE_BPE = 2, // Byte Pair Encoding
|
||||
LLAMA_VOCAB_TYPE_WPM = 3, // WordPiece
|
||||
};
|
||||
|
||||
// note: these values should be synchronized with ggml_rope
|
||||
|
Loading…
Reference in New Issue
Block a user