From 807b0c49ff7071094f97ebc3a0a8e2b9e274f503 Mon Sep 17 00:00:00 2001 From: fairydreaming <166155368+fairydreaming@users.noreply.github.com> Date: Thu, 4 Jul 2024 15:46:11 +0200 Subject: [PATCH] Inference support for T5 and FLAN-T5 model families (#5763) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * llama : add inference support and model types for T5 and FLAN-T5 model families * llama : add new API functions to support encoder-decoder models: llama_encode(), llama_model_has_encoder(), llama_model_decoder_start_token() * common, llama-cli, llama-batched : add support for encoder-decoder models * convert-hf : handle shared token embeddings tensors in T5Model * convert-hf : add support for SentencePiece BPE tokenizer in T5Model (for Pile-T5 models) * convert-hf : add MT5ForConditionalGeneration and UMT5ForConditionalGeneration to architectures supported by T5Model * convert : add t5 tokenizer tests, use "slow" HF tokenizer for t5 --------- Co-authored-by: Stanisław Szymczyk Co-authored-by: Georgi Gerganov --- common/common.cpp | 19 +- convert-hf-to-gguf-update.py | 19 +- convert-hf-to-gguf.py | 46 +- examples/batched/batched.cpp | 34 +- examples/main/main.cpp | 22 +- include/llama.h | 15 + models/ggml-vocab-bert-bge.gguf.inp | 2 + models/ggml-vocab-bert-bge.gguf.out | 1 + models/ggml-vocab-command-r.gguf.inp | 2 + models/ggml-vocab-command-r.gguf.out | 1 + models/ggml-vocab-deepseek-coder.gguf.inp | 2 + models/ggml-vocab-deepseek-coder.gguf.out | 1 + models/ggml-vocab-deepseek-llm.gguf.inp | 2 + models/ggml-vocab-deepseek-llm.gguf.out | 1 + models/ggml-vocab-falcon.gguf.inp | 2 + models/ggml-vocab-falcon.gguf.out | 1 + models/ggml-vocab-gpt-2.gguf.inp | 2 + models/ggml-vocab-gpt-2.gguf.out | 1 + models/ggml-vocab-llama-bpe.gguf.inp | 2 + models/ggml-vocab-llama-bpe.gguf.out | 1 + models/ggml-vocab-llama-spm.gguf.inp | 2 + models/ggml-vocab-llama-spm.gguf.out | 1 + models/ggml-vocab-mpt.gguf.inp | 2 + models/ggml-vocab-mpt.gguf.out | 1 + models/ggml-vocab-phi-3.gguf.inp | 2 + models/ggml-vocab-phi-3.gguf.out | 1 + models/ggml-vocab-qwen2.gguf.inp | 2 + models/ggml-vocab-qwen2.gguf.out | 1 + models/ggml-vocab-refact.gguf.inp | 2 + models/ggml-vocab-refact.gguf.out | 1 + models/ggml-vocab-starcoder.gguf.inp | 2 + models/ggml-vocab-starcoder.gguf.out | 1 + src/llama.cpp | 783 +++++++++++++++++++++- 33 files changed, 946 insertions(+), 31 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 2c05a4d4a..4138dc3b6 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -2070,7 +2070,24 @@ std::tuple llama_init_from_gpt_par if (params.warmup) { LOG("warming up the model with an empty run\n"); - std::vector tmp = { llama_token_bos(model), llama_token_eos(model), }; + std::vector tmp; + llama_token bos = llama_token_bos(model); + llama_token eos = llama_token_eos(model); + // some models (e.g. T5) don't have a BOS token + if (bos != -1) { + tmp.push_back(bos); + } + tmp.push_back(eos); + + if (llama_model_has_encoder(model)) { + llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size(), 0, 0)); + llama_token decoder_start_token_id = llama_model_decoder_start_token(model); + if (decoder_start_token_id == -1) { + decoder_start_token_id = bos; + } + tmp.clear(); + tmp.push_back(decoder_start_token_id); + } llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0)); llama_kv_cache_clear(lctx); llama_synchronize(lctx); diff --git a/convert-hf-to-gguf-update.py b/convert-hf-to-gguf-update.py index 9eb406cb8..21a306255 100755 --- a/convert-hf-to-gguf-update.py +++ b/convert-hf-to-gguf-update.py @@ -45,6 +45,7 @@ class TOKENIZER_TYPE(IntEnum): SPM = auto() BPE = auto() WPM = auto() + UGM = auto() # TODO: this string has to exercise as much pre-tokenizer functionality as possible @@ -89,6 +90,7 @@ models = [ {"name": "gemma", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/google/gemma-2b", }, {"name": "gemma-2", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/google/gemma-2-9b", }, {"name": "jais", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/core42/jais-13b", }, + {"name": "t5", "tokt": TOKENIZER_TYPE.UGM, "repo": "https://huggingface.co/google-t5/t5-small", }, ] @@ -110,9 +112,13 @@ def download_model(model): os.makedirs(f"models/tokenizers/{name}", exist_ok=True) files = ["config.json", "tokenizer.json", "tokenizer_config.json"] + if tokt == TOKENIZER_TYPE.SPM: files.append("tokenizer.model") + if tokt == TOKENIZER_TYPE.UGM: + files.append("spiece.model") + for file in files: save_path = f"models/tokenizers/{name}/{file}" if os.path.isfile(save_path): @@ -135,7 +141,7 @@ for model in models: name = model["name"] tokt = model["tokt"] - if tokt == TOKENIZER_TYPE.SPM: + if tokt == TOKENIZER_TYPE.SPM or tokt == TOKENIZER_TYPE.UGM: continue # Skip if the tokenizer folder does not exist or there are other download issues previously @@ -145,7 +151,10 @@ for model in models: # create the tokenizer try: - tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}") + if name == "t5": + tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}", use_fast=False) + else: + tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}") except OSError as e: logger.error(f"Error loading tokenizer for model {name}. The model may not exist or is not accessible with the provided token. Error: {e}") continue # Skip to the next model if the tokenizer can't be loaded @@ -266,6 +275,7 @@ tests = [ "\n =", "' era", "Hello, y'all! How are you 😁 ?我想在apple工作1314151天~", + "!!!!!!", "3", "33", "333", @@ -304,7 +314,10 @@ for model in models: # create the tokenizer try: - tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}") + if name == "t5": + tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}", use_fast=False) + else: + tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}") except OSError as e: logger.error(f"Failed to load tokenizer for model {name}. Error: {e}") continue # Skip this model and continue with the next one in the loop diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index bae145589..a1d165023 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -2853,11 +2853,17 @@ class DeepseekV2Model(Model): raise ValueError(f"Unprocessed experts: {experts}") -@Model.register("T5ForConditionalGeneration") @Model.register("T5WithLMHeadModel") +@Model.register("T5ForConditionalGeneration") +@Model.register("MT5ForConditionalGeneration") +@Model.register("UMT5ForConditionalGeneration") class T5Model(Model): model_arch = gguf.MODEL_ARCH.T5 + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.shared_token_embeddings_found = False + def set_vocab(self): # to avoid TypeError: Descriptors cannot be created directly # exception when importing sentencepiece_model_pb2 @@ -2865,17 +2871,29 @@ class T5Model(Model): from sentencepiece import SentencePieceProcessor from sentencepiece import sentencepiece_model_pb2 as model - tokenizer_path = self.dir_model / 'spiece.model' + tokenizer_path = self.dir_model / 'tokenizer.model' + + # many older models use spiece.model tokenizer model filename + if not tokenizer_path.is_file(): + tokenizer_path = self.dir_model / 'spiece.model' if not tokenizer_path.is_file(): raise FileNotFoundError(f"File not found: {tokenizer_path}") sentencepiece_model = model.ModelProto() sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read()) + + # some models like Pile-T5 family use BPE tokenizer instead of Unigram + if sentencepiece_model.trainer_spec.model_type == 2: # BPE + # assure the tokenizer model file name is correct + assert tokenizer_path.name == 'tokenizer.model' + return self._set_vocab_sentencepiece() + else: + assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM + add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix remove_whitespaces = sentencepiece_model.normalizer_spec.remove_extra_whitespaces precompiled_charsmap = sentencepiece_model.normalizer_spec.precompiled_charsmap - assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM tokenizer = SentencePieceProcessor() tokenizer.LoadFromFile(str(tokenizer_path)) @@ -2945,7 +2963,10 @@ class T5Model(Model): def set_gguf_parameters(self): self.gguf_writer.add_name("T5") - self.gguf_writer.add_context_length(self.hparams["n_positions"]) + if (n_ctx := self.find_hparam(["n_positions"], optional=True)) is None: + logger.warning("Couldn't find context length in config.json, assuming default value of 512") + n_ctx = 512 + self.gguf_writer.add_context_length(n_ctx) self.gguf_writer.add_embedding_length(self.hparams["d_model"]) self.gguf_writer.add_feed_forward_length(self.hparams["d_ff"]) self.gguf_writer.add_block_count(self.hparams["num_layers"]) @@ -2961,12 +2982,17 @@ class T5Model(Model): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: del bid # unused - # Sometimes T5 and Flan-T5 based models contain "encoder.embed_tokens.weight" tensor or - # "decoder.embed_tokens.weight" tensors that are duplicates of "shared.weight" tensor - # To prevent errors caused by an unnecessary unmapped tensor, skip both of them and use only "shared.weight". - if name == "decoder.embed_tokens.weight" or name == "encoder.embed_tokens.weight": - logger.debug(f"Skipping tensor {name!r} in safetensors so that convert can end normally.") - return [] + # T5 based models contain shared token embeddings tensors saved randomly as either "encoder.embed_tokens.weight", + # "decoder.embed_tokens.weight" or "shared.weight" tensor. In some models there are even multiple of them stored + # in the safetensors files. We use the first tensor from these three as the token embeddings for both encoder + # and decoder and ignore the remaining ones. + if name in ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "shared.weight"]: + if not self.shared_token_embeddings_found: + name = "shared.weight" + self.shared_token_embeddings_found = True + else: + logger.debug(f"Skipping shared tensor {name!r} in safetensors so that convert can end normally.") + return [] return [(self.map_tensor_name(name), data_torch)] diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 62d9b144d..2442e954d 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -93,14 +93,34 @@ int main(int argc, char ** argv) { // create a llama_batch // we use this object to submit token data for decoding - llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t)n_parallel), 0, 1); + llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t) n_parallel), 0, n_parallel); + + std::vector seq_ids(n_parallel, 0); + for (int32_t i = 0; i < n_parallel; ++i) { + seq_ids[i] = i; + } // evaluate the initial prompt for (size_t i = 0; i < tokens_list.size(); ++i) { - llama_batch_add(batch, tokens_list[i], i, { 0 }, false); + llama_batch_add(batch, tokens_list[i], i, seq_ids, false); } GGML_ASSERT(batch.n_tokens == (int) tokens_list.size()); + if (llama_model_has_encoder(model)) { + if (llama_encode(ctx, batch)) { + LOG_TEE("%s : failed to eval\n", __func__); + return 1; + } + + llama_token decoder_start_token_id = llama_model_decoder_start_token(model); + if (decoder_start_token_id == -1) { + decoder_start_token_id = llama_token_bos(model); + } + + llama_batch_clear(batch); + llama_batch_add(batch, decoder_start_token_id, 0, seq_ids, false); + } + // llama_decode will output logits only for the last token of the prompt batch.logits[batch.n_tokens - 1] = true; @@ -109,11 +129,11 @@ int main(int argc, char ** argv) { return 1; } - // assign the system KV cache to all parallel sequences - // this way, the parallel sequences will "reuse" the prompt tokens without having to copy them - for (int32_t i = 1; i < n_parallel; ++i) { - llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); - } + //// assign the system KV cache to all parallel sequences + //// this way, the parallel sequences will "reuse" the prompt tokens without having to copy them + //for (int32_t i = 1; i < n_parallel; ++i) { + // llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); + //} if (n_parallel > 1) { LOG_TEE("\n\n%s: generating %d sequences ...\n", __func__, n_parallel); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index d512953b9..22bb37889 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -255,7 +255,9 @@ int main(int argc, char ** argv) { } const bool add_bos = llama_should_add_bos_token(model); - GGML_ASSERT(llama_add_eos_token(model) != 1); + if (!llama_model_has_encoder(model)) { + GGML_ASSERT(llama_add_eos_token(model) != 1); + } LOG("add_bos: %d\n", add_bos); std::vector embd_inp; @@ -517,6 +519,24 @@ int main(int argc, char ** argv) { exit(1); } + if (llama_model_has_encoder(model)) { + int enc_input_size = embd_inp.size(); + llama_token * enc_input_buf = embd_inp.data(); + + if (llama_encode(ctx, llama_batch_get_one(enc_input_buf, enc_input_size, 0, 0))) { + LOG_TEE("%s : failed to eval\n", __func__); + return 1; + } + + llama_token decoder_start_token_id = llama_model_decoder_start_token(model); + if (decoder_start_token_id == -1) { + decoder_start_token_id = llama_token_bos(model); + } + + embd_inp.clear(); + embd_inp.push_back(decoder_start_token_id); + } + while ((n_remain != 0 && !is_antiprompt) || params.interactive) { // predict if (!embd.empty()) { diff --git a/include/llama.h b/include/llama.h index c5b618292..aca79d2b8 100644 --- a/include/llama.h +++ b/include/llama.h @@ -485,6 +485,13 @@ extern "C" { // Get a llama model tensor LLAMA_API struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name); + // Returns true if the model contains an encoder that requires llama_encode() call + LLAMA_API bool llama_model_has_encoder(const struct llama_model * model); + + // For encoder-decoder models, this function returns id of the token that must be provided + // to the decoder to start generating output sequence. For other models, it returns -1. + LLAMA_API llama_token llama_model_decoder_start_token(const struct llama_model * model); + // Returns 0 on success LLAMA_API uint32_t llama_model_quantize( const char * fname_inp, @@ -770,6 +777,14 @@ extern "C" { // Frees a batch of tokens allocated with llama_batch_init() LLAMA_API void llama_batch_free(struct llama_batch batch); + // Processes a batch of tokens with the ecoder part of the encoder-decoder model. + // Stores the encoder output internally for later use by the decoder cross-attention layers. + // 0 - success + // < 0 - error + LLAMA_API int32_t llama_encode( + struct llama_context * ctx, + struct llama_batch batch); + // Positive return values does not mean a fatal error, but rather a warning. // 0 - success // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context) diff --git a/models/ggml-vocab-bert-bge.gguf.inp b/models/ggml-vocab-bert-bge.gguf.inp index 5b4aeb31a..9baf7d77a 100644 --- a/models/ggml-vocab-bert-bge.gguf.inp +++ b/models/ggml-vocab-bert-bge.gguf.inp @@ -73,6 +73,8 @@ __ggml_vocab_test__ __ggml_vocab_test__ Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ __ggml_vocab_test__ +!!!!!! +__ggml_vocab_test__ 3 __ggml_vocab_test__ 33 diff --git a/models/ggml-vocab-bert-bge.gguf.out b/models/ggml-vocab-bert-bge.gguf.out index 82d4ed1c1..a62566ce7 100644 --- a/models/ggml-vocab-bert-bge.gguf.out +++ b/models/ggml-vocab-bert-bge.gguf.out @@ -31,6 +31,7 @@ 1027 1005 3690 7592 1010 1061 1005 2035 999 2129 2024 2017 100 1029 1855 100 100 6207 100 100 14677 23632 22203 1811 1995 + 999 999 999 999 999 999 1017 3943 21211 diff --git a/models/ggml-vocab-command-r.gguf.inp b/models/ggml-vocab-command-r.gguf.inp index 5b4aeb31a..9baf7d77a 100644 --- a/models/ggml-vocab-command-r.gguf.inp +++ b/models/ggml-vocab-command-r.gguf.inp @@ -73,6 +73,8 @@ __ggml_vocab_test__ __ggml_vocab_test__ Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ __ggml_vocab_test__ +!!!!!! +__ggml_vocab_test__ 3 __ggml_vocab_test__ 33 diff --git a/models/ggml-vocab-command-r.gguf.out b/models/ggml-vocab-command-r.gguf.out index 939b9dc30..3f6b41888 100644 --- a/models/ggml-vocab-command-r.gguf.out +++ b/models/ggml-vocab-command-r.gguf.out @@ -31,6 +31,7 @@ 206 1857 14 4515 28339 19 1770 14 1954 8 4070 1955 1933 80503 231 5691 12081 13336 2648 29325 14315 24 26 24 27 24 28 24 5123 18372 + 57178 10251 26 26 26 26 26 26 diff --git a/models/ggml-vocab-deepseek-coder.gguf.inp b/models/ggml-vocab-deepseek-coder.gguf.inp index 5b4aeb31a..9baf7d77a 100644 --- a/models/ggml-vocab-deepseek-coder.gguf.inp +++ b/models/ggml-vocab-deepseek-coder.gguf.inp @@ -73,6 +73,8 @@ __ggml_vocab_test__ __ggml_vocab_test__ Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ __ggml_vocab_test__ +!!!!!! +__ggml_vocab_test__ 3 __ggml_vocab_test__ 33 diff --git a/models/ggml-vocab-deepseek-coder.gguf.out b/models/ggml-vocab-deepseek-coder.gguf.out index a43e3f0f1..52c4111a1 100644 --- a/models/ggml-vocab-deepseek-coder.gguf.out +++ b/models/ggml-vocab-deepseek-coder.gguf.out @@ -31,6 +31,7 @@ 185 405 6 2895 17535 11 320 6 435 0 1717 417 340 12394 233 210 3015 19100 608 9413 2668 16 18 16 19 16 20 16 1393 169 121 239 + 15330 3023 18 18 18 18 18 18 diff --git a/models/ggml-vocab-deepseek-llm.gguf.inp b/models/ggml-vocab-deepseek-llm.gguf.inp index 5b4aeb31a..9baf7d77a 100644 --- a/models/ggml-vocab-deepseek-llm.gguf.inp +++ b/models/ggml-vocab-deepseek-llm.gguf.inp @@ -73,6 +73,8 @@ __ggml_vocab_test__ __ggml_vocab_test__ Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ __ggml_vocab_test__ +!!!!!! +__ggml_vocab_test__ 3 __ggml_vocab_test__ 33 diff --git a/models/ggml-vocab-deepseek-llm.gguf.out b/models/ggml-vocab-deepseek-llm.gguf.out index d31ac1cc6..0191b7a11 100644 --- a/models/ggml-vocab-deepseek-llm.gguf.out +++ b/models/ggml-vocab-deepseek-llm.gguf.out @@ -31,6 +31,7 @@ 185 403 6 2906 17464 11 320 6 436 0 1724 418 340 33701 210 3025 19017 612 9407 2681 16 18 16 19 16 20 16 1398 68940 239 + 15278 3033 18 18 18 18 18 18 diff --git a/models/ggml-vocab-falcon.gguf.inp b/models/ggml-vocab-falcon.gguf.inp index 5b4aeb31a..9baf7d77a 100644 --- a/models/ggml-vocab-falcon.gguf.inp +++ b/models/ggml-vocab-falcon.gguf.inp @@ -73,6 +73,8 @@ __ggml_vocab_test__ __ggml_vocab_test__ Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ __ggml_vocab_test__ +!!!!!! +__ggml_vocab_test__ 3 __ggml_vocab_test__ 33 diff --git a/models/ggml-vocab-falcon.gguf.out b/models/ggml-vocab-falcon.gguf.out index 1ab70fe70..64a48d97f 100644 --- a/models/ggml-vocab-falcon.gguf.out +++ b/models/ggml-vocab-falcon.gguf.out @@ -31,6 +31,7 @@ 1212 40 18 4932 9856 23 291 18 436 12 1265 362 299 8196 207 204 42 50087 123 2727 20300 32022 133 234 17419 30137 28 7858 181 133 236 + 51520 30 3138 22287 diff --git a/models/ggml-vocab-gpt-2.gguf.inp b/models/ggml-vocab-gpt-2.gguf.inp index 5b4aeb31a..9baf7d77a 100644 --- a/models/ggml-vocab-gpt-2.gguf.inp +++ b/models/ggml-vocab-gpt-2.gguf.inp @@ -73,6 +73,8 @@ __ggml_vocab_test__ __ggml_vocab_test__ Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ __ggml_vocab_test__ +!!!!!! +__ggml_vocab_test__ 3 __ggml_vocab_test__ 33 diff --git a/models/ggml-vocab-gpt-2.gguf.out b/models/ggml-vocab-gpt-2.gguf.out index 88217d3fa..17a13bdfc 100644 --- a/models/ggml-vocab-gpt-2.gguf.out +++ b/models/ggml-vocab-gpt-2.gguf.out @@ -31,6 +31,7 @@ 198 796 6 6980 15496 11 331 6 439 0 1374 389 345 30325 223 5633 22755 239 46349 111 28839 101 18040 32432 98 43291 1485 1415 24309 25465 171 121 252 + 13896 3228 18 2091 20370 diff --git a/models/ggml-vocab-llama-bpe.gguf.inp b/models/ggml-vocab-llama-bpe.gguf.inp index 5b4aeb31a..9baf7d77a 100644 --- a/models/ggml-vocab-llama-bpe.gguf.inp +++ b/models/ggml-vocab-llama-bpe.gguf.inp @@ -73,6 +73,8 @@ __ggml_vocab_test__ __ggml_vocab_test__ Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ __ggml_vocab_test__ +!!!!!! +__ggml_vocab_test__ 3 __ggml_vocab_test__ 33 diff --git a/models/ggml-vocab-llama-bpe.gguf.out b/models/ggml-vocab-llama-bpe.gguf.out index bb1fe229c..4b35cf93f 100644 --- a/models/ggml-vocab-llama-bpe.gguf.out +++ b/models/ggml-vocab-llama-bpe.gguf.out @@ -31,6 +31,7 @@ 198 284 6 11639 9906 11 379 65948 0 2650 527 499 27623 223 949 37046 101067 19000 23182 102301 9263 18136 16 36827 21909 + 17523 3001 18 1644 8765 diff --git a/models/ggml-vocab-llama-spm.gguf.inp b/models/ggml-vocab-llama-spm.gguf.inp index 5b4aeb31a..9baf7d77a 100644 --- a/models/ggml-vocab-llama-spm.gguf.inp +++ b/models/ggml-vocab-llama-spm.gguf.inp @@ -73,6 +73,8 @@ __ggml_vocab_test__ __ggml_vocab_test__ Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ __ggml_vocab_test__ +!!!!!! +__ggml_vocab_test__ 3 __ggml_vocab_test__ 33 diff --git a/models/ggml-vocab-llama-spm.gguf.out b/models/ggml-vocab-llama-spm.gguf.out index 1c3b0a2c9..93aacf8ba 100644 --- a/models/ggml-vocab-llama-spm.gguf.out +++ b/models/ggml-vocab-llama-spm.gguf.out @@ -31,6 +31,7 @@ 29871 13 353 525 3152 15043 29892 343 29915 497 29991 1128 526 366 29871 243 162 155 132 1577 30672 31522 30505 11548 31041 30732 29896 29941 29896 29946 29896 29945 29896 30408 30739 + 1738 6824 21004 29871 29941 29871 29941 29941 29871 29941 29941 29941 diff --git a/models/ggml-vocab-mpt.gguf.inp b/models/ggml-vocab-mpt.gguf.inp index 5b4aeb31a..9baf7d77a 100644 --- a/models/ggml-vocab-mpt.gguf.inp +++ b/models/ggml-vocab-mpt.gguf.inp @@ -73,6 +73,8 @@ __ggml_vocab_test__ __ggml_vocab_test__ Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ __ggml_vocab_test__ +!!!!!! +__ggml_vocab_test__ 3 __ggml_vocab_test__ 33 diff --git a/models/ggml-vocab-mpt.gguf.out b/models/ggml-vocab-mpt.gguf.out index d4a877d1c..372c751bf 100644 --- a/models/ggml-vocab-mpt.gguf.out +++ b/models/ggml-vocab-mpt.gguf.out @@ -31,6 +31,7 @@ 187 426 8 8685 12092 13 340 8 455 2 1359 403 368 49042 212 3736 15367 41197 13610 19934 41869 21275 1012 1047 18795 40120 20422 241 + 18963 4672 20 1610 20084 diff --git a/models/ggml-vocab-phi-3.gguf.inp b/models/ggml-vocab-phi-3.gguf.inp index 5b4aeb31a..9baf7d77a 100644 --- a/models/ggml-vocab-phi-3.gguf.inp +++ b/models/ggml-vocab-phi-3.gguf.inp @@ -73,6 +73,8 @@ __ggml_vocab_test__ __ggml_vocab_test__ Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ __ggml_vocab_test__ +!!!!!! +__ggml_vocab_test__ 3 __ggml_vocab_test__ 33 diff --git a/models/ggml-vocab-phi-3.gguf.out b/models/ggml-vocab-phi-3.gguf.out index 1c3b0a2c9..93aacf8ba 100644 --- a/models/ggml-vocab-phi-3.gguf.out +++ b/models/ggml-vocab-phi-3.gguf.out @@ -31,6 +31,7 @@ 29871 13 353 525 3152 15043 29892 343 29915 497 29991 1128 526 366 29871 243 162 155 132 1577 30672 31522 30505 11548 31041 30732 29896 29941 29896 29946 29896 29945 29896 30408 30739 + 1738 6824 21004 29871 29941 29871 29941 29941 29871 29941 29941 29941 diff --git a/models/ggml-vocab-qwen2.gguf.inp b/models/ggml-vocab-qwen2.gguf.inp index 5b4aeb31a..9baf7d77a 100644 --- a/models/ggml-vocab-qwen2.gguf.inp +++ b/models/ggml-vocab-qwen2.gguf.inp @@ -73,6 +73,8 @@ __ggml_vocab_test__ __ggml_vocab_test__ Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ __ggml_vocab_test__ +!!!!!! +__ggml_vocab_test__ 3 __ggml_vocab_test__ 33 diff --git a/models/ggml-vocab-qwen2.gguf.out b/models/ggml-vocab-qwen2.gguf.out index 4ab275396..18b4b45cd 100644 --- a/models/ggml-vocab-qwen2.gguf.out +++ b/models/ggml-vocab-qwen2.gguf.out @@ -31,6 +31,7 @@ 198 284 6 11385 9707 11 379 64848 0 2585 525 498 26525 223 937 104100 18493 22377 99257 16 18 16 19 16 20 16 35727 21216 + 17085 2928 18 18 18 18 18 18 diff --git a/models/ggml-vocab-refact.gguf.inp b/models/ggml-vocab-refact.gguf.inp index 5b4aeb31a..9baf7d77a 100644 --- a/models/ggml-vocab-refact.gguf.inp +++ b/models/ggml-vocab-refact.gguf.inp @@ -73,6 +73,8 @@ __ggml_vocab_test__ __ggml_vocab_test__ Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ __ggml_vocab_test__ +!!!!!! +__ggml_vocab_test__ 3 __ggml_vocab_test__ 33 diff --git a/models/ggml-vocab-refact.gguf.out b/models/ggml-vocab-refact.gguf.out index 46d8b4aec..63d8305c3 100644 --- a/models/ggml-vocab-refact.gguf.out +++ b/models/ggml-vocab-refact.gguf.out @@ -31,6 +31,7 @@ 203 280 25 34666 8279 30 533 25 464 19 4971 884 844 18458 228 1018 4982 13368 2909 9513 17827 35 37 35 38 35 39 35 11873 47838 + 9163 3202 37 37 37 37 37 37 diff --git a/models/ggml-vocab-starcoder.gguf.inp b/models/ggml-vocab-starcoder.gguf.inp index 5b4aeb31a..9baf7d77a 100644 --- a/models/ggml-vocab-starcoder.gguf.inp +++ b/models/ggml-vocab-starcoder.gguf.inp @@ -73,6 +73,8 @@ __ggml_vocab_test__ __ggml_vocab_test__ Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ __ggml_vocab_test__ +!!!!!! +__ggml_vocab_test__ 3 __ggml_vocab_test__ 33 diff --git a/models/ggml-vocab-starcoder.gguf.out b/models/ggml-vocab-starcoder.gguf.out index 9ce2698a9..87e2465d3 100644 --- a/models/ggml-vocab-starcoder.gguf.out +++ b/models/ggml-vocab-starcoder.gguf.out @@ -31,6 +31,7 @@ 222 299 44 34719 8302 49 553 44 483 38 4998 904 863 18445 247 1037 4995 13379 2924 9515 17823 54 56 54 57 54 58 54 11904 47892 + 9221 3226 56 56 56 56 56 56 diff --git a/src/llama.cpp b/src/llama.cpp index 3d131b325..1fe2b9f79 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2044,12 +2044,18 @@ enum e_model { MODEL_17M, MODEL_22M, MODEL_33M, + MODEL_60M, MODEL_70M, + MODEL_80M, MODEL_109M, MODEL_137M, MODEL_160M, + MODEL_220M, + MODEL_250M, MODEL_335M, MODEL_410M, + MODEL_770M, + MODEL_780M, MODEL_0_5B, MODEL_1B, MODEL_1_3B, @@ -2061,6 +2067,7 @@ enum e_model { MODEL_6_9B, MODEL_7B, MODEL_8B, + MODEL_11B, MODEL_12B, MODEL_13B, MODEL_14B, @@ -2112,6 +2119,7 @@ struct llama_hparams { uint32_t n_expert = 0; uint32_t n_expert_used = 0; uint32_t n_vocab_type = 0; // for BERT-style token types + uint32_t n_rel_attn_bkts = 0; uint32_t n_layer_dense_lead = 0; uint32_t n_lora_q = 0; @@ -2147,6 +2155,10 @@ struct llama_hparams { bool use_alibi = false; bool attn_soft_cap = false; + // needed by encoder-decoder models (e.g. T5, FLAN-T5) + // ref: https://github.com/ggerganov/llama.cpp/pull/8141 + llama_token dec_start_token_id = -1; + enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE; enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE; enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE; @@ -2167,6 +2179,7 @@ struct llama_hparams { if (this->n_expert != other.n_expert) return true; if (this->n_expert_used != other.n_expert_used) return true; + if (this->n_rel_attn_bkts != other.n_rel_attn_bkts) return true; if (this->n_layer_dense_lead != other.n_layer_dense_lead) return true; if (this->n_lora_q != other.n_lora_q) return true; if (this->n_lora_kv != other.n_lora_kv) return true; @@ -2182,6 +2195,8 @@ struct llama_hparams { if (this->ssm_d_state != other.ssm_d_state) return true; if (this->ssm_dt_rank != other.ssm_dt_rank) return true; + if (this->dec_start_token_id != other.dec_start_token_id) return true; + const float EPSILON = 1e-9f; if (!is_float_close(this->f_norm_eps, other.f_norm_eps, EPSILON)) return true; @@ -2254,6 +2269,7 @@ struct llama_cparams { void * cb_eval_user_data; }; +// TODO: separate into "llama_layer_enc" and "llama_layer_dec" struct llama_layer { // normalization struct ggml_tensor * attn_norm; @@ -2271,6 +2287,8 @@ struct llama_layer { struct ggml_tensor * attn_sub_norm; struct ggml_tensor * attn_post_norm; struct ggml_tensor * ffn_sub_norm; + struct ggml_tensor * attn_norm_cross; + struct ggml_tensor * attn_norm_enc; // attention struct ggml_tensor * wq; @@ -2282,6 +2300,14 @@ struct llama_layer { struct ggml_tensor * wq_b; struct ggml_tensor * wkv_a_mqa; struct ggml_tensor * wkv_b; + struct ggml_tensor * wq_cross; + struct ggml_tensor * wk_cross; + struct ggml_tensor * wv_cross; + struct ggml_tensor * wo_cross; + struct ggml_tensor * wq_enc; + struct ggml_tensor * wk_enc; + struct ggml_tensor * wv_enc; + struct ggml_tensor * wo_enc; // attention bias struct ggml_tensor * bq; @@ -2290,6 +2316,11 @@ struct llama_layer { struct ggml_tensor * bo; struct ggml_tensor * bqkv; + // relative position bias + struct ggml_tensor * attn_rel_b; + struct ggml_tensor * attn_rel_b_enc; + struct ggml_tensor * attn_rel_b_cross; + // normalization struct ggml_tensor * ffn_norm; struct ggml_tensor * ffn_norm_b; @@ -2297,11 +2328,15 @@ struct llama_layer { struct ggml_tensor * layer_out_norm; struct ggml_tensor * layer_out_norm_b; struct ggml_tensor * ffn_norm_exps; + struct ggml_tensor * ffn_norm_enc; // ff struct ggml_tensor * ffn_gate; // w1 struct ggml_tensor * ffn_down; // w2 struct ggml_tensor * ffn_up; // w3 + struct ggml_tensor * ffn_gate_enc; + struct ggml_tensor * ffn_down_enc; + struct ggml_tensor * ffn_up_enc; // ff MoE struct ggml_tensor * ffn_gate_inp; @@ -2535,6 +2570,7 @@ struct llama_model { struct ggml_tensor * output_norm_b; struct ggml_tensor * output; struct ggml_tensor * output_b; + struct ggml_tensor * output_norm_enc; std::vector layers; @@ -2663,6 +2699,13 @@ struct llama_context { // populated only when pooling_type != LLAMA_POOLING_TYPE_NONE std::map> embd_seq; + // whether we are computing encoder output or decoder output + bool is_encoding = false; + + // output of the encoder part of the encoder-decoder models + std::vector embd_enc; + std::vector> seq_ids_enc; + // memory buffers used to evaluate the model std::vector buf_compute_meta; ggml_backend_sched_t sched = nullptr; @@ -2683,6 +2726,9 @@ struct llama_context { struct ggml_tensor * inp_s_copy; // I32 [kv_size] struct ggml_tensor * inp_s_mask; // F32 [1, n_kv] struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch] + struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch] + struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc] + struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch] // control vectors struct llama_control_vector cvec; @@ -4286,12 +4332,18 @@ static const char * llama_model_type_name(e_model type) { case MODEL_17M: return "17M"; case MODEL_22M: return "22M"; case MODEL_33M: return "33M"; + case MODEL_60M: return "60M"; case MODEL_70M: return "70M"; + case MODEL_80M: return "80M"; case MODEL_109M: return "109M"; case MODEL_137M: return "137M"; case MODEL_160M: return "160M"; + case MODEL_220M: return "220M"; + case MODEL_250M: return "250M"; case MODEL_335M: return "335M"; case MODEL_410M: return "410M"; + case MODEL_770M: return "770M"; + case MODEL_780M: return "780M"; case MODEL_0_5B: return "0.5B"; case MODEL_1B: return "1B"; case MODEL_1_3B: return "1.3B"; @@ -4303,6 +4355,7 @@ static const char * llama_model_type_name(e_model type) { case MODEL_6_9B: return "6.9B"; case MODEL_7B: return "7B"; case MODEL_8B: return "8B"; + case MODEL_11B: return "11B"; case MODEL_12B: return "12B"; case MODEL_13B: return "13B"; case MODEL_14B: return "14B"; @@ -4917,6 +4970,38 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_T5: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts); + + uint32_t dec_start_token_id; + if (ml.get_key(LLM_KV_DECODER_START_TOKEN_ID, dec_start_token_id, false)) { + hparams.dec_start_token_id = dec_start_token_id; + } + + switch (hparams.n_layer) { + case 6: model.type = e_model::MODEL_60M; break; // t5-small + case 8: model.type = e_model::MODEL_80M; break; // flan-t5-small + case 12: + switch (hparams.n_ff) { + case 3072: model.type = e_model::MODEL_220M; break; // t5-base + case 2048: model.type = e_model::MODEL_250M; break; // flan-t5-base + default: model.type = e_model::MODEL_UNKNOWN; + } break; + case 24: + switch (hparams.n_ff) { + case 4096: model.type = e_model::MODEL_770M; break; // t5-large + case 2816: model.type = e_model::MODEL_780M; break; // flan-t5-large + case 16384: model.type = e_model::MODEL_3B; break; // t5-3b + case 5120: model.type = e_model::MODEL_3B; break; // flan-t5-xl + case 65536: model.type = e_model::MODEL_11B; break; // t5-11b + case 10240: model.type = e_model::MODEL_11B; break; // flan-t5-xxl + default: model.type = e_model::MODEL_UNKNOWN; + } break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; case LLM_ARCH_JAIS: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -6996,6 +7081,64 @@ static bool llm_load_tensors( layer.ffn_up_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "scale", i), {1}); } } break; + case LLM_ARCH_T5: + { + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + { + model.output_norm_enc = ml.create_tensor(ctx_output, tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}); + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_DEC_OUTPUT_NORM, "weight"), {n_embd}); + + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (model.output == NULL) { + model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + } + } + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.attn_norm_enc = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_rel_b_enc = ml.create_tensor(ctx_input, tn(LLM_TENSOR_ENC_ATTN_REL_B, "weight", i), {hparams.n_head, hparams.n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED); + + layer.wq_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}); + layer.wk_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}); + layer.wv_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}); + layer.wo_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}); + + layer.ffn_norm_enc = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_gate_enc = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd, n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_down_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_FFN_UP, "weight", i), {n_embd, n_ff}); + + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_DEC_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_rel_b = ml.create_tensor(ctx_input, tn(LLM_TENSOR_DEC_ATTN_REL_B, "weight", i), {hparams.n_head, hparams.n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED); + + layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}); + layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}); + layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}); + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}); + + layer.attn_norm_cross = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_DEC_CROSS_ATTN_NORM, "weight", i), {n_embd}); + // this tensor seems to be unused in HF transformers implementation + layer.attn_rel_b_cross = ml.create_tensor(ctx_input, tn(LLM_TENSOR_DEC_CROSS_ATTN_REL_B, "weight", i), {hparams.n_head, hparams.n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED); + + layer.wq_cross = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_CROSS_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}); + layer.wk_cross = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_CROSS_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}); + layer.wv_cross = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_CROSS_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}); + layer.wo_cross = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_CROSS_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}); + + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_DEC_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_gate = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_DEC_FFN_GATE, "weight", i), {n_embd, n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_FFN_UP, "weight", i), {n_embd, n_ff}); + } + } break; case LLM_ARCH_JAIS: { model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); @@ -7787,6 +7930,7 @@ struct llm_build_context { const int32_t n_tokens; const int32_t n_kv; // size of KV cache to consider (n_kv <= kv_self.size) const int32_t n_outputs; + const int32_t n_outputs_enc; const int32_t kv_head; // index of where we store new KV data in the cache const int32_t n_ctx_orig; @@ -7836,6 +7980,7 @@ struct llm_build_context { n_tokens (batch.n_tokens), n_kv (worst_case ? kv_self.size : kv_self.n), n_outputs (worst_case ? n_tokens : lctx.n_outputs), + n_outputs_enc (worst_case ? n_tokens : lctx.embd_enc.size() / hparams.n_embd), kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head), n_ctx_orig (cparams.n_ctx_orig_yarn), flash_attn (cparams.flash_attn), @@ -7867,6 +8012,9 @@ struct llm_build_context { lctx.inp_s_copy = nullptr; lctx.inp_s_mask = nullptr; lctx.inp_s_seq = nullptr; + lctx.inp_pos_bucket = nullptr; + lctx.inp_embd_enc = nullptr; + lctx.inp_KQ_mask_cross = nullptr; } void free() { @@ -8119,6 +8267,53 @@ struct llm_build_context { return gf; } + struct ggml_tensor * llm_build_pos_bucket(bool causal) { + if (causal) { + lctx.inp_pos_bucket = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_kv, n_tokens); + } else { + lctx.inp_pos_bucket = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_tokens); + } + + ggml_set_input(lctx.inp_pos_bucket); + cb(lctx.inp_pos_bucket, "pos_bucket", -1); + + return lctx.inp_pos_bucket; + } + + struct ggml_tensor * llm_build_pos_bias(struct ggml_tensor * pos_bucket, struct ggml_tensor * attn_rel_b) { + struct ggml_tensor * pos_bucket_1d = ggml_view_1d(ctx0, pos_bucket, pos_bucket->ne[0] * pos_bucket->ne[1], 0); + cb(pos_bucket_1d, "pos_bucket_1d", -1); + + struct ggml_tensor * pos_bias = ggml_get_rows(ctx0, attn_rel_b, pos_bucket_1d); + cb(pos_bias, "pos_bias", -1); + + pos_bias = ggml_view_3d(ctx0, pos_bias, pos_bias->ne[0], lctx.inp_pos_bucket->ne[0], lctx.inp_pos_bucket->ne[1], ggml_element_size(pos_bias) * pos_bias->ne[0], ggml_element_size(pos_bias) * pos_bias->ne[0] * lctx.inp_pos_bucket->ne[0], 0); + cb(pos_bias, "pos_bias", -1); + + pos_bias = ggml_permute(ctx0, pos_bias, 2, 0, 1, 3); + cb(pos_bias, "pos_bias", -1); + + pos_bias = ggml_cont(ctx0, pos_bias); + cb(pos_bias, "pos_bias", -1); + + return pos_bias; + } + + struct ggml_tensor * llm_build_inp_embd_enc() { + const int64_t n_embd = hparams.n_embd; + lctx.inp_embd_enc = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_outputs_enc); + ggml_set_input(lctx.inp_embd_enc); + cb(lctx.inp_embd_enc, "embd_enc", -1); + return lctx.inp_embd_enc; + } + + struct ggml_tensor * llm_build_inp_KQ_mask_cross() { + lctx.inp_KQ_mask_cross = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_outputs_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + ggml_set_input(lctx.inp_KQ_mask_cross); + cb(lctx.inp_KQ_mask_cross, "KQ_mask_cross", -1); + return lctx.inp_KQ_mask_cross; + } + struct ggml_cgraph * build_llama() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); @@ -12426,6 +12621,321 @@ struct llm_build_context { return gf; } + struct ggml_cgraph * build_t5() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + + // mutable variable, needed during the last layer of the computation to skip unused tokens + int32_t n_tokens = this->n_tokens; + + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + if (lctx.is_encoding) { + struct ggml_tensor * pos_bucket_enc = llm_build_pos_bucket(false); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask_enc = build_inp_KQ_mask(false); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm_enc, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq_enc, cur); + cb(Qcur, "Qcur", il); + + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk_enc, cur); + cb(Kcur, "Kcur", il); + + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv_enc, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + + struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); + struct ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3)); + + struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); + cb(kq, "kq", il); + + struct ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b_enc ? model.layers[il].attn_rel_b_enc : model.layers[0].attn_rel_b_enc; + struct ggml_tensor * pos_bias = llm_build_pos_bias(pos_bucket_enc, attn_rel_b); + struct ggml_tensor * kq_b = ggml_add(ctx0, kq, pos_bias); + cb(kq_b, "kq_b", il); + + kq = ggml_soft_max_ext(ctx0, kq_b, KQ_mask_enc, 1.0f, hparams.f_max_alibi_bias); + cb(kq, "kq_soft_max_ext", il); + + struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_tokens))); + cb(v, "v", il); + + struct ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_tokens, n_embd_head, n_head_kv), kq); + cb(kqv, "kqv", il); + + struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3); + cb(kqv_merged, "kqv_merged", il); + + cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens); + cb(cur, "kqv_merged_cont", il); + + ggml_build_forward_expand(gf, cur); + + cur = ggml_mul_mat(ctx0, model.layers[il].wo_enc, cur); + cb(cur, "kqv_out", il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + n_tokens = n_outputs; + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + { + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm_enc, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + // T5 uses relu, flan-T5 uses gelu-gated + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up_enc, NULL, NULL, + model.layers[il].ffn_gate_enc, NULL, NULL, + model.layers[il].ffn_down_enc, NULL, NULL, + NULL, + model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU, + model.layers[il].ffn_gate_enc ? LLM_FFN_PAR : LLM_FFN_SEQ, + cb, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + ggml_tensor * layer_dir = lctx.cvec.tensor_for(il); + if (layer_dir != nullptr) { + cur = ggml_add(ctx0, cur, layer_dir); + } + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + cb(cur, "result_embd", -1); + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm_enc, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + } else { + struct ggml_tensor * embd_enc = llm_build_inp_embd_enc(); + struct ggml_tensor * pos_bucket_dec = llm_build_pos_bucket(true); + + struct ggml_tensor * KQ_mask_dec = build_inp_KQ_mask(); + struct ggml_tensor * KQ_mask_cross = llm_build_inp_KQ_mask_cross(); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + llm_build_kv_store(ctx0, hparams, cparams, kv_self, gf, Kcur, Vcur, n_tokens, kv_head, cb, il); + + struct ggml_tensor * k = + ggml_view_3d(ctx0, kv_self.k_l[il], + n_embd_head_k, n_kv, n_head_kv, + ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa), + ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k), + 0); + cb(k, "k", il); + + struct ggml_tensor * v = + ggml_view_3d(ctx0, kv_self.v_l[il], + n_kv, n_embd_head_v, n_head_kv, + ggml_element_size(kv_self.v_l[il])*n_ctx, + ggml_element_size(kv_self.v_l[il])*n_ctx*n_embd_head_v, + 0); + cb(v, "v", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + + struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); + + struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); + cb(kq, "kq", il); + + struct ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b ? model.layers[il].attn_rel_b : model.layers[0].attn_rel_b; + struct ggml_tensor * pos_bias = llm_build_pos_bias(pos_bucket_dec, attn_rel_b); + struct ggml_tensor * kq_b = ggml_add(ctx0, kq, pos_bias); + cb(kq_b, "kq_b", il); + + kq = ggml_soft_max_ext(ctx0, kq_b, KQ_mask_dec, 1.0f, hparams.f_max_alibi_bias); + cb(kq, "kq_soft_max_ext", il); + + struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq); + cb(kqv, "kqv", il); + + struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3); + cb(kqv_merged, "kqv_merged", il); + + cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens); + cb(cur, "kqv_merged_cont", il); + + ggml_build_forward_expand(gf, cur); + + cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur); + cb(cur, "kqv_out", il); + } + + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "cross_inp", il); + + struct ggml_tensor * inpCA = cur; + + // norm + cur = llm_build_norm(ctx0, cur, hparams, + model.layers[il].attn_norm_cross, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm_cross", il); + + // cross-attention + { + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq_cross, cur); + cb(Qcur, "Qcur", il); + + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk_cross, embd_enc); + cb(Kcur, "Kcur", il); + + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv_cross, embd_enc); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_outputs_enc); + + struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); + struct ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3)); + + struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); + cb(kq, "kq", il); + + kq = ggml_soft_max_ext(ctx0, kq, KQ_mask_cross, 1.0f, hparams.f_max_alibi_bias); + cb(kq, "kq_soft_max_ext", il); + + struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_outputs_enc))); + cb(v, "v", il); + + struct ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_outputs_enc, n_embd_head, n_head_kv), kq); + cb(kqv, "kqv", il); + + struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3); + cb(kqv_merged, "kqv_merged", il); + + cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens); + cb(cur, "kqv_merged_cont", il); + + ggml_build_forward_expand(gf, cur); + + cur = ggml_mul_mat(ctx0, model.layers[il].wo_cross, cur); + cb(cur, "kqv_out", il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + n_tokens = n_outputs; + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + inpCA = ggml_get_rows(ctx0, inpCA, inp_out_ids); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpCA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + { + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + // T5 uses relu, flan-T5 uses gelu-gated + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU, + model.layers[il].ffn_gate_enc ? LLM_FFN_PAR : LLM_FFN_SEQ, + cb, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + ggml_tensor * layer_dir = lctx.cvec.tensor_for(il); + if (layer_dir != nullptr) { + cur = ggml_add(ctx0, cur, layer_dir); + } + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + cb(cur, "result_embd", -1); + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); + } + + ggml_build_forward_expand(gf, cur); + + return gf; + } + struct ggml_cgraph * build_jais() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); @@ -12748,6 +13258,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_bitnet(); } break; + case LLM_ARCH_T5: + { + result = llm.build_t5(); + } break; case LLM_ARCH_JAIS: { result = llm.build_jais(); @@ -12790,6 +13304,30 @@ static void llama_set_s_copy(llama_context & lctx) { } } +static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) { + // TODO move to hparams if a T5 variant appears that uses a different value + const int64_t max_distance = 128; + + if (bidirectional) { + n_buckets >>= 1; + } + + const int64_t max_exact = n_buckets >> 1; + + int32_t relative_position = x - y; + int32_t relative_bucket = 0; + if (bidirectional) { + relative_bucket += (relative_position > 0) * n_buckets; + relative_position = abs(relative_position); + } else { + relative_position = -std::min(relative_position, 0); + } + int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact)); + relative_position_if_large = std::min(relative_position_if_large, n_buckets - 1); + relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large); + return relative_bucket; +} + static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { // // set input data @@ -12855,7 +13393,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { if (lctx.inp_KQ_mask) { // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache. - if (cparams.causal_attn) { + if (cparams.causal_attn && !lctx.is_encoding) { const int64_t n_kv = kv_self.n; const int64_t n_tokens = batch.n_tokens; @@ -12908,7 +13446,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } else { // when using kv cache, the mask needs to match the kv cache size const int64_t n_tokens = batch.n_tokens; - const int64_t n_stride = hparams.causal_attn ? kv_self.n : n_tokens; + const int64_t n_stride = hparams.causal_attn && !lctx.is_encoding ? kv_self.n : n_tokens; GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer)); @@ -13072,6 +13610,70 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } } + + if (lctx.inp_pos_bucket) { + const int64_t n_tokens = batch.n_tokens; + + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_pos_bucket->buffer)); + + int32_t * data = (int32_t *) lctx.inp_pos_bucket->data; + + if (!lctx.is_encoding) { + const int64_t n_kv = kv_self.n; + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < n_tokens; ++j) { + for (int i = 0; i < n_kv; ++i) { + data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(lctx.kv_self.cells[i].pos, batch.pos[j], hparams.n_rel_attn_bkts, lctx.is_encoding); + } + } + } + } else { + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < n_tokens; ++j) { + for (int i = 0; i < n_tokens; ++i) { + data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(batch.pos[i], batch.pos[j], hparams.n_rel_attn_bkts, lctx.is_encoding); + } + } + } + } + } + + if (!lctx.is_encoding && lctx.inp_embd_enc) { + assert(lctx.inp_embd_enc->type == GGML_TYPE_F32); + assert((size_t) ggml_nelements(lctx.inp_embd_enc) == lctx.embd_enc.size()); + + ggml_backend_tensor_set(lctx.inp_embd_enc, lctx.embd_enc.data(), 0, ggml_nbytes(lctx.inp_embd_enc)); + } + + if (!lctx.is_encoding && lctx.inp_KQ_mask_cross) { + const int64_t n_output_enc = lctx.embd_enc.size() / hparams.n_embd; + const int64_t n_tokens = batch.n_tokens; + + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask_cross->buffer)); + + float * data = (float *) lctx.inp_KQ_mask_cross->data; + + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < n_tokens; ++j) { + for (int i = 0; i < n_output_enc; ++i) { + float f = -INFINITY; + for (int s = 0; s < batch.n_seq_id[j]; ++s) { + const llama_seq_id seq_id = batch.seq_id[j][s]; + if (lctx.seq_ids_enc[i].find(seq_id) != lctx.seq_ids_enc[i].end()) { + f = 0.0f; + } + } + data[h*(n_output_enc*n_tokens) + j*n_output_enc + i] = f; + } + } + + for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { + for (int j = 0; j < n_output_enc; ++j) { + data[h*(n_output_enc*n_tokens) + i*n_output_enc + j] = -INFINITY; + } + } + } + } } // Make sure enough space is available for outputs. @@ -13088,7 +13690,7 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) { // TODO: use a per-batch flag for logits presence instead const bool has_logits = !cparams.embeddings; - const bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE); + const bool has_embd = lctx.is_encoding || (cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE)); const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0; const size_t embd_size = has_embd ? n_embd*n_outputs_max : 0; @@ -13180,6 +13782,7 @@ static int llama_decode_internal( llama_context & lctx, llama_batch batch_all) { // TODO: rename back to batch + lctx.is_encoding = false; const uint32_t n_tokens_all = batch_all.n_tokens; if (n_tokens_all == 0) { @@ -13212,6 +13815,7 @@ static int llama_decode_internal( const auto n_ubatch = cparams.n_ubatch; + // TODO: simplify or deprecate std::vector pos; std::vector n_seq_id; std::vector seq_id_arr; @@ -13475,6 +14079,138 @@ static int llama_decode_internal( return 0; } +// encode a batch of tokens by evaluating the encoder part of the transformer +// +// - lctx: llama context +// - batch: batch to evaluate +// +// return 0 on success +// return positive int on warning +// return negative int on error +// +static int llama_encode_internal( + llama_context & lctx, + llama_batch batch) { + + lctx.is_encoding = true; + + const uint32_t n_tokens = batch.n_tokens; + + if (n_tokens == 0) { + LLAMA_LOG_ERROR("%s: n_tokens == 0", __func__); + return -1; + } + + const auto & model = lctx.model; + const auto & hparams = model.hparams; + const auto & cparams = lctx.cparams; + + GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT + + // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot + GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens"); + + if (lctx.t_compute_start_us == 0) { + lctx.t_compute_start_us = ggml_time_us(); + } + + lctx.n_queued_tokens += n_tokens; + + const int64_t n_embd = hparams.n_embd; + + // TODO: simplify or deprecate + std::vector pos; + std::vector n_seq_id; + std::vector seq_id_arr; + std::vector> seq_id; + + // reserve output buffer + if (llama_output_reserve(lctx, n_tokens) < n_tokens) { + LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens); + return -2; + }; + + for (uint32_t i = 0; i < n_tokens; ++i) { + lctx.output_ids[i] = i; + } + + lctx.inp_embd_enc = NULL; + lctx.n_outputs = n_tokens; + + const int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch; + GGML_ASSERT(n_threads > 0); + + // helpers for smoother batch API transition + // after deprecating the llama_eval calls, these will be removed + if (batch.pos == nullptr) { + pos.resize(n_tokens); + for (uint32_t i = 0; i < n_tokens; i++) { + pos[i] = batch.all_pos_0 + i*batch.all_pos_1; + } + + batch.pos = pos.data(); + } + + if (batch.seq_id == nullptr) { + n_seq_id.resize(n_tokens); + seq_id.resize(n_tokens); + seq_id_arr.resize(n_tokens); + for (uint32_t i = 0; i < n_tokens; i++) { + n_seq_id[i] = 1; + seq_id[i].resize(1); + seq_id[i][0] = batch.all_seq_id; + seq_id_arr[i] = seq_id[i].data(); + } + + batch.n_seq_id = n_seq_id.data(); + batch.seq_id = seq_id_arr.data(); + } + + ggml_backend_sched_reset(lctx.sched); + ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data); + + ggml_cgraph * gf = llama_build_graph(lctx, batch, false); + + // the output embeddings after the final encoder normalization + struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 1]; + + GGML_ASSERT(strcmp(embd->name, "result_norm") == 0); + + ggml_backend_sched_alloc_graph(lctx.sched, gf); + + llama_set_inputs(lctx, batch); + + llama_graph_compute(lctx, gf, n_threads); + + // extract embeddings + if (embd) { + ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd); + GGML_ASSERT(backend_embd != nullptr); + + // extract token embeddings + GGML_ASSERT(lctx.embd != nullptr); + + lctx.embd_enc.resize(n_tokens*n_embd); + float * embd_out = lctx.embd_enc.data(); + + ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_tokens*n_embd*sizeof(float)); + + // remember the sequence ids used during the encoding - needed for cross attention later + lctx.seq_ids_enc.resize(n_tokens); + for (uint32_t i = 0; i < n_tokens; i++) { + for (int s = 0; s < batch.n_seq_id[i]; s++) { + llama_seq_id seq_id = batch.seq_id[i][s]; + lctx.seq_ids_enc[i].insert(seq_id); + } + } + } + + // Reset state for the next token before backend sync, to allow the CPU activities in the reset to + // overlap with device computation. + ggml_backend_sched_reset(lctx.sched); + + return 0; +} // find holes from the beginning of the KV cache and fill them by moving data from the end of the cache static void llama_kv_cache_defrag_internal(struct llama_context & lctx) { @@ -14547,11 +15283,14 @@ struct llm_tokenizer_ugm { std::string normalized; normalize(text, &normalized); size_t input_len = normalized.size(); + if (input_len == 0) { + return; + } // initialize score_sum to -FLT_MAX so it will be always lower than sums of token scores - std::vector tokenization_results(input_len + 1, {0, 0, -FLT_MAX}); + std::vector tokenization_results(input_len + 1, {vocab.special_unk_id, 0, -FLT_MAX}); // at the beginning tokenization score is zero - tokenization_results[0] = { 0, 0, 0 }; + tokenization_results[0] = { vocab.special_unk_id, 0, 0 }; for (size_t input_offset = 0; input_offset < input_len;) { size_t prefix_offset = input_offset; @@ -14571,7 +15310,7 @@ struct llm_tokenizer_ugm { single_codepoint_token_found = true; } llama_token token_id = node->value; - const auto &token_data = vocab.id_to_token[token_id]; + const auto & token_data = vocab.id_to_token[token_id]; // we set the user-defined token scores to 0 to make them more likely to be selected // (normal token scores are log probabilities, so they are negative) @@ -16854,10 +17593,11 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s // sanity checks // - // - qs.n_attention_wv == 0 for Mamba models - // - qs.n_attention_wv == model.hparams.n_layer for Transformer models + // - qs.n_attention_wv == 0 for Mamba models + // - qs.n_attention_wv == model.hparams.n_layer for Transformer models + // - qs.n_attention_wv == 3 * model.hparams.n_layer for Encoder-Decoder models // - GGML_ASSERT((qs.n_attention_wv == 0 || qs.n_attention_wv == (int)model.hparams.n_layer) && "n_attention_wv is unexpected"); + GGML_ASSERT((qs.n_attention_wv == 0 || qs.n_attention_wv == (int)model.hparams.n_layer || qs.n_attention_wv == 3 * (int)model.hparams.n_layer) && "n_attention_wv is unexpected"); size_t total_size_org = 0; size_t total_size_new = 0; @@ -16982,6 +17722,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s quantize &= name.find("ssm_x.weight") == std::string::npos; quantize &= name.find("ssm_dt.weight") == std::string::npos; + // do not quantize relative position bias (T5) + quantize &= name.find("attn_rel_b.weight") == std::string::npos; + enum ggml_type new_type; void * new_data; size_t new_size; @@ -18138,6 +18881,17 @@ struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const ch return it->second; } +bool llama_model_has_encoder(const struct llama_model * model) { + switch (model->arch) { + case LLM_ARCH_T5: return true; + default: return false; + } +} + +llama_token llama_model_decoder_start_token(const struct llama_model * model) { + return model->hparams.dec_start_token_id; +} + uint32_t llama_model_quantize( const char * fname_inp, const char * fname_out, @@ -19484,6 +20238,17 @@ void llama_batch_free(struct llama_batch batch) { if (batch.logits) free(batch.logits); } +int32_t llama_encode( + struct llama_context * ctx, + struct llama_batch batch) { + const int ret = llama_encode_internal(*ctx, batch); + if (ret < 0) { + LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret); + } + + return ret; +} + int32_t llama_decode( struct llama_context * ctx, struct llama_batch batch) {