diff --git a/common/common.cpp b/common/common.cpp index 7d983a453..98fc8388c 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -2212,23 +2212,23 @@ std::tuple llama_init_from_gpt_par std::vector llama_tokenize( const struct llama_context * ctx, const std::string & text, - bool add_bos, - bool special) { - return llama_tokenize(llama_get_model(ctx), text, add_bos, special); + bool add_special, + bool parse_special) { + return llama_tokenize(llama_get_model(ctx), text, add_special, parse_special); } std::vector llama_tokenize( const struct llama_model * model, const std::string & text, - bool add_bos, - bool special) { + bool add_special, + bool parse_special) { // upper limit for the number of tokens - int n_tokens = text.length() + add_bos; + int n_tokens = text.length() + 2 * add_special; std::vector result(n_tokens); - n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, special); + n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_special, parse_special); if (n_tokens < 0) { result.resize(-n_tokens); - int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, special); + int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_special, parse_special); GGML_ASSERT(check == -n_tokens); } else { result.resize(n_tokens); diff --git a/common/common.h b/common/common.h index 4635e05d6..a7f476c1b 100644 --- a/common/common.h +++ b/common/common.h @@ -223,14 +223,14 @@ void llama_batch_add( std::vector llama_tokenize( const struct llama_context * ctx, const std::string & text, - bool add_bos, - bool special = false); + bool add_special, + bool parse_special = false); std::vector llama_tokenize( const struct llama_model * model, const std::string & text, - bool add_bos, - bool special = false); + bool add_special, + bool parse_special = false); // tokenizes a token into a piece // should work similar to Python's `tokenizer.id_to_piece` diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 37af6328a..63710676b 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -227,15 +227,14 @@ class Model(ABC): return ("pytorch_model.bin",) return (f"pytorch_model-{n:05}-of-{self.num_parts:05}.bin" for n in range(1, self.num_parts + 1)) - def _set_vocab_gpt2(self): - dir_model = self.dir_model - hparams = self.hparams + # used for GPT-2 BPE and WordPiece vocabs + def get_basic_vocab(self) -> tuple[list[str], list[int]]: tokens: list[str] = [] toktypes: list[int] = [] from transformers import AutoTokenizer - tokenizer = AutoTokenizer.from_pretrained(dir_model) - vocab_size = hparams.get("vocab_size", len(tokenizer.vocab)) + tokenizer = AutoTokenizer.from_pretrained(self.dir_model) + vocab_size = self.hparams.get("vocab_size", len(tokenizer.vocab)) assert max(tokenizer.vocab.values()) < vocab_size reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()} @@ -255,11 +254,15 @@ class Model(ABC): tokens.append(reverse_vocab[i]) toktypes.append(gguf.TokenType.NORMAL) + return tokens, toktypes + + def _set_vocab_gpt2(self) -> None: + tokens, toktypes = self.get_basic_vocab() self.gguf_writer.add_tokenizer_model("gpt2") self.gguf_writer.add_token_list(tokens) self.gguf_writer.add_token_types(toktypes) - special_vocab = gguf.SpecialVocab(dir_model, load_merges=True) + special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True) special_vocab.add_to_gguf(self.gguf_writer) def _set_vocab_qwen(self): @@ -2043,34 +2046,25 @@ class BertModel(Model): self.gguf_writer.add_pooling_type(pooling_type) def set_vocab(self): - # use huggingface vocab to get all tokens - vocab = LlamaHfVocab(self.dir_model, ignore_nonllama=True) - tokens, scores, toktypes = zip(*vocab.all_tokens()) - assert len(tokens) == vocab.vocab_size - self.vocab_size = vocab.vocab_size + tokens, toktypes = self.get_basic_vocab() + self.vocab_size = len(tokens) # we need this to validate the size of the token_type embeddings # though currently we are passing all zeros to the token_type embeddings - n_token_types = len(set(toktypes)) - self.gguf_writer.add_token_type_count(n_token_types) + self.gguf_writer.add_token_type_count(2) # "Sequence A" or "Sequence B" # convert to phantom space vocab - def phantom(tok, typ): - if tok.startswith(b"[") and tok.endswith(b"]"): + def phantom(tok): + if tok.startswith("[") and tok.endswith("]"): return tok - if tok.startswith(b"##"): + if tok.startswith("##"): return tok[2:] - return b"\xe2\x96\x81" + tok - tokens = tuple(phantom(t, y) for t, y in zip(tokens, toktypes)) - - # set up bos and eos tokens (cls and sep) - self.gguf_writer.add_bos_token_id(vocab.tokenizer.cls_token_id) - self.gguf_writer.add_eos_token_id(vocab.tokenizer.sep_token_id) + return "\u2581" + tok + tokens = list(map(phantom, tokens)) # add vocab to gguf self.gguf_writer.add_tokenizer_model("bert") self.gguf_writer.add_token_list(tokens) - self.gguf_writer.add_token_scores(scores) self.gguf_writer.add_token_types(toktypes) # handle special tokens @@ -2142,16 +2136,6 @@ class NomicBertModel(BertModel): super().set_gguf_parameters() self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"]) - def get_tensors(self): - assert self.vocab_size is not None - for name, data in super().get_tensors(): - # Nomic Embed's token embeddings tensor is padded, but llama.cpp wants tensor sizes to match exactly. - if name == 'embeddings.word_embeddings.weight' and data.shape[1] != self.vocab_size: - rounded_vocab_size = (self.vocab_size + 63) // 64 * 64 - assert data.shape == (rounded_vocab_size, self.hparams["n_embd"]) - data = data[:self.vocab_size, :] - yield name, data - @Model.register("GemmaForCausalLM") class GemmaModel(Model): @@ -2327,7 +2311,8 @@ class MambaModel(Model): data = data.astype(np.float32) # if f16 desired, convert big float32 2-dim weight tensors to float16 - if self.ftype == 1 and data_dtype == np.float32 and new_name.removesuffix(".weight").endswith((".ssm_in", ".ssm_out", "token_embd", "output")) and n_dims == 2: + new_weight_name = new_name[:-len(".weight")] if new_name.endswith(".weight") else "" + if self.ftype == 1 and data_dtype == np.float32 and new_weight_name.endswith((".ssm_in", ".ssm_out", "token_embd", "output")) and n_dims == 2: data = data.astype(np.float16) print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") diff --git a/convert-persimmon-to-gguf.py b/convert-persimmon-to-gguf.py index ccb99279e..69be17f94 100755 --- a/convert-persimmon-to-gguf.py +++ b/convert-persimmon-to-gguf.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 +from __future__ import annotations + import argparse import os import sys diff --git a/convert.py b/convert.py index a37aeb5e5..e860ac89f 100755 --- a/convert.py +++ b/convert.py @@ -33,7 +33,7 @@ if 'NO_LOCAL_GGUF' not in os.environ: import gguf if TYPE_CHECKING: - from typing import TypeAlias + from typing_extensions import Self, TypeAlias if hasattr(faulthandler, 'register') and hasattr(signal, 'SIGUSR1'): faulthandler.register(signal.SIGUSR1) @@ -517,7 +517,7 @@ class LlamaHfVocab(Vocab): tokenizer_model = "llama" name = "hfft" - def __init__(self, base_path: Path, ignore_nonllama: bool = False): + def __init__(self, base_path: Path): fname_tokenizer = base_path / FAST_TOKENIZER_FILE # if this fails, FileNotFoundError propagates to caller with open(fname_tokenizer, encoding='utf-8') as f: @@ -525,9 +525,7 @@ class LlamaHfVocab(Vocab): # pre-check so we know if we need transformers tokenizer_model: dict[str, Any] = tokenizer_json['model'] - if ignore_nonllama: - pass # workaround incorrect use of this class for WordPiece - elif ( + if ( tokenizer_model['type'] != 'BPE' or not tokenizer_model.get('byte_fallback', False) or tokenizer_json['decoder']['type'] != 'Sequence' ): @@ -647,16 +645,17 @@ def permute(weights: NDArray, n_head: int, n_head_kv: int) -> NDArray: class Tensor(ABC): + ndarray: NDArray data_type: DataType @abstractmethod - def astype(self, data_type: DataType) -> Tensor: ... + def astype(self, data_type: DataType) -> Self: ... @abstractmethod - def permute(self, n_head: int, n_head_kv: int) -> Tensor: ... + def permute(self, n_head: int, n_head_kv: int) -> Self: ... @abstractmethod - def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> UnquantizedTensor: ... + def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> Self: ... @abstractmethod - def part(self, n_part: int) -> UnquantizedTensor: ... + def part(self, n_part: int) -> Self: ... @abstractmethod def to_ggml(self) -> GGMLCompatibleTensor: ... @@ -673,13 +672,13 @@ class UnquantizedTensor(Tensor): self.ndarray = ndarray self.data_type = NUMPY_TYPE_TO_DATA_TYPE[ndarray.dtype] - def astype(self, data_type: DataType) -> Tensor: + def astype(self, data_type: DataType) -> UnquantizedTensor: dtype = data_type.dtype if self.data_type == DT_BF16: self.ndarray = bf16_to_fp32(self.ndarray) return UnquantizedTensor(self.ndarray.astype(dtype)) - def to_ggml(self) -> UnquantizedTensor: + def to_ggml(self) -> Self: return self def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> UnquantizedTensor: diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 536657526..6a93147d7 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -123,10 +123,10 @@ int main(int argc, char ** argv) { inputs.push_back(inp); } - // add eos if not present + // add SEP if not present for (auto & inp : inputs) { - if (inp.empty() || inp.back() != llama_token_eos(model)) { - inp.push_back(llama_token_eos(model)); + if (inp.empty() || inp.back() != llama_token_sep(model)) { + inp.push_back(llama_token_sep(model)); } } diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index d8cb0a642..1bf55f90c 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -349,12 +349,13 @@ static void process_logits( static bool compute_imatrix(llama_context * ctx, const gpt_params & params, bool compute_ppl, int from_chunk) { const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx)); + GGML_ASSERT(llama_add_eos_token(llama_get_model(ctx)) != 1); const int n_ctx = llama_n_ctx(ctx); auto tim1 = std::chrono::high_resolution_clock::now(); fprintf(stderr, "%s: tokenizing the input ..\n", __func__); - std::vector tokens = ::llama_tokenize(ctx, params.prompt, add_bos); + std::vector tokens = ::llama_tokenize(ctx, params.prompt, true); auto tim2 = std::chrono::high_resolution_clock::now(); fprintf(stderr, "%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast(tim2-tim1).count()); diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index 91c39c5ae..c69dcd06e 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -239,6 +239,7 @@ int main(int argc, char ** argv) { LOG_TEE("%s\n", get_system_info(params).c_str()); } const bool add_bos = llama_should_add_bos_token(model); + GGML_ASSERT(llama_add_eos_token(model) != 1); LOG("add_bos: %d\n", add_bos); bool suff_rm_leading_spc = params.escape; @@ -279,10 +280,10 @@ int main(int argc, char ** argv) { if (ctx_guidance) { LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(sparams.cfg_negative_prompt)); - guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, add_bos); + guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, true); LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp).c_str()); - std::vector original_inp = ::llama_tokenize(ctx, params.prompt, add_bos); + std::vector original_inp = ::llama_tokenize(ctx, params.prompt, true); LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str()); original_prompt_len = original_inp.size(); diff --git a/examples/llava/llava-cli.cpp b/examples/llava/llava-cli.cpp index e29da6cb2..75948806e 100644 --- a/examples/llava/llava-cli.cpp +++ b/examples/llava/llava-cli.cpp @@ -146,7 +146,6 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_ int n_past = 0; const int max_tgt_len = params->n_predict < 0 ? 256 : params->n_predict; - const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx_llava->ctx_llama)); std::string system_prompt, user_prompt; size_t image_pos = prompt.find(""); @@ -180,7 +179,7 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_ } } - eval_string(ctx_llava->ctx_llama, system_prompt.c_str(), params->n_batch, &n_past, add_bos); + eval_string(ctx_llava->ctx_llama, system_prompt.c_str(), params->n_batch, &n_past, true); llava_eval_image_embed(ctx_llava->ctx_llama, image_embed, params->n_batch, &n_past); eval_string(ctx_llava->ctx_llama, user_prompt.c_str(), params->n_batch, &n_past, false); diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp index e2551e7a4..5af6a8ab6 100644 --- a/examples/lookahead/lookahead.cpp +++ b/examples/lookahead/lookahead.cpp @@ -64,13 +64,10 @@ int main(int argc, char ** argv) { std::tie(model, ctx) = llama_init_from_gpt_params(params); // Tokenize the prompt - const bool add_bos = llama_should_add_bos_token(model); - LOG("add_bos tgt: %d\n", add_bos); - std::vector inp; std::vector all; - inp = ::llama_tokenize(ctx, params.prompt, add_bos, true); + inp = ::llama_tokenize(ctx, params.prompt, true, true); all = inp; const int max_context_size = llama_n_ctx(ctx); diff --git a/examples/lookup/lookup-create.cpp b/examples/lookup/lookup-create.cpp index 46a6bed07..1c230c966 100644 --- a/examples/lookup/lookup-create.cpp +++ b/examples/lookup/lookup-create.cpp @@ -28,10 +28,8 @@ int main(int argc, char ** argv){ GGML_ASSERT(model != nullptr); // tokenize the prompt - const bool add_bos = llama_should_add_bos_token(model); - std::vector inp; - inp = ::llama_tokenize(ctx, params.prompt, add_bos, true); + inp = ::llama_tokenize(ctx, params.prompt, true, true); fprintf(stderr, "%s: tokenization done\n", __func__); diff --git a/examples/lookup/lookup-stats.cpp b/examples/lookup/lookup-stats.cpp index 31f227773..41b62c2fe 100644 --- a/examples/lookup/lookup-stats.cpp +++ b/examples/lookup/lookup-stats.cpp @@ -34,11 +34,8 @@ int main(int argc, char ** argv){ GGML_ASSERT(llama_n_vocab(model) < (1 << 16)); // tokenize the prompt - const bool add_bos = llama_should_add_bos_token(model); - LOG("add_bos tgt: %d\n", add_bos); - std::vector inp; - inp = ::llama_tokenize(ctx, params.prompt, add_bos, true); + inp = ::llama_tokenize(ctx, params.prompt, true, true); llama_ngram_cache ngram_cache_context; llama_ngram_cache ngram_cache_dynamic; diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index 2e8c35de3..65ed408a2 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -42,11 +42,8 @@ int main(int argc, char ** argv){ GGML_ASSERT(llama_n_vocab(model) < (1 << 16)); // tokenize the prompt - const bool add_bos = llama_should_add_bos_token(model); - LOG("add_bos tgt: %d\n", add_bos); - std::vector inp; - inp = ::llama_tokenize(ctx, params.prompt, add_bos, true); + inp = ::llama_tokenize(ctx, params.prompt, true, true); llama_ngram_cache ngram_cache_context; llama_ngram_cache ngram_cache_dynamic; diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 711f162d7..249fc2bb6 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -246,6 +246,7 @@ int main(int argc, char ** argv) { } const bool add_bos = llama_should_add_bos_token(model); + GGML_ASSERT(llama_add_eos_token(model) != 1); LOG("add_bos: %d\n", add_bos); std::vector embd_inp; @@ -255,7 +256,7 @@ int main(int argc, char ** argv) { if (params.chatml) { params.prompt = "<|im_start|>system\n" + params.prompt + "<|im_end|>"; } - embd_inp = ::llama_tokenize(ctx, params.prompt, add_bos, true); + embd_inp = ::llama_tokenize(ctx, params.prompt, true, true); } else { LOG("use session tokens\n"); embd_inp = session_tokens; @@ -277,10 +278,10 @@ int main(int argc, char ** argv) { if (ctx_guidance) { LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(sparams.cfg_negative_prompt)); - guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, add_bos, true); + guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, true, true); LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp).c_str()); - std::vector original_inp = ::llama_tokenize(ctx, params.prompt, add_bos, true); + std::vector original_inp = ::llama_tokenize(ctx, params.prompt, true, true); LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str()); original_prompt_len = original_inp.size(); @@ -339,14 +340,14 @@ int main(int argc, char ** argv) { } // prefix & suffix for instruct mode - const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", add_bos, true); - const auto inp_sfx = ::llama_tokenize(ctx, "\n\n### Response:\n\n", false, true); + const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", true, true); + const auto inp_sfx = ::llama_tokenize(ctx, "\n\n### Response:\n\n", false, true); LOG("inp_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_pfx).c_str()); LOG("inp_sfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_sfx).c_str()); // chatml prefix & suffix - const auto cml_pfx = ::llama_tokenize(ctx, "\n<|im_start|>user\n", add_bos, true); + const auto cml_pfx = ::llama_tokenize(ctx, "\n<|im_start|>user\n", true, true); const auto cml_sfx = ::llama_tokenize(ctx, "<|im_end|>\n<|im_start|>assistant\n", false, true); LOG("cml_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, cml_pfx).c_str()); diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index c70385c62..bab79aaea 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -315,10 +315,11 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & // BOS tokens will be added for each chunk before eval const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx)); + GGML_ASSERT(llama_add_eos_token(llama_get_model(ctx)) != 1); fprintf(stderr, "%s: tokenizing the input ..\n", __func__); - std::vector tokens = ::llama_tokenize(ctx, params.prompt, add_bos); + std::vector tokens = ::llama_tokenize(ctx, params.prompt, true); const int n_ctx = llama_n_ctx(ctx); @@ -454,6 +455,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par // BOS tokens will be added for each chunk before eval const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx)); + GGML_ASSERT(llama_add_eos_token(llama_get_model(ctx)) != 1); std::ofstream logits_stream; if (!params.logits_file.empty()) { @@ -470,7 +472,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par auto tim1 = std::chrono::high_resolution_clock::now(); fprintf(stderr, "%s: tokenizing the input ..\n", __func__); - std::vector tokens = ::llama_tokenize(ctx, params.prompt, add_bos); + std::vector tokens = ::llama_tokenize(ctx, params.prompt, true); auto tim2 = std::chrono::high_resolution_clock::now(); fprintf(stderr, "%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast(tim2-tim1).count()); @@ -771,9 +773,6 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { const bool is_spm = llama_vocab_type(llama_get_model(ctx)) == LLAMA_VOCAB_TYPE_SPM; fprintf(stderr, "================================= is_spm = %d\n", is_spm); - // This is needed as usual for LLaMA models - const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx)); - // The tasks should be randomized so the score stabilizes quickly. bool randomize_tasks = true; @@ -818,7 +817,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { hs_cur.gold_ending_idx = std::stoi( prompt_lines[idx*6+1] ); for (size_t j = 0; j < 4; j++) { hs_cur.ending[j] = prompt_lines[idx*6+2+j]; - hs_cur.seq_tokens[j] = ::llama_tokenize(ctx, hs_cur.context + " " + hs_cur.ending[j], add_bos); + hs_cur.seq_tokens[j] = ::llama_tokenize(ctx, hs_cur.context + " " + hs_cur.ending[j], true); } // determine the common prefix of the endings @@ -837,7 +836,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { hs_cur.seq_tokens[2].size() - hs_cur.common_prefix + hs_cur.seq_tokens[3].size() - hs_cur.common_prefix; - //GGML_ASSERT(hs_cur.common_prefix >= ::llama_tokenize(ctx, hs_cur.context, add_bos).size()); + //GGML_ASSERT(hs_cur.common_prefix >= ::llama_tokenize(ctx, hs_cur.context, true).size()); // Delete the selected random example from the prompt if (randomize_tasks) { @@ -1110,12 +1109,9 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) { fprintf(stderr, "%s : tokenizing selected tasks\n", __func__); - // This is needed as usual for LLaMA models - const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx)); - for (auto & task : data) { - task.seq_tokens[0] = ::llama_tokenize(ctx, task.first + task.choices[0] + task.second, add_bos); - task.seq_tokens[1] = ::llama_tokenize(ctx, task.first + task.choices[1] + task.second, add_bos); + task.seq_tokens[0] = ::llama_tokenize(ctx, task.first + task.choices[0] + task.second, true); + task.seq_tokens[1] = ::llama_tokenize(ctx, task.first + task.choices[1] + task.second, true); task.common_prefix = 0; for (size_t k = 0; k < task.seq_tokens[0].size(); k++) { @@ -1130,8 +1126,8 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) { task.seq_tokens[0].size() - task.common_prefix + task.seq_tokens[1].size() - task.common_prefix; - task.n_base1 = ::llama_tokenize(ctx, task.first + task.choices[0], add_bos).size(); - task.n_base2 = ::llama_tokenize(ctx, task.first + task.choices[1], add_bos).size(); + task.n_base1 = ::llama_tokenize(ctx, task.first + task.choices[0], true).size(); + task.n_base2 = ::llama_tokenize(ctx, task.first + task.choices[1], true).size(); } fprintf(stderr, "%s : calculating winogrande score over selected tasks.\n", __func__); @@ -1322,7 +1318,7 @@ struct multiple_choice_task { std::vector log_probs; }; -static bool multiple_choice_prepare_one_task(llama_context * ctx, bool add_bos, multiple_choice_task& task, bool log_error) { +static bool multiple_choice_prepare_one_task(llama_context * ctx, multiple_choice_task& task, bool log_error) { if (task.question.empty() || task.mc1.answers.empty()) { if (log_error) { printf("%s: found bad task with empty question and/or answers\n", __func__); @@ -1337,7 +1333,7 @@ static bool multiple_choice_prepare_one_task(llama_context * ctx, bool add_bos, } return false; } - task.seq_tokens.emplace_back(::llama_tokenize(ctx, task.question + " " + answer, add_bos)); + task.seq_tokens.emplace_back(::llama_tokenize(ctx, task.question + " " + answer, true)); } auto min_len = task.seq_tokens.front().size(); for (auto& seq : task.seq_tokens) { @@ -1436,9 +1432,6 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params n_task = params.multiple_choice_tasks; } - // This is needed as usual for LLaMA models - const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx)); - printf("%s: preparing task data", __func__); fflush(stdout); if (n_task > 500) { @@ -1446,7 +1439,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params fflush(stdout); std::atomic counter(0); std::atomic n_bad(0); - auto prepare = [&counter, &n_bad, &tasks, ctx, add_bos] () { + auto prepare = [&counter, &n_bad, &tasks, ctx] () { int num_tasks = tasks.size(); int n_bad_local = 0; while (true) { @@ -1457,7 +1450,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params } int last = std::min(first + K_TOKEN_CHUNK, num_tasks); for (int i = first; i < last; ++i) { - if (!multiple_choice_prepare_one_task(ctx, add_bos, tasks[i], false)) ++n_bad_local; + if (!multiple_choice_prepare_one_task(ctx, tasks[i], false)) ++n_bad_local; } } }; @@ -1479,7 +1472,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params int i_task = 0; for (auto& task : tasks) { ++i_task; - if (!multiple_choice_prepare_one_task(ctx, add_bos, task, true)) { + if (!multiple_choice_prepare_one_task(ctx, task, true)) { return; } if (i_task%n_dot == 0) { @@ -1715,6 +1708,7 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) { const int num_batches = (n_ctx + n_batch - 1)/n_batch; const int nv = 2*((n_vocab + 1)/2) + 4; const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx)); + GGML_ASSERT(llama_add_eos_token(llama_get_model(ctx)) != 1); std::vector log_probs_uint16(size_t(n_ctx - 1 - n_ctx/2) * nv); std::vector kld_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 6c64fe3e1..2e791190b 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -689,6 +689,7 @@ struct server_context { n_ctx = llama_n_ctx(ctx); add_bos_token = llama_should_add_bos_token(model); + GGML_ASSERT(llama_add_eos_token(model) != 1); return true; } @@ -758,7 +759,7 @@ struct server_context { metrics.init(); } - std::vector tokenize(const json & json_prompt, bool add_bos) const { + std::vector tokenize(const json & json_prompt, bool add_special) const { // TODO: currently, we tokenize using special tokens by default // this is not always correct (see https://github.com/ggerganov/llama.cpp/pull/4160#issuecomment-1824826216) // but it's better compared to completely ignoring ChatML and other chat templates @@ -776,7 +777,7 @@ struct server_context { std::vector p; if (first) { - p = ::llama_tokenize(ctx, s, add_bos, TMP_FORCE_SPECIAL); + p = ::llama_tokenize(ctx, s, add_special, TMP_FORCE_SPECIAL); first = false; } else { p = ::llama_tokenize(ctx, s, false, TMP_FORCE_SPECIAL); @@ -793,7 +794,7 @@ struct server_context { } } else { auto s = json_prompt.template get(); - prompt_tokens = ::llama_tokenize(ctx, s, add_bos, TMP_FORCE_SPECIAL); + prompt_tokens = ::llama_tokenize(ctx, s, add_special, TMP_FORCE_SPECIAL); } return prompt_tokens; @@ -1058,7 +1059,7 @@ struct server_context { system_tokens.clear(); if (!system_prompt.empty()) { - system_tokens = ::llama_tokenize(ctx, system_prompt, add_bos_token); + system_tokens = ::llama_tokenize(ctx, system_prompt, true); llama_batch_clear(batch); @@ -1914,7 +1915,7 @@ struct server_context { prefix_tokens.push_back(llama_token_middle(model)); prompt_tokens = prefix_tokens; } else { - prompt_tokens = tokenize(slot.prompt, system_prompt.empty() && add_bos_token); // add BOS if there isn't system prompt + prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt } slot.n_past = 0; diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 6e0815b36..6a7367b0c 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -76,6 +76,28 @@ int main(int argc, char ** argv) { params.n_threads_batch = params.n_threads_batch_draft; std::tie(model_dft, ctx_dft) = llama_init_from_gpt_params(params); + const bool vocab_type_tgt = llama_vocab_type(model_tgt); + LOG("vocab_type tgt: %d\n", vocab_type_tgt); + + const bool vocab_type_dft = llama_vocab_type(model_dft); + LOG("vocab_type dft: %d\n", vocab_type_dft); + + if (vocab_type_tgt != vocab_type_dft) { + fprintf(stderr, "%s: error: draft model vocab type must match target model to use speculation but ", __func__); + fprintf(stderr, "vocab_type_dft = %d while vocab_type_tgt = %d\n", vocab_type_dft, vocab_type_tgt); + return 1; + } + + if ( + llama_add_bos_token(model_tgt) != llama_add_bos_token(model_dft) || + llama_add_eos_token(model_tgt) != llama_add_eos_token(model_dft) || + llama_token_bos(model_tgt) != llama_token_bos(model_dft) || + llama_token_eos(model_tgt) != llama_token_eos(model_dft) + ) { + fprintf(stderr, "%s: error: draft model special tokens must match target model to use speculation\n", __func__); + return 1; + } + { const int n_vocab_tgt = llama_n_vocab(model_tgt); const int n_vocab_dft = llama_n_vocab(model_dft); @@ -105,20 +127,8 @@ int main(int argc, char ** argv) { // Tokenize the prompt - const bool add_bos_tgt = llama_should_add_bos_token(model_tgt); - LOG("add_bos tgt: %d\n", add_bos_tgt); - - const bool add_bos_dft = llama_should_add_bos_token(model_dft); - LOG("add_bos dft: %d\n", add_bos_dft); - - if (add_bos_tgt != add_bos_dft) { - fprintf(stderr, "%s: error: draft model add_bos must match target model to use speculation but ", __func__); - fprintf(stderr, "add_bos_dft = %d while add_bos_tgt = %d\n", add_bos_dft, add_bos_tgt); - return 1; - } - std::vector inp; - inp = ::llama_tokenize(ctx_tgt, params.prompt, add_bos_tgt, true); + inp = ::llama_tokenize(ctx_tgt, params.prompt, true, true); const int max_context_size = llama_n_ctx(ctx_tgt); const int max_tokens_list_size = max_context_size - 4; diff --git a/examples/tokenize/tokenize.cpp b/examples/tokenize/tokenize.cpp index d95a92475..8b1baea80 100644 --- a/examples/tokenize/tokenize.cpp +++ b/examples/tokenize/tokenize.cpp @@ -26,11 +26,9 @@ int main(int argc, char ** argv) { llama_context_params ctx_params = llama_context_default_params(); llama_context * ctx = llama_new_context_with_model(model, ctx_params); - const bool add_bos = llama_should_add_bos_token(model); - std::vector tokens; - tokens = ::llama_tokenize(model, prompt, add_bos, true); + tokens = ::llama_tokenize(model, prompt, true, true); for (int i = 0; i < (int) tokens.size(); i++) { if (printing_ids) { diff --git a/llama.cpp b/llama.cpp index 6a090d1bb..8dbf47486 100644 --- a/llama.cpp +++ b/llama.cpp @@ -318,6 +318,8 @@ enum llm_kv { LLM_KV_TOKENIZER_UNK_ID, LLM_KV_TOKENIZER_SEP_ID, LLM_KV_TOKENIZER_PAD_ID, + LLM_KV_TOKENIZER_CLS_ID, + LLM_KV_TOKENIZER_MASK_ID, LLM_KV_TOKENIZER_ADD_BOS, LLM_KV_TOKENIZER_ADD_EOS, LLM_KV_TOKENIZER_ADD_PREFIX, @@ -388,6 +390,8 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_TOKENIZER_UNK_ID, "tokenizer.ggml.unknown_token_id" }, { LLM_KV_TOKENIZER_SEP_ID, "tokenizer.ggml.seperator_token_id" }, { LLM_KV_TOKENIZER_PAD_ID, "tokenizer.ggml.padding_token_id" }, + { LLM_KV_TOKENIZER_CLS_ID, "tokenizer.ggml.cls_token_id" }, + { LLM_KV_TOKENIZER_MASK_ID, "tokenizer.ggml.mask_token_id" }, { LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token" }, { LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token" }, { LLM_KV_TOKENIZER_ADD_PREFIX, "tokenizer.ggml.add_space_prefix" }, @@ -2018,11 +2022,13 @@ struct llama_vocab { std::map, int> bpe_ranks; // default LLaMA special tokens - id special_bos_id = 1; - id special_eos_id = 2; - id special_unk_id = 0; - id special_sep_id = -1; - id special_pad_id = -1; + id special_bos_id = 1; + id special_eos_id = 2; + id special_unk_id = 0; + id special_sep_id = -1; + id special_pad_id = -1; + id special_cls_id = -1; + id special_mask_id = -1; int special_add_bos = -1; // -1 unknown, 1 add, 0 don't add. int special_add_eos = -1; // -1 unknown, 1 add, 0 don't add. @@ -3978,7 +3984,9 @@ static void llm_load_hparams( } // TODO: This should probably be in llama.h -static std::vector llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos, bool special = false); +static std::vector llama_tokenize_internal( + const llama_vocab & vocab, std::string raw_text, bool add_special, bool parse_special = false +); static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch); static void llm_load_vocab( @@ -4000,23 +4008,27 @@ static void llm_load_vocab( vocab.type = LLAMA_VOCAB_TYPE_NONE; // default special tokens - vocab.special_bos_id = -1; - vocab.special_eos_id = -1; - vocab.special_unk_id = -1; - vocab.special_sep_id = -1; - vocab.special_pad_id = -1; - vocab.linefeed_id = -1; + vocab.special_bos_id = -1; + vocab.special_eos_id = -1; + vocab.special_unk_id = -1; + vocab.special_sep_id = -1; + vocab.special_pad_id = -1; + vocab.special_cls_id = -1; + vocab.special_mask_id = -1; + vocab.linefeed_id = -1; return; } else if (tokenizer_name == "llama") { vocab.type = LLAMA_VOCAB_TYPE_SPM; // default special tokens - vocab.special_bos_id = 1; - vocab.special_eos_id = 2; - vocab.special_unk_id = 0; - vocab.special_sep_id = -1; - vocab.special_pad_id = -1; + vocab.special_bos_id = 1; + vocab.special_eos_id = 2; + vocab.special_unk_id = 0; + vocab.special_sep_id = -1; + vocab.special_pad_id = -1; + vocab.special_cls_id = -1; + vocab.special_mask_id = -1; const int add_space_prefix_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str()); if (add_space_prefix_keyidx != -1) { @@ -4051,20 +4063,24 @@ static void llm_load_vocab( } // default special tokens - vocab.special_bos_id = 11; - vocab.special_eos_id = 11; - vocab.special_unk_id = -1; - vocab.special_sep_id = -1; - vocab.special_pad_id = -1; + vocab.special_bos_id = 11; + vocab.special_eos_id = 11; + vocab.special_unk_id = -1; + vocab.special_sep_id = -1; + vocab.special_pad_id = -1; + vocab.special_cls_id = -1; + vocab.special_mask_id = -1; } else if (tokenizer_name == "bert") { vocab.type = LLAMA_VOCAB_TYPE_WPM; // default special tokens - vocab.special_bos_id = 101; - vocab.special_eos_id = 102; - vocab.special_unk_id = 100; - vocab.special_sep_id = -1; - vocab.special_pad_id = -1; + vocab.special_bos_id = -1; + vocab.special_eos_id = -1; + vocab.special_unk_id = 100; + vocab.special_sep_id = 102; + vocab.special_pad_id = 0; + vocab.special_cls_id = 101; + vocab.special_mask_id = 103; vocab.add_space_prefix = false; } else { LLAMA_LOG_WARN("%s: unknown tokenizer: '%s'", __func__, tokenizer_name.c_str()); @@ -4127,11 +4143,13 @@ static void llm_load_vocab( // special tokens { const std::vector> special_token_types = { - { LLM_KV_TOKENIZER_BOS_ID, vocab.special_bos_id }, - { LLM_KV_TOKENIZER_EOS_ID, vocab.special_eos_id }, - { LLM_KV_TOKENIZER_UNK_ID, vocab.special_unk_id }, - { LLM_KV_TOKENIZER_SEP_ID, vocab.special_sep_id }, - { LLM_KV_TOKENIZER_PAD_ID, vocab.special_pad_id }, + { LLM_KV_TOKENIZER_BOS_ID, vocab.special_bos_id }, + { LLM_KV_TOKENIZER_EOS_ID, vocab.special_eos_id }, + { LLM_KV_TOKENIZER_UNK_ID, vocab.special_unk_id }, + { LLM_KV_TOKENIZER_SEP_ID, vocab.special_sep_id }, + { LLM_KV_TOKENIZER_PAD_ID, vocab.special_pad_id }, + { LLM_KV_TOKENIZER_CLS_ID, vocab.special_cls_id }, + { LLM_KV_TOKENIZER_MASK_ID, vocab.special_mask_id }, }; for (const auto & it : special_token_types) { const std::string & key = kv(std::get<0>(it)); @@ -4323,12 +4341,14 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, model.name.c_str()); // special tokens - if (vocab.special_bos_id != -1) { LLAMA_LOG_INFO( "%s: BOS token = %d '%s'\n", __func__, vocab.special_bos_id, vocab.id_to_token[vocab.special_bos_id].text.c_str() ); } - if (vocab.special_eos_id != -1) { LLAMA_LOG_INFO( "%s: EOS token = %d '%s'\n", __func__, vocab.special_eos_id, vocab.id_to_token[vocab.special_eos_id].text.c_str() ); } - if (vocab.special_unk_id != -1) { LLAMA_LOG_INFO( "%s: UNK token = %d '%s'\n", __func__, vocab.special_unk_id, vocab.id_to_token[vocab.special_unk_id].text.c_str() ); } - if (vocab.special_sep_id != -1) { LLAMA_LOG_INFO( "%s: SEP token = %d '%s'\n", __func__, vocab.special_sep_id, vocab.id_to_token[vocab.special_sep_id].text.c_str() ); } - if (vocab.special_pad_id != -1) { LLAMA_LOG_INFO( "%s: PAD token = %d '%s'\n", __func__, vocab.special_pad_id, vocab.id_to_token[vocab.special_pad_id].text.c_str() ); } - if (vocab.linefeed_id != -1) { LLAMA_LOG_INFO( "%s: LF token = %d '%s'\n", __func__, vocab.linefeed_id, vocab.id_to_token[vocab.linefeed_id].text.c_str() ); } + if (vocab.special_bos_id != -1) { LLAMA_LOG_INFO( "%s: BOS token = %d '%s'\n", __func__, vocab.special_bos_id, vocab.id_to_token[vocab.special_bos_id].text.c_str() ); } + if (vocab.special_eos_id != -1) { LLAMA_LOG_INFO( "%s: EOS token = %d '%s'\n", __func__, vocab.special_eos_id, vocab.id_to_token[vocab.special_eos_id].text.c_str() ); } + if (vocab.special_unk_id != -1) { LLAMA_LOG_INFO( "%s: UNK token = %d '%s'\n", __func__, vocab.special_unk_id, vocab.id_to_token[vocab.special_unk_id].text.c_str() ); } + if (vocab.special_sep_id != -1) { LLAMA_LOG_INFO( "%s: SEP token = %d '%s'\n", __func__, vocab.special_sep_id, vocab.id_to_token[vocab.special_sep_id].text.c_str() ); } + if (vocab.special_pad_id != -1) { LLAMA_LOG_INFO( "%s: PAD token = %d '%s'\n", __func__, vocab.special_pad_id, vocab.id_to_token[vocab.special_pad_id].text.c_str() ); } + if (vocab.special_cls_id != -1) { LLAMA_LOG_INFO( "%s: CLS token = %d '%s'\n", __func__, vocab.special_cls_id, vocab.id_to_token[vocab.special_cls_id].text.c_str() ); } + if (vocab.special_mask_id != -1) { LLAMA_LOG_INFO( "%s: MASK token = %d '%s'\n", __func__, vocab.special_mask_id, vocab.id_to_token[vocab.special_mask_id].text.c_str() ); } + if (vocab.linefeed_id != -1) { LLAMA_LOG_INFO( "%s: LF token = %d '%s'\n", __func__, vocab.linefeed_id, vocab.id_to_token[vocab.linefeed_id].text.c_str() ); } } // Returns false if cancelled by progress_callback @@ -11358,9 +11378,6 @@ struct llm_tokenizer_wpm { output.push_back(vocab.special_unk_id); } } - - // append eos token - output.push_back(vocab.special_eos_id); } std::vector preprocess(const std::string & text) { @@ -11565,30 +11582,28 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list< } } -static std::vector llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos, bool special) { +static std::vector llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool add_special, bool parse_special) { std::vector output; - - // OG tokenizer behavior: - // - // tokenizer.encode('', add_bos=True) returns [1] - // tokenizer.encode('', add_bos=False) returns [] - - if (bos && vocab.special_bos_id != -1) { - output.push_back(vocab.special_bos_id); - } - - if (raw_text.empty()) { - return output; - } - std::forward_list fragment_buffer; - fragment_buffer.emplace_front(raw_text, 0, raw_text.length()); - if (special) tokenizer_st_partition(vocab, fragment_buffer); + if (!raw_text.empty()) { + fragment_buffer.emplace_front(raw_text, 0, raw_text.length()); + if (parse_special) tokenizer_st_partition(vocab, fragment_buffer); + } switch (vocab.type) { case LLAMA_VOCAB_TYPE_SPM: { + // OG tokenizer behavior: + // + // tokenizer.encode('', add_special_tokens=True) returns [1] + // tokenizer.encode('', add_special_tokens=False) returns [] + + if (add_special && vocab.special_add_bos != 0) { + GGML_ASSERT(vocab.special_bos_id != -1); + output.push_back(vocab.special_bos_id); + } + for (const auto & fragment : fragment_buffer) { if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { // without adding this leading whitespace, we do not get the same results as the original tokenizer @@ -11614,9 +11629,19 @@ static std::vector llama_tokenize_internal(const llama_vocab & output.push_back(fragment.token); } } + + if (add_special && vocab.special_add_eos == 1) { + GGML_ASSERT(vocab.special_eos_id != -1); + output.push_back(vocab.special_eos_id); + } } break; case LLAMA_VOCAB_TYPE_BPE: { + if (add_special && vocab.special_add_bos == 1) { + GGML_ASSERT(vocab.special_bos_id != -1); + output.push_back(vocab.special_bos_id); + } + for (const auto & fragment : fragment_buffer) { if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length); @@ -11630,9 +11655,16 @@ static std::vector llama_tokenize_internal(const llama_vocab & output.push_back(fragment.token); } } + + GGML_ASSERT(vocab.special_add_eos != 1); } break; case LLAMA_VOCAB_TYPE_WPM: { + if (add_special) { + GGML_ASSERT(vocab.special_cls_id != -1); + output.push_back(vocab.special_cls_id); + } + for (const auto & fragment : fragment_buffer) { if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length); @@ -11646,6 +11678,11 @@ static std::vector llama_tokenize_internal(const llama_vocab & output.push_back(fragment.token); } } + + if (add_special) { + GGML_ASSERT(vocab.special_sep_id != -1); + output.push_back(vocab.special_sep_id); + } } break; case LLAMA_VOCAB_TYPE_NONE: GGML_ASSERT(false); @@ -16104,6 +16141,14 @@ llama_token llama_token_eos(const struct llama_model * model) { return model->vocab.special_eos_id; } +llama_token llama_token_cls(const struct llama_model * model) { + return model->vocab.special_cls_id; +} + +llama_token llama_token_sep(const struct llama_model * model) { + return model->vocab.special_sep_id; +} + llama_token llama_token_nl(const struct llama_model * model) { return model->vocab.linefeed_id; } @@ -16138,9 +16183,9 @@ int32_t llama_tokenize( int32_t text_len, llama_token * tokens, int32_t n_tokens_max, - bool add_bos, - bool special) { - auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_bos, special); + bool add_special, + bool parse_special) { + auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_special, parse_special); if (n_tokens_max < (int) res.size()) { // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__); diff --git a/llama.h b/llama.h index 6a5bbe26d..b770a275f 100644 --- a/llama.h +++ b/llama.h @@ -786,6 +786,8 @@ extern "C" { // Special tokens LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence + LLAMA_API llama_token llama_token_cls(const struct llama_model * model); // classification + LLAMA_API llama_token llama_token_sep(const struct llama_model * model); // sentence separator LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line // Returns -1 if unknown, 1 for true or 0 for false. @@ -808,16 +810,16 @@ extern "C" { /// @param tokens The tokens pointer must be large enough to hold the resulting tokens. /// @return Returns the number of tokens on success, no more than n_tokens_max /// @return Returns a negative number on failure - the number of tokens that would have been returned - /// @param special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. - /// Does not insert a leading space. + /// @param parse_special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated + /// as plaintext. Does not insert a leading space. LLAMA_API int32_t llama_tokenize( const struct llama_model * model, const char * text, int32_t text_len, llama_token * tokens, int32_t n_tokens_max, - bool add_bos, - bool special); + bool add_special, + bool parse_special); // Token Id -> Piece. // Uses the vocabulary in the provided context.