mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 11:24:35 +00:00
llama : tokenizer fixes (#2549)
* Merge tokenizer fixes into the gguf branch. * Add test vocabularies
This commit is contained in:
parent
8af3a99ff1
commit
ec1b100720
103
convert.py
103
convert.py
@ -238,22 +238,58 @@ class Params:
|
||||
return params
|
||||
|
||||
|
||||
class SentencePieceVocab:
|
||||
def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path], vocabtype: Optional[str]) -> None:
|
||||
self.vocabtype = vocabtype
|
||||
if self.vocabtype == "bpe":
|
||||
self.sentencepiece_tokenizer = json.loads(open(str(fname_tokenizer)).read())
|
||||
else:
|
||||
self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer))
|
||||
class BpeVocab:
|
||||
def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path]) -> None:
|
||||
self.bpe_tokenizer = json.loads(open(str(fname_tokenizer), encoding="utf-8").read())
|
||||
added_tokens: Dict[str, int]
|
||||
if fname_added_tokens is not None:
|
||||
added_tokens = json.load(open(fname_added_tokens))
|
||||
added_tokens = json.load(open(fname_added_tokens, encoding="utf-8"))
|
||||
else:
|
||||
added_tokens = {}
|
||||
if self.vocabtype == "bpe":
|
||||
vocab_size: int = len(self.sentencepiece_tokenizer)
|
||||
vocab_size: int = len(self.bpe_tokenizer)
|
||||
expected_ids = list(range(vocab_size, vocab_size + len(added_tokens)))
|
||||
actual_ids = sorted(added_tokens.values())
|
||||
if expected_ids != actual_ids:
|
||||
raise Exception(f"Expected added token IDs to be sequential and start at {len(added_tokens)}; got {actual_ids}")
|
||||
items = sorted(added_tokens.items(), key=lambda text_idx: text_idx[1])
|
||||
self.added_tokens_list = [text for (text, idx) in items]
|
||||
self.vocab_size_base: int = vocab_size
|
||||
self.vocab_size: int = self.vocab_size_base + len(self.added_tokens_list)
|
||||
self.fname_tokenizer = fname_tokenizer
|
||||
self.fname_added_tokens = fname_added_tokens
|
||||
|
||||
def bpe_tokens(self) -> Iterable[Tuple[bytes, float]]:
|
||||
tokenizer = self.bpe_tokenizer
|
||||
from transformers.models.gpt2 import tokenization_gpt2
|
||||
byte_encoder = tokenization_gpt2.bytes_to_unicode()
|
||||
byte_decoder = {v: k for k, v in byte_encoder.items()}
|
||||
for i, item in enumerate(tokenizer):
|
||||
text: bytes = item.encode("utf-8")
|
||||
score: float = -i
|
||||
yield text, score
|
||||
|
||||
def added_tokens(self) -> Iterable[Tuple[bytes, float]]:
|
||||
for text in self.added_tokens_list:
|
||||
score = -1000.0
|
||||
yield text.encode("utf-8"), score
|
||||
|
||||
def all_tokens(self) -> Iterable[Tuple[bytes, float]]:
|
||||
yield from self.bpe_tokens()
|
||||
yield from self.added_tokens()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"BpeVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
|
||||
|
||||
|
||||
class SentencePieceVocab:
|
||||
def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path]) -> None:
|
||||
self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer))
|
||||
added_tokens: Dict[str, int]
|
||||
if fname_added_tokens is not None:
|
||||
added_tokens = json.load(open(fname_added_tokens, encoding="utf-8"))
|
||||
else:
|
||||
vocab_size: int = self.sentencepiece_tokenizer.vocab_size()
|
||||
added_tokens = {}
|
||||
vocab_size: int = self.sentencepiece_tokenizer.vocab_size()
|
||||
expected_ids = list(range(vocab_size, vocab_size + len(added_tokens)))
|
||||
actual_ids = sorted(added_tokens.values())
|
||||
if expected_ids != actual_ids:
|
||||
@ -267,32 +303,11 @@ class SentencePieceVocab:
|
||||
|
||||
def sentencepiece_tokens(self) -> Iterable[Tuple[bytes, float]]:
|
||||
tokenizer = self.sentencepiece_tokenizer
|
||||
if self.vocabtype == "bpe":
|
||||
from transformers.models.gpt2 import tokenization_gpt2
|
||||
byte_encoder = tokenization_gpt2.bytes_to_unicode()
|
||||
byte_decoder = {v: k for k, v in byte_encoder.items()}
|
||||
for i, item in enumerate(tokenizer):
|
||||
text: bytes
|
||||
text = b''.join([x.to_bytes(1, byteorder='big') for x in [byte_decoder[y] for y in item]])
|
||||
score: float = -i
|
||||
for i in range(tokenizer.vocab_size()):
|
||||
piece = tokenizer.id_to_piece(i)
|
||||
text: bytes = piece.encode("utf-8")
|
||||
score: float = tokenizer.get_score(i)
|
||||
yield text, score
|
||||
else:
|
||||
for i in range(tokenizer.vocab_size()):
|
||||
text: bytes
|
||||
if tokenizer.is_unknown(i):
|
||||
text = " \u2047 ".encode("utf-8")
|
||||
elif tokenizer.is_control(i):
|
||||
text = b""
|
||||
elif tokenizer.is_byte(i):
|
||||
piece = tokenizer.id_to_piece(i)
|
||||
if len(piece) != 6:
|
||||
raise Exception(f"Invalid token: {piece}")
|
||||
byte_value = int(piece[3:-1], 16)
|
||||
text = struct.pack("B", byte_value)
|
||||
else:
|
||||
text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode("utf-8")
|
||||
score: float = tokenizer.get_score(i)
|
||||
yield text, score
|
||||
|
||||
def added_tokens(self) -> Iterable[Tuple[bytes, float]]:
|
||||
for text in self.added_tokens_list:
|
||||
@ -319,7 +334,7 @@ class GGMLVocab:
|
||||
return f"<GGMLVocab with {self.vocab_size} tokens>"
|
||||
|
||||
|
||||
Vocab = Union[SentencePieceVocab, GGMLVocab]
|
||||
Vocab = Union[BpeVocab, SentencePieceVocab, GGMLVocab]
|
||||
|
||||
|
||||
def permute(weights: NDArray, n_head: int, n_kv_head: Optional[int] = None) -> NDArray:
|
||||
@ -1044,7 +1059,7 @@ def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], conc
|
||||
def check_vocab_size(params: Params, vocab: Vocab) -> None:
|
||||
if params.n_vocab != vocab.vocab_size:
|
||||
# GGMLVocab comes from the same file as the model so shouldn't mismatch:
|
||||
assert isinstance(vocab, SentencePieceVocab)
|
||||
assert isinstance(vocab, BpeVocab) or isinstance(vocab, SentencePieceVocab)
|
||||
if params.n_vocab == vocab.vocab_size_base:
|
||||
print("Ignoring added_tokens.json since model matches vocab size without it.")
|
||||
vocab.added_tokens_list = []
|
||||
@ -1093,7 +1108,7 @@ class OutputFile:
|
||||
@staticmethod
|
||||
def write_vocab_only(fname_out: Path, vocab: Vocab) -> None:
|
||||
of = OutputFile(fname_out)
|
||||
params = Params(n_vocab=vocab.vocab_size, n_embd=0, n_mult=0, n_head=1, n_layer=0)
|
||||
params = Params(n_vocab=vocab.vocab_size, n_embd=0, n_mult=0, n_head=1, n_layer=0, n_kv_head=None)
|
||||
of = OutputFile(fname_out)
|
||||
of.write_file_header(params, file_type=GGMLFileType.AllF32)
|
||||
of.write_vocab(vocab)
|
||||
@ -1228,7 +1243,7 @@ def filter_and_sort_tensors(model: LazyModel) -> LazyModel:
|
||||
return {name: model[name] for name in TENSORS_LIST if name in model}
|
||||
|
||||
|
||||
def load_vocab(path: Path, vocabtype: Optional[str]) -> SentencePieceVocab:
|
||||
def load_vocab(path: Path, vocabtype: Optional[str]) -> Union[BpeVocab, SentencePieceVocab]:
|
||||
print(f"vocabtype: {vocabtype}")
|
||||
# Be extra-friendly and accept either a file or a directory. Also, if it's
|
||||
# a directory, it might be the model directory, and tokenizer.model might
|
||||
@ -1250,8 +1265,12 @@ def load_vocab(path: Path, vocabtype: Optional[str]) -> SentencePieceVocab:
|
||||
"if it's in another directory, pass the directory as --vocab-dir")
|
||||
added_tokens_path = path.parent / "added_tokens.json"
|
||||
print(f"Loading vocab file {path}")
|
||||
return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None,
|
||||
vocabtype)
|
||||
if vocabtype == "bpe":
|
||||
return BpeVocab(path, added_tokens_path if added_tokens_path.exists() else None)
|
||||
elif vocabtype == "spm":
|
||||
return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None)
|
||||
else:
|
||||
raise ValueError(f"Unsupported vocabulary type {vocabtype}")
|
||||
|
||||
|
||||
def default_outfile(model_paths: List[Path], file_type: GGMLFileType) -> Path:
|
||||
|
@ -633,17 +633,6 @@ std::string gpt_random_prompt(std::mt19937 & rng) {
|
||||
return "The";
|
||||
}
|
||||
|
||||
// TODO: not great allocating this every time
|
||||
std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::string & text, bool add_bos) {
|
||||
// initialize to prompt numer of chars, since n_tokens <= n_prompt_chars
|
||||
std::vector<llama_token> res(text.size() + (int) add_bos);
|
||||
const int n = llama_tokenize(ctx, text.c_str(), res.data(), res.size(), add_bos);
|
||||
assert(n >= 0);
|
||||
res.resize(n);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) {
|
||||
auto lparams = llama_context_default_params();
|
||||
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#define LLAMA_API_CPP // TODO: eliminate me
|
||||
#include "llama.h"
|
||||
|
||||
#include <string>
|
||||
@ -100,12 +101,6 @@ void gpt_print_usage(int argc, char ** argv, const gpt_params & params);
|
||||
|
||||
std::string gpt_random_prompt(std::mt19937 & rng);
|
||||
|
||||
//
|
||||
// Vocab utils
|
||||
//
|
||||
|
||||
std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::string & text, bool add_bos);
|
||||
|
||||
//
|
||||
// Model utils
|
||||
//
|
||||
|
@ -67,7 +67,7 @@ int main(int argc, char ** argv) {
|
||||
fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str());
|
||||
fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
|
||||
for (int i = 0; i < (int) embd_inp.size(); i++) {
|
||||
fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_str(ctx, embd_inp[i]));
|
||||
fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_str(ctx, embd_inp[i]).c_str());
|
||||
}
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
@ -191,10 +191,6 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// tokenize the prompt
|
||||
std::vector<llama_token> embd_inp;
|
||||
|
||||
// Add a space in front of the first character to match OG llama tokenizer behavior
|
||||
params.prompt.insert(0, 1, ' ');
|
||||
|
||||
if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) {
|
||||
embd_inp = ::llama_tokenize(ctx, params.prompt, true);
|
||||
} else {
|
||||
@ -278,7 +274,7 @@ int main(int argc, char ** argv) {
|
||||
fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str());
|
||||
fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
|
||||
for (int i = 0; i < (int) embd_inp.size(); i++) {
|
||||
fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_str(ctx, embd_inp[i]));
|
||||
fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_str(ctx, embd_inp[i]).c_str());
|
||||
}
|
||||
|
||||
if (ctx_guidance) {
|
||||
@ -286,14 +282,14 @@ int main(int argc, char ** argv) {
|
||||
fprintf(stderr, "%s: negative prompt: '%s'\n", __func__, params.cfg_negative_prompt.c_str());
|
||||
fprintf(stderr, "%s: number of tokens in negative prompt = %zu\n", __func__, guidance_inp.size());
|
||||
for (int i = 0; i < (int) guidance_inp.size(); i++) {
|
||||
fprintf(stderr, "%6d -> '%s'\n", guidance_inp[i], llama_token_to_str(ctx, guidance_inp[i]));
|
||||
fprintf(stderr, "%6d -> '%s'\n", guidance_inp[i], llama_token_to_str(ctx, guidance_inp[i]).c_str());
|
||||
}
|
||||
}
|
||||
|
||||
if (params.n_keep > 0) {
|
||||
fprintf(stderr, "%s: static prompt based on n_keep: '", __func__);
|
||||
for (int i = 0; i < params.n_keep; i++) {
|
||||
fprintf(stderr, "%s", llama_token_to_str(ctx, embd_inp[i]));
|
||||
fprintf(stderr, "%s", llama_token_to_str(ctx, embd_inp[i]).c_str());
|
||||
}
|
||||
fprintf(stderr, "'\n");
|
||||
}
|
||||
@ -662,7 +658,7 @@ int main(int argc, char ** argv) {
|
||||
// display text
|
||||
if (input_echo) {
|
||||
for (auto id : embd) {
|
||||
printf("%s", llama_token_to_str(ctx, id));
|
||||
printf("%s", llama_token_to_str(ctx, id).c_str());
|
||||
}
|
||||
fflush(stdout);
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
#include "ggml.h"
|
||||
#include "build-info.h"
|
||||
|
||||
#define LLAMA_API_CPP // TODO: eliminate me
|
||||
#define LLAMA_API_INTERNAL
|
||||
#include "llama.h"
|
||||
|
||||
|
@ -45,9 +45,8 @@ int main(int argc, char ** argv) {
|
||||
llama_free_model(model);
|
||||
return 1;
|
||||
}
|
||||
auto tokens = std::vector<llama_token>(params.n_ctx);
|
||||
auto n_prompt_tokens = llama_tokenize(ctx, params.prompt.c_str(), tokens.data(), int(tokens.size()), true);
|
||||
|
||||
auto tokens = llama_tokenize(ctx, params.prompt.c_str(), true);
|
||||
auto n_prompt_tokens = tokens.size();
|
||||
if (n_prompt_tokens < 1) {
|
||||
fprintf(stderr, "%s : failed to tokenize prompt\n", __func__);
|
||||
llama_free(ctx);
|
||||
@ -92,7 +91,7 @@ int main(int argc, char ** argv) {
|
||||
auto next_token_str = llama_token_to_str(ctx, next_token);
|
||||
last_n_tokens_data.push_back(next_token);
|
||||
|
||||
printf("%s", next_token_str);
|
||||
printf("%s", next_token_str.c_str());
|
||||
if (llama_eval(ctx, &next_token, 1, n_past, params.n_threads)) {
|
||||
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
|
||||
llama_free(ctx);
|
||||
@ -152,7 +151,7 @@ int main(int argc, char ** argv) {
|
||||
auto next_token_str = llama_token_to_str(ctx2, next_token);
|
||||
last_n_tokens_data.push_back(next_token);
|
||||
|
||||
printf("%s", next_token_str);
|
||||
printf("%s", next_token_str.c_str());
|
||||
if (llama_eval(ctx2, &next_token, 1, n_past, params.n_threads)) {
|
||||
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
|
||||
llama_free(ctx2);
|
||||
|
@ -62,7 +62,7 @@ int main(int argc, char ** argv) {
|
||||
fprintf(stderr, "\n\n");
|
||||
|
||||
for (auto id : tokens_list) {
|
||||
fprintf(stderr, "%s", llama_token_to_str(ctx, id));
|
||||
fprintf(stderr, "%s", llama_token_to_str(ctx, id).c_str());
|
||||
}
|
||||
|
||||
fflush(stderr);
|
||||
@ -109,7 +109,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
// print the new token :
|
||||
printf("%s", llama_token_to_str(ctx, new_token_id));
|
||||
printf("%s", llama_token_to_str(ctx, new_token_id).c_str());
|
||||
fflush(stdout);
|
||||
|
||||
// push this new token for next evaluation
|
||||
|
@ -1,4 +1,5 @@
|
||||
#include "ggml.h"
|
||||
#include "common.h"
|
||||
#include "llama.h"
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
@ -1961,7 +1962,7 @@ void print_matrix(struct ggml_tensor * probs) {
|
||||
|
||||
|
||||
void print_token(struct llama_context * ctx, llama_token token) {
|
||||
printf("%s", llama_token_to_str(ctx, token));
|
||||
printf("%s", llama_token_to_str(ctx, token).c_str());
|
||||
}
|
||||
|
||||
void print_tokens(struct llama_context* ctx, struct ggml_tensor * tokens) {
|
||||
@ -2188,11 +2189,10 @@ int tokenize_file(struct llama_context * lctx, const char * filename, std::vecto
|
||||
f.read_raw(buf.data(), f.size);
|
||||
buf[f.size] = '\0';
|
||||
|
||||
out.resize(buf.size());
|
||||
|
||||
int n_tokens = llama_tokenize(lctx, buf.data(), out.data(), buf.size(), false);
|
||||
if (n_tokens >= 0) {
|
||||
out.resize(n_tokens);
|
||||
int n_tokens = llama_tokenize(lctx, buf.data(), out.data(), out.size(), false);
|
||||
if (n_tokens < 0) {
|
||||
out.resize(-n_tokens);
|
||||
llama_tokenize(lctx, buf.data(), out.data(), out.size(), false);
|
||||
}
|
||||
|
||||
bool verify = false;
|
||||
@ -2200,17 +2200,17 @@ int tokenize_file(struct llama_context * lctx, const char * filename, std::vecto
|
||||
const char * in = buf.data();
|
||||
const char * end = buf.data() + buf.size();
|
||||
for (int i = 0; i < (int) out.size(); ++i) {
|
||||
const char * s = llama_token_to_str(lctx, out[i]);
|
||||
int len = strlen(s);
|
||||
std::string s = llama_token_to_str(lctx, out[i]);
|
||||
int len = s.length();
|
||||
if (in >= end) {
|
||||
printf("%s: unexpected end of original text.\n", __func__);
|
||||
break;
|
||||
}
|
||||
const bool matches = (strncmp(in, s, len) == 0);
|
||||
const bool matches = (strncmp(in, s.c_str(), len) == 0);
|
||||
if (matches) {
|
||||
in += len;
|
||||
} else {
|
||||
printf("%s: mismatch: expected '%s', but got '%s'\n", __func__, std::string(in, len).c_str(), s);
|
||||
printf("%s: mismatch: expected '%s', but got '%s'\n", __func__, std::string(in, len).c_str(), s.c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
326
llama.cpp
326
llama.cpp
@ -7,6 +7,7 @@
|
||||
#endif
|
||||
|
||||
#include "llama-util.h"
|
||||
#define LLAMA_API_CPP // TODO: eliminate me
|
||||
#include "llama.h"
|
||||
|
||||
#include "ggml.h"
|
||||
@ -575,6 +576,7 @@ struct llama_file_loader {
|
||||
float score = 0.0f;
|
||||
file.read_raw(&score, sizeof(score));
|
||||
|
||||
GGML_ASSERT(vocab.token_to_id.find(word) == vocab.token_to_id.end());
|
||||
vocab.token_to_id[word] = i;
|
||||
|
||||
auto & tok_score = vocab.id_to_token[i];
|
||||
@ -1060,6 +1062,11 @@ static void llama_model_load_internal(
|
||||
std::unique_ptr<llama_model_loader> ml(new llama_model_loader(fname, use_mmap));
|
||||
|
||||
vocab = std::move(ml->file_loader->vocab);
|
||||
|
||||
if (vocab_only) {
|
||||
return;
|
||||
}
|
||||
|
||||
model.hparams = ml->file_loader->hparams;
|
||||
model.n_gpu_layers = n_gpu_layers;
|
||||
llama_file_version file_version = ml->file_loader->file_version;
|
||||
@ -1141,10 +1148,6 @@ static void llama_model_load_internal(
|
||||
}
|
||||
}
|
||||
|
||||
if (vocab_only) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto & ctx = model.ctx;
|
||||
|
||||
size_t ctx_size;
|
||||
@ -1940,6 +1943,105 @@ static bool llama_eval_internal(
|
||||
// tokenizer
|
||||
//
|
||||
|
||||
static std::string llama_vocab_type(const llama_vocab& vocab) {
|
||||
return vocab.token_to_id.size() == 32000 ? "spm": "bpe";
|
||||
}
|
||||
|
||||
static bool llama_is_normal_token(const llama_vocab& vocab, llama_token token) {
|
||||
if(llama_vocab_type(vocab) == "spm")
|
||||
return token >= 259;
|
||||
else if(llama_vocab_type(vocab) == "bpe")
|
||||
return token >= 95;
|
||||
else
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool llama_is_unknown_token(const llama_vocab& vocab, llama_token token) {
|
||||
if(llama_vocab_type(vocab) == "spm")
|
||||
return token == 0;
|
||||
else
|
||||
// TODO: improve?
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool llama_is_control_token(const llama_vocab& vocab, llama_token token) {
|
||||
if(llama_vocab_type(vocab) == "spm")
|
||||
return token == 1 || token == 2;
|
||||
else
|
||||
// TODO: improve?
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool llama_is_bos_token(const llama_vocab& vocab, llama_token token) {
|
||||
if(llama_vocab_type(vocab) == "spm")
|
||||
return token == 1;
|
||||
else
|
||||
// TODO: improve?
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool llama_is_eos_token(const llama_vocab& vocab, llama_token token) {
|
||||
if(llama_vocab_type(vocab) == "spm")
|
||||
return token == 2;
|
||||
else
|
||||
// TODO: improve?
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool llama_is_user_defined_token(const llama_vocab& vocab, llama_token token) {
|
||||
// TODO: improve?
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool llama_is_unused_token(const llama_vocab& vocab, llama_token token) {
|
||||
// TODO: improve?
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool llama_is_byte_token(const llama_vocab& vocab, llama_token token) {
|
||||
if(llama_vocab_type(vocab) == "spm")
|
||||
return 3 <= token && token < 259;
|
||||
else if(llama_vocab_type(vocab) == "bpe")
|
||||
return 1 <= token && token < 95;
|
||||
else
|
||||
return false;
|
||||
}
|
||||
|
||||
static uint8_t llama_byte_to_char(const llama_vocab& vocab, uint8_t byte) {
|
||||
if(llama_vocab_type(vocab) == "spm")
|
||||
return byte + 3;
|
||||
else if(llama_vocab_type(vocab) == "bpe")
|
||||
return byte + 32;
|
||||
else
|
||||
return false;
|
||||
}
|
||||
|
||||
static std::string llama_escape_whitespace(const std::string& text) {
|
||||
std::string result;
|
||||
bool escaping = false;
|
||||
result += "\xe2\x96\x81";
|
||||
for (size_t offs = 0; offs < text.length(); ++offs) {
|
||||
if (text[offs] == ' ') {
|
||||
if (!escaping) {
|
||||
result += "\xe2\x96\x81";
|
||||
escaping = true;
|
||||
}
|
||||
}
|
||||
else {
|
||||
escaping = false;
|
||||
result += text[offs];
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static std::string llama_unescape_whitespace(const std::string& word) {
|
||||
if (word.length() >= 3 && word.substr(0, 3) == "\xe2\x96\x81") {
|
||||
return std::string(" ") + word.substr(3);
|
||||
}
|
||||
return word;
|
||||
}
|
||||
|
||||
static size_t utf8_len(char src) {
|
||||
const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
|
||||
uint8_t highbits = static_cast<uint8_t>(src) >> 4;
|
||||
@ -1981,10 +2083,11 @@ struct llama_tokenizer {
|
||||
size_t offs = 0;
|
||||
while (offs < text.size()) {
|
||||
llama_sp_symbol sym;
|
||||
size_t char_len = std::min(text.size() - offs, utf8_len(text[offs]));
|
||||
size_t len = utf8_len(text[offs]);
|
||||
GGML_ASSERT(offs + len <= text.size());
|
||||
sym.text = text.c_str() + offs;
|
||||
sym.n = char_len;
|
||||
offs += char_len;
|
||||
sym.n = len;
|
||||
offs += len;
|
||||
sym.prev = index - 1;
|
||||
sym.next = offs == text.size() ? -1 : index + 1;
|
||||
index++;
|
||||
@ -2029,23 +2132,36 @@ struct llama_tokenizer {
|
||||
|
||||
for (int i = 0; i != -1; i = symbols_[i].next) {
|
||||
auto & symbol = symbols_[i];
|
||||
auto token = vocab_.token_to_id.find(std::string(symbol.text, symbol.n));
|
||||
|
||||
if (token == vocab_.token_to_id.end()) {
|
||||
// output any symbols that did not form tokens as bytes.
|
||||
for (int j = 0; j < (int) symbol.n; ++j) {
|
||||
// NOTE: old version, before #2420 - not sure what are the implications of this
|
||||
//llama_vocab::id token_id = static_cast<uint8_t>(symbol.text[j]) + 3;
|
||||
llama_vocab::id token_id = vocab_.token_to_id.at(std::string(1, symbol.text[j]));
|
||||
output.push_back(token_id);
|
||||
}
|
||||
} else {
|
||||
output.push_back((*token).second);
|
||||
}
|
||||
resegment(symbol, output);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void resegment(llama_sp_symbol &symbol, std::vector<llama_vocab::id> &output) {
|
||||
auto text = std::string(symbol.text, symbol.n);
|
||||
auto token = vocab_.token_to_id.find(text);
|
||||
|
||||
// Do we need to support is_unused?
|
||||
if (token != vocab_.token_to_id.end()) {
|
||||
output.push_back((*token).second);
|
||||
return;
|
||||
}
|
||||
|
||||
const auto p = rev_merge.find(text);
|
||||
|
||||
if (p == rev_merge.end()) {
|
||||
// output any symbols that did not form tokens as bytes.
|
||||
for (int j = 0; j < (int)symbol.n; ++j) {
|
||||
llama_vocab::id token_id = llama_byte_to_char(vocab_, symbol.text[j]);
|
||||
output.push_back(token_id);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
resegment(symbols_[p->second.first], output);
|
||||
resegment(symbols_[p->second.second], output);
|
||||
}
|
||||
|
||||
void try_add_bigram(int left, int right) {
|
||||
if (left == -1 || right == -1) {
|
||||
return;
|
||||
@ -2070,18 +2186,22 @@ private:
|
||||
bigram.score = tok_score.score;
|
||||
bigram.size = text.size();
|
||||
work_queue_.push(bigram);
|
||||
|
||||
// Do we need to support is_unused?
|
||||
rev_merge[text] = std::make_pair(left, right);
|
||||
}
|
||||
|
||||
const llama_vocab & vocab_;
|
||||
std::vector<llama_sp_symbol> symbols_;
|
||||
llama_sp_bigram::queue work_queue_;
|
||||
std::map<std::string, std::pair<int, int> > rev_merge;
|
||||
};
|
||||
|
||||
static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, const std::string & text, bool bos) {
|
||||
static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, const std::string & raw_text, bool bos, bool escape) {
|
||||
llama_tokenizer tokenizer(vocab);
|
||||
std::vector<llama_vocab::id> output;
|
||||
|
||||
if (text.empty()) {
|
||||
if (raw_text.empty()) {
|
||||
return output;
|
||||
}
|
||||
|
||||
@ -2089,6 +2209,13 @@ static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, co
|
||||
output.push_back(llama_token_bos());
|
||||
}
|
||||
|
||||
std::string text;
|
||||
if (escape) {
|
||||
text = llama_escape_whitespace(raw_text);
|
||||
} else {
|
||||
text = raw_text;
|
||||
}
|
||||
|
||||
tokenizer.tokenize(text, output);
|
||||
return output;
|
||||
}
|
||||
@ -2670,15 +2797,15 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
|
||||
|
||||
for (size_t i = 0; i < candidates->size; ++i) {
|
||||
const llama_token id = candidates->data[i].id;
|
||||
const char * str = llama_token_to_str(ctx, id);
|
||||
std::string str = llama_token_to_str(ctx, id);
|
||||
if (id == eos) {
|
||||
if (!allow_eos) {
|
||||
candidates->data[i].logit = -INFINITY;
|
||||
}
|
||||
} else if (*str == 0) {
|
||||
} else if (str.empty()) {
|
||||
candidates->data[i].logit = -INFINITY;
|
||||
} else {
|
||||
candidates_decoded.push_back(decode_utf8(str));
|
||||
candidates_decoded.push_back(decode_utf8(str.c_str()));
|
||||
candidates_grammar.push_back({ i, candidates_decoded.back().data() });
|
||||
}
|
||||
}
|
||||
@ -2879,9 +3006,9 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
|
||||
LLAMA_ASSERT(false);
|
||||
}
|
||||
|
||||
const char * str = llama_token_to_str(ctx, token);
|
||||
std::string str = llama_token_to_str(ctx, token);
|
||||
// Note terminating 0 in decoded string
|
||||
auto code_points = decode_utf8(str);
|
||||
auto code_points = decode_utf8(str.c_str());
|
||||
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
||||
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
|
||||
}
|
||||
@ -4132,7 +4259,8 @@ int llama_tokenize_with_model(
|
||||
llama_token * tokens,
|
||||
int n_max_tokens,
|
||||
bool add_bos) {
|
||||
auto res = llama_tokenize(model->vocab, text, add_bos);
|
||||
auto escape = llama_vocab_type(model->vocab) == "spm";
|
||||
auto res = llama_tokenize(model->vocab, text, add_bos, escape);
|
||||
|
||||
if (n_max_tokens < (int) res.size()) {
|
||||
LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);
|
||||
@ -4155,6 +4283,62 @@ int llama_tokenize(
|
||||
return llama_tokenize_with_model(&ctx->model, text, tokens, n_max_tokens, add_bos);
|
||||
}
|
||||
|
||||
std::vector<llama_token> llama_tokenize(
|
||||
struct llama_context * ctx,
|
||||
const std::string & text,
|
||||
bool add_bos) {
|
||||
int length = text.length() + add_bos;
|
||||
std::vector<llama_token> result(length);
|
||||
length = llama_tokenize(ctx, text.c_str(), result.data(), result.size(), add_bos);
|
||||
if (length < 0) {
|
||||
result.resize(-length);
|
||||
int check = llama_tokenize(ctx, text.c_str(), result.data(), result.size(), add_bos);
|
||||
assert(check == -length);
|
||||
GGML_UNUSED(check);
|
||||
} else {
|
||||
result.resize(length);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
int llama_tokenize_bpe(
|
||||
struct llama_context * ctx,
|
||||
const char * text,
|
||||
llama_token * tokens,
|
||||
int n_max_tokens,
|
||||
bool add_bos) {
|
||||
auto res = llama_tokenize(ctx->model.vocab, text, add_bos, false);
|
||||
|
||||
if (n_max_tokens < (int) res.size()) {
|
||||
LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);
|
||||
return -((int) res.size());
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < res.size(); i++) {
|
||||
tokens[i] = res[i];
|
||||
}
|
||||
|
||||
return res.size();
|
||||
}
|
||||
|
||||
std::vector<llama_token> llama_tokenize_bpe(
|
||||
struct llama_context * ctx,
|
||||
const std::string & text,
|
||||
bool add_bos) {
|
||||
int length = text.length() + add_bos;
|
||||
std::vector<llama_token> result(length);
|
||||
length = llama_tokenize_bpe(ctx, text.c_str(), result.data(), result.size(), add_bos);
|
||||
if (length < 0) {
|
||||
result.resize(-length);
|
||||
int check = llama_tokenize_bpe(ctx, text.c_str(), result.data(), result.size(), add_bos);
|
||||
assert(check == -length);
|
||||
GGML_UNUSED(check);
|
||||
} else {
|
||||
result.resize(length);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
int llama_n_vocab_from_model(const struct llama_model * model) {
|
||||
return model->vocab.id_to_token.size();
|
||||
}
|
||||
@ -4208,16 +4392,88 @@ float * llama_get_embeddings(struct llama_context * ctx) {
|
||||
return ctx->embedding.data();
|
||||
}
|
||||
|
||||
const char * llama_token_to_str_with_model(const struct llama_model * model, llama_token token) {
|
||||
if (token >= llama_n_vocab_from_model(model)) {
|
||||
return nullptr;
|
||||
int llama_token_to_str_with_model(const struct llama_model * model, llama_token token, char * str, int length) {
|
||||
if (0 <= token && token < llama_n_vocab_from_model(model)) {
|
||||
if (llama_is_normal_token(model->vocab, token)) {
|
||||
std::string result = model->vocab.id_to_token[token].tok;
|
||||
if(llama_vocab_type(model->vocab) == "spm") {
|
||||
result = llama_unescape_whitespace(result);
|
||||
}
|
||||
if(result.length() > length) {
|
||||
return - result.length();
|
||||
}
|
||||
strcpy(str, result.c_str());
|
||||
return result.length();
|
||||
} else if (llama_is_unknown_token(model->vocab, token)) {
|
||||
if(3 > length) {
|
||||
return -3;
|
||||
}
|
||||
strcpy(str, "\xe2\x96\x85");
|
||||
return 3;
|
||||
} else if (llama_is_control_token(model->vocab, token)) {
|
||||
;
|
||||
} else if (llama_is_byte_token(model->vocab, token)) {
|
||||
if(1 > length) {
|
||||
return -1;
|
||||
}
|
||||
str[0] = llama_byte_to_char(model->vocab, token);
|
||||
str[1] = 0x00;
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
return model->vocab.id_to_token[token].tok.c_str();
|
||||
return 0;
|
||||
}
|
||||
|
||||
const char * llama_token_to_str(const struct llama_context * ctx, llama_token token) {
|
||||
return llama_token_to_str_with_model(&ctx->model, token);
|
||||
int llama_token_to_str(const struct llama_context * ctx, llama_token token, char * str, int length) {
|
||||
return llama_token_to_str_with_model(&ctx->model, token, str, length);
|
||||
}
|
||||
|
||||
std::string llama_token_to_str(
|
||||
const struct llama_context * ctx,
|
||||
llama_token token) {
|
||||
std::string result;
|
||||
int length = 8;
|
||||
result.resize(length);
|
||||
length = llama_token_to_str(ctx, token, (char *)result.data(), result.length());
|
||||
if (length < 0) {
|
||||
result.resize(-length);
|
||||
int check = llama_token_to_str(ctx, token, (char *)result.data(), result.length());
|
||||
assert(check == -length);
|
||||
GGML_UNUSED(check);
|
||||
} else {
|
||||
result.resize(length);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
int llama_token_to_str_bpe(const struct llama_context * ctx, llama_token token, char * str, int length) {
|
||||
if (0 <= token && token < llama_n_vocab_from_model(&ctx->model)) {
|
||||
std::string result = ctx->model.vocab.id_to_token[token].tok;
|
||||
if (result.length() > length) {
|
||||
return - result.length();
|
||||
}
|
||||
strcpy(str, result.c_str());
|
||||
return result.length();
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::string llama_token_to_str_bpe(
|
||||
const struct llama_context * ctx,
|
||||
llama_token token) {
|
||||
std::string result;
|
||||
int length = 8;
|
||||
result.resize(length);
|
||||
length = llama_token_to_str_bpe(ctx, token, (char*)result.data(), result.length());
|
||||
if (length < 0) {
|
||||
result.resize(-length);
|
||||
int check = llama_token_to_str_bpe(ctx, token, (char*)result.data(), result.length());
|
||||
assert(check == -length);
|
||||
GGML_UNUSED(check);
|
||||
} else {
|
||||
result.resize(length);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
llama_token llama_token_bos() {
|
||||
|
60
llama.h
60
llama.h
@ -336,6 +336,13 @@ extern "C" {
|
||||
int n_max_tokens,
|
||||
bool add_bos);
|
||||
|
||||
LLAMA_API int llama_tokenize_bpe(
|
||||
struct llama_context * ctx,
|
||||
const char * text,
|
||||
llama_token * tokens,
|
||||
int n_max_tokens,
|
||||
bool add_bos);
|
||||
|
||||
LLAMA_API int llama_tokenize_with_model(
|
||||
const struct llama_model * model,
|
||||
const char * text,
|
||||
@ -377,14 +384,23 @@ extern "C" {
|
||||
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
|
||||
|
||||
// Token Id -> String. Uses the vocabulary in the provided context
|
||||
LLAMA_API const char * llama_token_to_str(
|
||||
LLAMA_API int llama_token_to_str(
|
||||
const struct llama_context * ctx,
|
||||
llama_token token);
|
||||
llama_token token,
|
||||
char * str,
|
||||
int length);
|
||||
|
||||
LLAMA_API const char * llama_token_to_str_with_model(
|
||||
LLAMA_API int llama_token_to_str_bpe(
|
||||
const struct llama_context * ctx,
|
||||
llama_token token,
|
||||
char * str,
|
||||
int length);
|
||||
|
||||
LLAMA_API int llama_token_to_str_with_model(
|
||||
const struct llama_model * model,
|
||||
llama_token token);
|
||||
|
||||
llama_token token,
|
||||
char * str,
|
||||
int length);
|
||||
// Special tokens
|
||||
LLAMA_API llama_token llama_token_bos(); // beginning-of-sentence
|
||||
LLAMA_API llama_token llama_token_eos(); // end-of-sentence
|
||||
@ -472,15 +488,43 @@ extern "C" {
|
||||
}
|
||||
#endif
|
||||
|
||||
// Internal API to be implemented by llama.cpp and used by tests/benchmarks only
|
||||
#ifdef LLAMA_API_INTERNAL
|
||||
// C++ API, will be moving to common.h soon (TM)
|
||||
#ifdef LLAMA_API_CPP
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
//
|
||||
// Vocab utils
|
||||
//
|
||||
|
||||
std::vector<llama_token> llama_tokenize(
|
||||
struct llama_context * ctx,
|
||||
const std::string & text,
|
||||
bool add_bos);
|
||||
|
||||
std::vector<llama_token> llama_tokenize_bpe(
|
||||
struct llama_context * ctx,
|
||||
const std::string & text,
|
||||
bool add_bos);
|
||||
|
||||
std::string llama_token_to_str(
|
||||
const struct llama_context * ctx,
|
||||
llama_token token);
|
||||
|
||||
std::string llama_token_to_str_bpe(
|
||||
const struct llama_context * ctx,
|
||||
llama_token token);
|
||||
|
||||
// Internal API to be implemented by llama.cpp and used by tests/benchmarks only
|
||||
#ifdef LLAMA_API_INTERNAL
|
||||
|
||||
struct ggml_tensor;
|
||||
|
||||
const std::vector<std::pair<std::string, struct ggml_tensor *>>& llama_internal_get_tensor_map(struct llama_context * ctx);
|
||||
|
||||
#endif
|
||||
#endif // LLAMA_API_CPP
|
||||
|
||||
#endif // LLAMA_API_INTERNAL
|
||||
|
||||
#endif // LLAMA_H
|
||||
|
BIN
models/ggml-vocab-aquila.bin
Normal file
BIN
models/ggml-vocab-aquila.bin
Normal file
Binary file not shown.
BIN
models/ggml-vocab-llama.bin
Normal file
BIN
models/ggml-vocab-llama.bin
Normal file
Binary file not shown.
Binary file not shown.
@ -1,4 +1,19 @@
|
||||
function(llama_add_test source)
|
||||
function(llama_build_executable source)
|
||||
get_filename_component(TEST_TARGET ${source} NAME_WE)
|
||||
add_executable(${TEST_TARGET} ${source})
|
||||
install(TARGETS ${TEST_TARGET} RUNTIME)
|
||||
target_link_libraries(${TEST_TARGET} PRIVATE llama)
|
||||
endfunction()
|
||||
|
||||
function(llama_test_executable name source)
|
||||
get_filename_component(TEST_TARGET ${source} NAME_WE)
|
||||
# add_executable(${TEST_TARGET} ${source})
|
||||
# install(TARGETS ${TEST_TARGET} RUNTIME)
|
||||
# target_link_libraries(${TEST_TARGET} PRIVATE llama)
|
||||
add_test(NAME ${name} COMMAND $<TARGET_FILE:${TEST_TARGET}> ${ARGN})
|
||||
endfunction()
|
||||
|
||||
function(llama_build_and_test_executable source)
|
||||
get_filename_component(TEST_TARGET ${source} NAME_WE)
|
||||
add_executable(${TEST_TARGET} ${source})
|
||||
install(TARGETS ${TEST_TARGET} RUNTIME)
|
||||
@ -6,11 +21,15 @@ function(llama_add_test source)
|
||||
add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}> ${ARGN})
|
||||
endfunction()
|
||||
|
||||
# llama_add_test(test-double-float.cpp) # SLOW
|
||||
llama_add_test(test-quantize-fns.cpp)
|
||||
llama_add_test(test-quantize-perf.cpp)
|
||||
llama_add_test(test-sampling.cpp)
|
||||
llama_add_test(test-tokenizer-0.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab.bin)
|
||||
llama_add_test(test-grammar-parser.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../examples/grammar-parser.cpp)
|
||||
llama_add_test(test-grad0.cpp) # SLOW
|
||||
# llama_add_test(test-opt.cpp) # SLOW
|
||||
# llama_build_and_test_executable(test-double-float.cpp) # SLOW
|
||||
llama_build_and_test_executable(test-quantize-fns.cpp)
|
||||
llama_build_and_test_executable(test-quantize-perf.cpp)
|
||||
llama_build_and_test_executable(test-sampling.cpp)
|
||||
llama_build_executable(test-tokenizer-0.cpp)
|
||||
llama_test_executable(test-tokenizer-0.llama test-tokenizer-0.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.bin)
|
||||
llama_build_executable(test-tokenizer-1.cpp)
|
||||
llama_test_executable(test-tokenizer-1.llama test-tokenizer-1.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.bin)
|
||||
llama_test_executable(test-tokenizer-1.aquila test-tokenizer-1.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-aquila.bin)
|
||||
llama_build_and_test_executable(test-grammar-parser.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../examples/grammar-parser.cpp)
|
||||
llama_build_and_test_executable(test-grad0.cpp) # SLOW
|
||||
# llama_build_and_test_executable(test-opt.cpp) # SLOW
|
||||
|
@ -1,3 +1,4 @@
|
||||
#define LLAMA_API_CPP // TODO: eliminate me
|
||||
#include "llama.h"
|
||||
|
||||
#include <cstdio>
|
||||
@ -5,16 +6,40 @@
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
static std::string unescape_whitespace(llama_context* ctx, const std::vector<llama_token>& tokens) {
|
||||
std::string result;
|
||||
for (int i = 0; i < tokens.size(); ++i) {
|
||||
result += llama_token_to_str(ctx, tokens[i]);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static const std::map<std::string, std::vector<llama_token>> & k_tests()
|
||||
{
|
||||
static std::map<std::string, std::vector<llama_token>> _k_tests = {
|
||||
{ "Hello World", { 1, 10994, 2787, }, },
|
||||
{ " Hello World", { 1, 15043, 2787, }, },
|
||||
{ " Hello World!", { 1, 15043, 2787, 29991, }, },
|
||||
{ " this is 🦙.cpp", { 1, 445, 338, 29871, 243, 162, 169, 156, 29889, 8223, }, },
|
||||
{ "w048 7tuijk dsdfhu", { 1, 29893, 29900, 29946, 29947, 29871, 29955, 9161, 13535, 18031, 2176, 6905, }, },
|
||||
{ "нещо на Български", { 1, 821, 4851, 665, 1386, 29713, 1305, }, },
|
||||
};
|
||||
{ " ", {1, 259, }, },
|
||||
{ "\t", { 1, 29871, 12, }, },
|
||||
{ "\n", { 1, 29871, 13, }, },
|
||||
{ "\t\n", { 1, 29871, 12, 13, }, },
|
||||
{ "Hello world", { 1, 15043, 3186, }, },
|
||||
{ " Hello world", { 1, 29871, 15043, 3186, }, },
|
||||
{ "Hello World", { 1, 15043, 2787, }, },
|
||||
{ " Hello World", { 1, 29871, 15043, 2787, }, },
|
||||
{ " Hello World!", { 1, 29871, 15043, 2787, 29991, }, },
|
||||
{ " this is 🦙.cpp", { 1, 29871, 445, 338, 29871, 243, 162, 169, 156, 29889, 8223, }, },
|
||||
{ "w048 7tuijk dsdfhu", { 1, 281, 29900, 29946, 29947, 29871, 29955, 9161, 13535, 18031, 2176, 6905, }, },
|
||||
{ "нещо на Български", { 1, 1538, 4851, 665, 1386, 29713, 1305, }, },
|
||||
{ "កាន់តែពិសេសអាចខលចេញ", { 1, 29871, 31849, 31324, 31934, 228, 162, 142, 228, 161,
|
||||
146, 228, 162, 133, 228, 161, 153, 228, 161, 186,
|
||||
31708, 228, 162, 132, 31708, 228, 161, 165, 31324, 228,
|
||||
161, 136, 228, 161, 132, 228, 161, 158, 228, 161,
|
||||
136, 228, 162, 132, 228, 161, 140, }, },
|
||||
{ "🚀 (normal) 😶🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token)",
|
||||
{ 1, 29871, 243, 162, 157, 131, 313, 8945, 29897, 29871,
|
||||
243, 162, 155, 185, 30722, 243, 162, 143, 174, 30598,
|
||||
313, 20787, 953, 3848, 275, 16125, 630, 29897, 29871, 31681,
|
||||
313, 6194, 953, 29877, 2397, 393, 756, 967, 1914, 5993, 29897, }, },
|
||||
};
|
||||
return _k_tests;
|
||||
};
|
||||
|
||||
@ -65,9 +90,9 @@ int main(int argc, char **argv) {
|
||||
}
|
||||
|
||||
for (const auto & test_kv : k_tests()) {
|
||||
std::vector<llama_token> res(test_kv.first.size());
|
||||
const int n = llama_tokenize(ctx, test_kv.first.c_str(), res.data(), int(res.size()), true);
|
||||
res.resize(n);
|
||||
std::vector<llama_token> res = llama_tokenize(ctx, test_kv.first.c_str(), true);
|
||||
fprintf(stderr, "%s : '%s' tokenized to '%s'\n",
|
||||
__func__, test_kv.first.c_str(), unescape_whitespace(ctx, res).c_str());
|
||||
|
||||
bool correct = res.size() == test_kv.second.size();
|
||||
|
||||
|
122
tests/test-tokenizer-1.cpp
Normal file
122
tests/test-tokenizer-1.cpp
Normal file
@ -0,0 +1,122 @@
|
||||
#define LLAMA_API_CPP // TODO: eliminate me
|
||||
#include "llama.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
#include <codecvt>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
static std::string vocab_type(llama_context* ctx) {
|
||||
return llama_n_vocab(ctx) == 32000 ? "spm": "bpe";
|
||||
}
|
||||
|
||||
static std::string escape_whitespace(const std::string& text) {
|
||||
std::string result;
|
||||
bool escaping = false;
|
||||
result += "\xe2\x96\x81";
|
||||
for (size_t offs = 0; offs < text.length(); ++offs) {
|
||||
if (text[offs] == ' ') {
|
||||
if (!escaping) {
|
||||
result += "\xe2\x96\x81";
|
||||
escaping = true;
|
||||
}
|
||||
}
|
||||
else {
|
||||
escaping = false;
|
||||
result += text[offs];
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static std::string unescape_whitespace(llama_context* ctx, const std::vector<llama_token>& tokens) {
|
||||
std::string result;
|
||||
for (int i = 0; i < tokens.size(); ++i) {
|
||||
result += llama_token_to_str(ctx, tokens[i]);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
if (argc < 2) {
|
||||
fprintf(stderr, "Usage: %s <vocab-file>\n", argv[0]);
|
||||
return 1;
|
||||
}
|
||||
|
||||
const std::string fname = argv[1];
|
||||
|
||||
fprintf(stderr, "%s : reading vocab from: '%s'\n", __func__, fname.c_str());
|
||||
|
||||
llama_model * model;
|
||||
llama_context * ctx;
|
||||
|
||||
llama_backend_init(false);
|
||||
|
||||
// load the vocab
|
||||
{
|
||||
auto lparams = llama_context_default_params();
|
||||
|
||||
lparams.vocab_only = true;
|
||||
|
||||
model = llama_load_model_from_file(fname.c_str(), lparams);
|
||||
|
||||
if (model == NULL) {
|
||||
fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
|
||||
return 1;
|
||||
}
|
||||
|
||||
ctx = llama_new_context_with_model(model, lparams);
|
||||
|
||||
if (ctx == NULL) {
|
||||
fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
|
||||
llama_free_model(model);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
const int n_vocab = llama_n_vocab(ctx);
|
||||
|
||||
for (int i = 0; i < n_vocab; ++i) {
|
||||
std::string forward = llama_token_to_str_bpe(ctx, i);
|
||||
std::vector<llama_token> tokens = llama_tokenize_bpe(ctx, forward, false);
|
||||
if (tokens.size() == 1) {
|
||||
if (i != tokens[0]) {
|
||||
std::string backward = llama_token_to_str(ctx, tokens[0]);
|
||||
fprintf(stderr, "%s : error: token %d is string %s but bpe returns token %d %s\n",
|
||||
__func__, i, llama_token_to_str(ctx, i).c_str(), tokens[0], backward.c_str());
|
||||
return 2;
|
||||
}
|
||||
} else {
|
||||
if ((vocab_type(ctx) == "spm" && i <= 258) ||
|
||||
(vocab_type(ctx) == "bpe" && (i == 0 || i >= 100000))) {
|
||||
fprintf(stderr, "%s : info: token %d is string %s and bpe returns tokens %s\n",
|
||||
__func__, i, llama_token_to_str(ctx, i).c_str(), unescape_whitespace(ctx, tokens).c_str());
|
||||
} else {
|
||||
fprintf(stderr, "%s : error: token %d is string %s but bpe returns tokens %s\n",
|
||||
__func__, i, llama_token_to_str(ctx, i).c_str(), unescape_whitespace(ctx, tokens).c_str());
|
||||
return 2;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::wstring_convert<typename std::codecvt_utf8<wchar_t>, wchar_t> converter;
|
||||
for (wchar_t ch = 0x0000; ch < 0xffff; ++ch) {
|
||||
std::wstring wstr(1, ch);
|
||||
std::string str = converter.to_bytes(wstr);
|
||||
std::vector<llama_token> tokens = llama_tokenize(ctx, escape_whitespace(str).c_str(), false);
|
||||
if (tokens.size() == 1) {
|
||||
fprintf(stderr, "%s : info: %s tokenized to %d \n",
|
||||
__func__, str.c_str(), tokens[0]);
|
||||
}
|
||||
}
|
||||
|
||||
llama_free_model(model);
|
||||
llama_free(ctx);
|
||||
|
||||
llama_backend_free();
|
||||
|
||||
return 0;
|
||||
}
|
Loading…
Reference in New Issue
Block a user