diff --git a/common/common.cpp b/common/common.cpp index 560e20d08..d3d896115 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -2156,7 +2156,9 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) { tmp.clear(); tmp.push_back(decoder_start_token_id); } - llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0)); + if (llama_model_has_decoder(model)) { + llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0)); + } llama_kv_cache_clear(lctx); llama_synchronize(lctx); llama_reset_timings(lctx); diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 7136db440..550dd5cfd 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -3324,6 +3324,145 @@ class T5Model(Model): return [(self.map_tensor_name(name), data_torch)] +@Model.register("T5EncoderModel") +class T5EncoderModel(Model): + model_arch = gguf.MODEL_ARCH.T5ENCODER + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.shared_token_embeddings_found = False + + def set_vocab(self): + # to avoid TypeError: Descriptors cannot be created directly + # exception when importing sentencepiece_model_pb2 + os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" + from sentencepiece import SentencePieceProcessor + from sentencepiece import sentencepiece_model_pb2 as model + + tokenizer_path = self.dir_model / 'tokenizer.model' + + # many older models use spiece.model tokenizer model filename + if not tokenizer_path.is_file(): + tokenizer_path = self.dir_model / 'spiece.model' + + if not tokenizer_path.is_file(): + raise FileNotFoundError(f"File not found: {tokenizer_path}") + + sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue] + sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read()) + + # some models like Pile-T5 family use BPE tokenizer instead of Unigram + if sentencepiece_model.trainer_spec.model_type == 2: # BPE + # assure the tokenizer model file name is correct + assert tokenizer_path.name == 'tokenizer.model' + return self._set_vocab_sentencepiece() + else: + assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM + + add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix + remove_whitespaces = sentencepiece_model.normalizer_spec.remove_extra_whitespaces + precompiled_charsmap = sentencepiece_model.normalizer_spec.precompiled_charsmap + + tokenizer = SentencePieceProcessor() + tokenizer.LoadFromFile(str(tokenizer_path)) + + vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size()) + + tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)] + scores: list[float] = [-10000.0] * vocab_size + toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size + + for token_id in range(tokenizer.vocab_size()): + piece = tokenizer.IdToPiece(token_id) + text = piece.encode("utf-8") + score = tokenizer.GetScore(token_id) + + toktype = SentencePieceTokenTypes.NORMAL + if tokenizer.IsUnknown(token_id): + toktype = SentencePieceTokenTypes.UNKNOWN + elif tokenizer.IsControl(token_id): + toktype = SentencePieceTokenTypes.CONTROL + elif tokenizer.IsUnused(token_id): + toktype = SentencePieceTokenTypes.UNUSED + elif tokenizer.IsByte(token_id): + toktype = SentencePieceTokenTypes.BYTE + + tokens[token_id] = text + scores[token_id] = score + toktypes[token_id] = toktype + + added_tokens_file = self.dir_model / 'added_tokens.json' + if added_tokens_file.is_file(): + with open(added_tokens_file, "r", encoding="utf-8") as f: + added_tokens_json = json.load(f) + for key in added_tokens_json: + token_id = added_tokens_json[key] + if token_id >= vocab_size: + logger.warning(f'ignore token {token_id}: id is out of range, max={vocab_size - 1}') + continue + + tokens[token_id] = key.encode("utf-8") + scores[token_id] = -1000.0 + toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED + + if vocab_size > len(tokens): + pad_count = vocab_size - len(tokens) + logger.debug(f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]") + for i in range(1, pad_count + 1): + tokens.append(bytes(f"[PAD{i}]", encoding="utf-8")) + scores.append(-1000.0) + toktypes.append(SentencePieceTokenTypes.UNUSED) + + self.gguf_writer.add_tokenizer_model("t5") + self.gguf_writer.add_tokenizer_pre("default") + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_scores(scores) + self.gguf_writer.add_token_types(toktypes) + self.gguf_writer.add_add_space_prefix(add_prefix) + self.gguf_writer.add_remove_extra_whitespaces(remove_whitespaces) + if precompiled_charsmap: + self.gguf_writer.add_precompiled_charsmap(precompiled_charsmap) + + special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens)) + special_vocab.add_to_gguf(self.gguf_writer) + + self.gguf_writer.add_add_bos_token(False) + self.gguf_writer.add_add_eos_token(True) + + def set_gguf_parameters(self): + if (n_ctx := self.find_hparam(["n_positions"], optional=True)) is None: + logger.warning("Couldn't find context length in config.json, assuming default value of 512") + n_ctx = 512 + self.gguf_writer.add_context_length(n_ctx) + self.gguf_writer.add_embedding_length(self.hparams["d_model"]) + self.gguf_writer.add_feed_forward_length(self.hparams["d_ff"]) + self.gguf_writer.add_block_count(self.hparams["num_layers"]) + self.gguf_writer.add_head_count(self.hparams["num_heads"]) + self.gguf_writer.add_key_length(self.hparams["d_kv"]) + self.gguf_writer.add_value_length(self.hparams["d_kv"]) + self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"]) + self.gguf_writer.add_relative_attn_buckets_count(self.hparams["relative_attention_num_buckets"]) + self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"]) + self.gguf_writer.add_file_type(self.ftype) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + # T5 based models contain shared token embeddings tensors saved randomly as either "encoder.embed_tokens.weight", + # "decoder.embed_tokens.weight" or "shared.weight" tensor. In some models there are even multiple of them stored + # in the safetensors files. We use the first tensor from these three as the token embeddings for both encoder + # and decoder and ignore the remaining ones. + if name in ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "shared.weight"]: + if not self.shared_token_embeddings_found: + name = "shared.weight" + self.shared_token_embeddings_found = True + else: + logger.debug(f"Skipping shared tensor {name!r} in safetensors so that convert can end normally.") + return [] + + return [(self.map_tensor_name(name), data_torch)] + + @Model.register("JAISLMHeadModel") class JaisModel(Model): model_arch = gguf.MODEL_ARCH.JAIS diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index cd7b448a6..b05aa006e 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -31,13 +31,24 @@ static void batch_add_seq(llama_batch & batch, const std::vector & toke } static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) { + const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); + const struct llama_model * model = llama_get_model(ctx); + // clear previous kv_cache values (irrelevant for embeddings) llama_kv_cache_clear(ctx); // run model fprintf(stderr, "%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); - if (llama_decode(ctx, batch) < 0) { - fprintf(stderr, "%s : failed to decode\n", __func__); + if (llama_model_has_encoder(model) && !llama_model_has_decoder(model)) { + // encoder-only model + if (llama_encode(ctx, batch) < 0) { + fprintf(stderr, "%s : failed to encode\n", __func__); + } + } else if (!llama_model_has_encoder(model) && llama_model_has_decoder(model)) { + // decoder-only model + if (llama_decode(ctx, batch) < 0) { + fprintf(stderr, "%s : failed to decode\n", __func__); + } } for (int i = 0; i < batch.n_tokens; i++) { @@ -45,11 +56,22 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu continue; } - // try to get sequence embeddings - supported only when pooling_type is not NONE - const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); - GGML_ASSERT(embd != NULL && "failed to get sequence embeddings"); + const float * embd = nullptr; + int embd_pos = 0; - float * out = output + batch.seq_id[i][0] * n_embd; + if (pooling_type == LLAMA_POOLING_TYPE_NONE) { + // try to get token embeddings + embd = llama_get_embeddings_ith(ctx, i); + embd_pos = i; + GGML_ASSERT(embd != NULL && "failed to get token embeddings"); + } else { + // try to get sequence embeddings - supported only when pooling_type is not NONE + embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + embd_pos = batch.seq_id[i][0]; + GGML_ASSERT(embd != NULL && "failed to get sequence embeddings"); + } + + float * out = output + embd_pos * n_embd; llama_embd_normalize(embd, out, n_embd, embd_norm); } } @@ -93,8 +115,9 @@ int main(int argc, char ** argv) { const int n_ctx = llama_n_ctx(ctx); const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); - if (pooling_type == LLAMA_POOLING_TYPE_NONE) { - fprintf(stderr, "%s: error: pooling type NONE not supported\n", __func__); + + if (llama_model_has_encoder(model) && llama_model_has_decoder(model)) { + fprintf(stderr, "%s: error: computing embeddings in encoder-decoder models is not supported\n", __func__); return 1; } @@ -153,13 +176,23 @@ int main(int argc, char ** argv) { const int n_prompts = prompts.size(); struct llama_batch batch = llama_batch_init(n_batch, 0, 1); + // count number of embeddings + int n_embd_count = 0; + if (pooling_type == LLAMA_POOLING_TYPE_NONE) { + for (int k = 0; k < n_prompts; k++) { + n_embd_count += inputs[k].size(); + } + } else { + n_embd_count = n_prompts; + } + // allocate output const int n_embd = llama_n_embd(model); - std::vector embeddings(n_prompts * n_embd, 0); + std::vector embeddings(n_embd_count * n_embd, 0); float * emb = embeddings.data(); // break into batches - int p = 0; // number of prompts processed already + int e = 0; // number of embeddings already stored int s = 0; // number of prompts in current batch for (int k = 0; k < n_prompts; k++) { // clamp to n_batch tokens @@ -169,11 +202,11 @@ int main(int argc, char ** argv) { // encode if at capacity if (batch.n_tokens + n_toks > n_batch) { - float * out = emb + p * n_embd; + float * out = emb + e * n_embd; batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize); - llama_batch_clear(batch); - p += s; + e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.n_tokens : s; s = 0; + llama_batch_clear(batch); } // add to batch @@ -182,40 +215,63 @@ int main(int argc, char ** argv) { } // final batch - float * out = emb + p * n_embd; + float * out = emb + e * n_embd; batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize); if (params.embd_out.empty()) { - // print the first part of the embeddings or for a single prompt, the full embedding fprintf(stdout, "\n"); - for (int j = 0; j < n_prompts; j++) { - fprintf(stdout, "embedding %d: ", j); - for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd) : n_embd); i++) { - if (params.embd_normalize == 0) { - fprintf(stdout, "%6.0f ", emb[j * n_embd + i]); - } else { - fprintf(stdout, "%9.6f ", emb[j * n_embd + i]); - } - } - fprintf(stdout, "\n"); - } - // print cosine similarity matrix - if (n_prompts > 1) { - fprintf(stdout, "\n"); - printf("cosine similarity matrix:\n\n"); - for (int i = 0; i < n_prompts; i++) { - fprintf(stdout, "%6.6s ", prompts[i].c_str()); - } - fprintf(stdout, "\n"); - for (int i = 0; i < n_prompts; i++) { - for (int j = 0; j < n_prompts; j++) { - float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd); - fprintf(stdout, "%6.2f ", sim); + if (pooling_type == LLAMA_POOLING_TYPE_NONE) { + for (int j = 0; j < n_embd_count; j++) { + fprintf(stdout, "embedding %d: ", j); + for (int i = 0; i < std::min(3, n_embd); i++) { + if (params.embd_normalize == 0) { + fprintf(stdout, "%6.0f ", emb[j * n_embd + i]); + } else { + fprintf(stdout, "%9.6f ", emb[j * n_embd + i]); + } + } + fprintf(stdout, " ... "); + for (int i = n_embd - 3; i < n_embd; i++) { + if (params.embd_normalize == 0) { + fprintf(stdout, "%6.0f ", emb[j * n_embd + i]); + } else { + fprintf(stdout, "%9.6f ", emb[j * n_embd + i]); + } } - fprintf(stdout, "%1.10s", prompts[i].c_str()); fprintf(stdout, "\n"); } + } else { + // print the first part of the embeddings or for a single prompt, the full embedding + for (int j = 0; j < n_prompts; j++) { + fprintf(stdout, "embedding %d: ", j); + for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd) : n_embd); i++) { + if (params.embd_normalize == 0) { + fprintf(stdout, "%6.0f ", emb[j * n_embd + i]); + } else { + fprintf(stdout, "%9.6f ", emb[j * n_embd + i]); + } + } + fprintf(stdout, "\n"); + } + + // print cosine similarity matrix + if (n_prompts > 1) { + fprintf(stdout, "\n"); + printf("cosine similarity matrix:\n\n"); + for (int i = 0; i < n_prompts; i++) { + fprintf(stdout, "%6.6s ", prompts[i].c_str()); + } + fprintf(stdout, "\n"); + for (int i = 0; i < n_prompts; i++) { + for (int j = 0; j < n_prompts; j++) { + float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd); + fprintf(stdout, "%6.2f ", sim); + } + fprintf(stdout, "%1.10s", prompts[i].c_str()); + fprintf(stdout, "\n"); + } + } } } @@ -233,23 +289,23 @@ int main(int argc, char ** argv) { } fprintf(stdout, notArray ? "]\n }" : "]"); j++; - if (j < n_prompts) fprintf(stdout, notArray ? ",\n" : ","); else break; + if (j < n_embd_count) fprintf(stdout, notArray ? ",\n" : ","); else break; } fprintf(stdout, notArray ? "\n ]" : "]\n"); if (params.embd_out == "json+" && n_prompts > 1) { fprintf(stdout, ",\n \"cosineSimilarity\": [\n"); - for (int i = 0;;) { // at least two iteration (n_prompts > 1) + for (int i = 0;;) { // at least two iteration (n_embd_count > 1) fprintf(stdout, " ["); - for (int j = 0;;) { // at least two iteration (n_prompts > 1) + for (int j = 0;;) { // at least two iteration (n_embd_count > 1) float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd); fprintf(stdout, "%6.2f", sim); j++; - if (j < n_prompts) fprintf(stdout, ", "); else break; + if (j < n_embd_count) fprintf(stdout, ", "); else break; } fprintf(stdout, " ]"); i++; - if (i < n_prompts) fprintf(stdout, ",\n"); else break; + if (i < n_embd_count) fprintf(stdout, ",\n"); else break; } fprintf(stdout, "\n ]"); } diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 89efe0c80..f63ec450a 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -217,6 +217,7 @@ class MODEL_ARCH(IntEnum): CHATGLM = auto() BITNET = auto() T5 = auto() + T5ENCODER = auto() JAIS = auto() @@ -344,6 +345,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.CHATGLM: "chatglm", MODEL_ARCH.BITNET: "bitnet", MODEL_ARCH.T5: "t5", + MODEL_ARCH.T5ENCODER: "t5encoder", MODEL_ARCH.JAIS: "jais", } @@ -1036,6 +1038,21 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.ENC_FFN_UP, MODEL_TENSOR.ENC_OUTPUT_NORM, ], + MODEL_ARCH.T5ENCODER: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ENC_ATTN_NORM, + MODEL_TENSOR.ENC_ATTN_Q, + MODEL_TENSOR.ENC_ATTN_K, + MODEL_TENSOR.ENC_ATTN_V, + MODEL_TENSOR.ENC_ATTN_OUT, + MODEL_TENSOR.ENC_ATTN_REL_B, + MODEL_TENSOR.ENC_FFN_NORM, + MODEL_TENSOR.ENC_FFN_GATE, + MODEL_TENSOR.ENC_FFN_DOWN, + MODEL_TENSOR.ENC_FFN_UP, + MODEL_TENSOR.ENC_OUTPUT_NORM, + ], MODEL_ARCH.JAIS: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, diff --git a/include/llama.h b/include/llama.h index 66c266298..ce07f4fac 100644 --- a/include/llama.h +++ b/include/llama.h @@ -504,6 +504,9 @@ extern "C" { // Returns true if the model contains an encoder that requires llama_encode() call LLAMA_API bool llama_model_has_encoder(const struct llama_model * model); + // Returns true if the model contains a decoder that requires llama_decode() call + LLAMA_API bool llama_model_has_decoder(const struct llama_model * model); + // For encoder-decoder models, this function returns id of the token that must be provided // to the decoder to start generating output sequence. For other models, it returns -1. LLAMA_API llama_token llama_model_decoder_start_token(const struct llama_model * model); diff --git a/src/llama.cpp b/src/llama.cpp index 97dd1b3fe..9c4f2aa72 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -208,6 +208,7 @@ enum llm_arch { LLM_ARCH_CHATGLM, LLM_ARCH_BITNET, LLM_ARCH_T5, + LLM_ARCH_T5ENCODER, LLM_ARCH_JAIS, LLM_ARCH_UNKNOWN, }; @@ -252,6 +253,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_CHATGLM, "chatglm" }, { LLM_ARCH_BITNET, "bitnet" }, { LLM_ARCH_T5, "t5" }, + { LLM_ARCH_T5ENCODER, "t5encoder" }, { LLM_ARCH_JAIS, "jais" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -1261,6 +1263,24 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_ENC_FFN_UP, "enc.blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_T5ENCODER, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ENC_OUTPUT_NORM, "enc.output_norm" }, + { LLM_TENSOR_ENC_ATTN_NORM, "enc.blk.%d.attn_norm" }, + { LLM_TENSOR_ENC_ATTN_Q, "enc.blk.%d.attn_q" }, + { LLM_TENSOR_ENC_ATTN_K, "enc.blk.%d.attn_k" }, + { LLM_TENSOR_ENC_ATTN_V, "enc.blk.%d.attn_v" }, + { LLM_TENSOR_ENC_ATTN_OUT, "enc.blk.%d.attn_o" }, + { LLM_TENSOR_ENC_ATTN_REL_B, "enc.blk.%d.attn_rel_b" }, + { LLM_TENSOR_ENC_FFN_NORM, "enc.blk.%d.ffn_norm" }, + { LLM_TENSOR_ENC_FFN_GATE, "enc.blk.%d.ffn_gate" }, + { LLM_TENSOR_ENC_FFN_DOWN, "enc.blk.%d.ffn_down" }, + { LLM_TENSOR_ENC_FFN_UP, "enc.blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_JAIS, { @@ -5187,6 +5207,12 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_T5ENCODER: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts); + model.type = e_model::MODEL_UNKNOWN; + } break; case LLM_ARCH_JAIS: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -7421,6 +7447,42 @@ static bool llm_load_tensors( layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_FFN_UP, "weight", i), {n_embd, n_ff}); } } break; + case LLM_ARCH_T5ENCODER: + { + const auto n_rel_attn_bkts = hparams.n_rel_attn_bkts; + + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + { + model.output_norm_enc = ml.create_tensor(ctx_output, tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (model.output == NULL) { + model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + } + } + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.attn_norm_enc = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_rel_b_enc = ml.create_tensor(ctx_input, tn(LLM_TENSOR_ENC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED); + + layer.wq_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}); + layer.wk_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}); + layer.wv_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}); + layer.wo_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}); + + layer.ffn_norm_enc = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_gate_enc = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd, n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_down_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_FFN_UP, "weight", i), {n_embd, n_ff}); + } + } break; case LLM_ARCH_JAIS: { model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); @@ -13135,7 +13197,7 @@ struct llm_build_context { return gf; } - struct ggml_cgraph * build_t5() { + struct ggml_cgraph * build_t5_encoder() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); // mutable variable, needed during the last layer of the computation to skip unused tokens @@ -13150,303 +13212,323 @@ struct llm_build_context { inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); - if (lctx.is_encoding) { - struct ggml_tensor * pos_bucket_enc = llm_build_pos_bucket(false); + GGML_ASSERT(lctx.is_encoding); + struct ggml_tensor * pos_bucket_enc = llm_build_pos_bucket(false); - // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask_enc = build_inp_KQ_mask(false); + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask_enc = build_inp_KQ_mask(false); - for (int il = 0; il < n_layer; ++il) { - struct ggml_tensor * inpSA = inpL; + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; - // norm - cur = llm_build_norm(ctx0, inpL, hparams, - model.layers[il].attn_norm_enc, NULL, - LLM_NORM_RMS, cb, il); - cb(cur, "attn_norm", il); + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm_enc, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); - // self-attention - { - struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq_enc, cur); - cb(Qcur, "Qcur", il); + // self-attention + { + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq_enc, cur); + cb(Qcur, "Qcur", il); - struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk_enc, cur); - cb(Kcur, "Kcur", il); + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk_enc, cur); + cb(Kcur, "Kcur", il); - struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv_enc, cur); - cb(Vcur, "Vcur", il); + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv_enc, cur); + cb(Vcur, "Vcur", il); - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); - struct ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3)); + struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); + struct ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3)); - struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); - cb(kq, "kq", il); + struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); + cb(kq, "kq", il); - struct ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b_enc ? model.layers[il].attn_rel_b_enc : model.layers[0].attn_rel_b_enc; - struct ggml_tensor * pos_bias = llm_build_pos_bias(pos_bucket_enc, attn_rel_b); - struct ggml_tensor * kq_b = ggml_add(ctx0, kq, pos_bias); - cb(kq_b, "kq_b", il); + struct ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b_enc ? model.layers[il].attn_rel_b_enc : model.layers[0].attn_rel_b_enc; + struct ggml_tensor * pos_bias = llm_build_pos_bias(pos_bucket_enc, attn_rel_b); + struct ggml_tensor * kq_b = ggml_add(ctx0, kq, pos_bias); + cb(kq_b, "kq_b", il); - kq = ggml_soft_max_ext(ctx0, kq_b, KQ_mask_enc, 1.0f, hparams.f_max_alibi_bias); - cb(kq, "kq_soft_max_ext", il); + kq = ggml_soft_max_ext(ctx0, kq_b, KQ_mask_enc, 1.0f, hparams.f_max_alibi_bias); + cb(kq, "kq_soft_max_ext", il); - struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_tokens))); - cb(v, "v", il); + struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_tokens))); + cb(v, "v", il); - struct ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_tokens, n_embd_head, n_head_kv), kq); - cb(kqv, "kqv", il); + struct ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_tokens, n_embd_head, n_head_kv), kq); + cb(kqv, "kqv", il); - struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3); - cb(kqv_merged, "kqv_merged", il); + struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3); + cb(kqv_merged, "kqv_merged", il); - cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens); - cb(cur, "kqv_merged_cont", il); + cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens); + cb(cur, "kqv_merged_cont", il); - ggml_build_forward_expand(gf, cur); + ggml_build_forward_expand(gf, cur); - cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo_enc, cur); - cb(cur, "kqv_out", il); - } - - if (il == n_layer - 1) { - // skip computing output for unused tokens - struct ggml_tensor * inp_out_ids = build_inp_out_ids(); - n_tokens = n_outputs; - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); - } - - struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); - cb(ffn_inp, "ffn_inp", il); - - // feed-forward network - { - cur = llm_build_norm(ctx0, ffn_inp, hparams, - model.layers[il].ffn_norm_enc, NULL, - LLM_NORM_RMS, cb, il); - cb(cur, "ffn_norm", il); - - // T5 uses relu, flan-T5 uses gelu-gated - cur = llm_build_ffn(ctx0, lctx, cur, - model.layers[il].ffn_up_enc, NULL, NULL, - model.layers[il].ffn_gate_enc, NULL, NULL, - model.layers[il].ffn_down_enc, NULL, NULL, - NULL, - model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU, - model.layers[il].ffn_gate_enc ? LLM_FFN_PAR : LLM_FFN_SEQ, - cb, il); - cb(cur, "ffn_out", il); - } - - cur = ggml_add(ctx0, cur, ffn_inp); - cb(cur, "ffn_out", il); - - ggml_tensor * layer_dir = lctx.cvec.tensor_for(il); - if (layer_dir != nullptr) { - cur = ggml_add(ctx0, cur, layer_dir); - } - cb(cur, "l_out", il); - - // input for next layer - inpL = cur; + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo_enc, cur); + cb(cur, "kqv_out", il); } - cur = inpL; - cb(cur, "result_embd", -1); - - cur = llm_build_norm(ctx0, cur, hparams, - model.output_norm_enc, NULL, - LLM_NORM_RMS, cb, -1); - cb(cur, "result_norm", -1); - } else { - GGML_ASSERT(n_outputs_enc > 0 && "call llama_encode() first"); - - struct ggml_tensor * embd_enc = llm_build_inp_embd_enc(); - struct ggml_tensor * pos_bucket_dec = llm_build_pos_bucket(true); - - struct ggml_tensor * KQ_mask_dec = build_inp_KQ_mask(); - struct ggml_tensor * KQ_mask_cross = llm_build_inp_KQ_mask_cross(); - - for (int il = 0; il < n_layer; ++il) { - struct ggml_tensor * inpSA = inpL; - - // norm - cur = llm_build_norm(ctx0, inpL, hparams, - model.layers[il].attn_norm, NULL, - LLM_NORM_RMS, cb, il); - cb(cur, "attn_norm", il); - - // self-attention - { - struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - llm_build_kv_store(ctx0, hparams, cparams, kv_self, gf, Kcur, Vcur, n_tokens, kv_head, cb, il); - - struct ggml_tensor * k = - ggml_view_3d(ctx0, kv_self.k_l[il], - n_embd_head_k, n_kv, n_head_kv, - ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa), - ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k), - 0); - cb(k, "k", il); - - struct ggml_tensor * v = - ggml_view_3d(ctx0, kv_self.v_l[il], - n_kv, n_embd_head_v, n_head_kv, - ggml_element_size(kv_self.v_l[il])*n_ctx, - ggml_element_size(kv_self.v_l[il])*n_ctx*n_embd_head_v, - 0); - cb(v, "v", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - - struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); - - struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); - cb(kq, "kq", il); - - struct ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b ? model.layers[il].attn_rel_b : model.layers[0].attn_rel_b; - struct ggml_tensor * pos_bias = llm_build_pos_bias(pos_bucket_dec, attn_rel_b); - struct ggml_tensor * kq_b = ggml_add(ctx0, kq, pos_bias); - cb(kq_b, "kq_b", il); - - kq = ggml_soft_max_ext(ctx0, kq_b, KQ_mask_dec, 1.0f, hparams.f_max_alibi_bias); - cb(kq, "kq_soft_max_ext", il); - - struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq); - cb(kqv, "kqv", il); - - struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3); - cb(kqv_merged, "kqv_merged", il); - - cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens); - cb(cur, "kqv_merged_cont", il); - - ggml_build_forward_expand(gf, cur); - - cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur); - cb(cur, "kqv_out", il); - } - - cur = ggml_add(ctx0, cur, inpSA); - cb(cur, "cross_inp", il); - - struct ggml_tensor * inpCA = cur; - - // norm - cur = llm_build_norm(ctx0, cur, hparams, - model.layers[il].attn_norm_cross, NULL, - LLM_NORM_RMS, cb, il); - cb(cur, "attn_norm_cross", il); - - // cross-attention - { - struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq_cross, cur); - cb(Qcur, "Qcur", il); - - struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk_cross, embd_enc); - cb(Kcur, "Kcur", il); - - struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv_cross, embd_enc); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_outputs_enc); - - struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); - struct ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3)); - - struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); - cb(kq, "kq", il); - - kq = ggml_soft_max_ext(ctx0, kq, KQ_mask_cross, 1.0f, hparams.f_max_alibi_bias); - cb(kq, "kq_soft_max_ext", il); - - struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_outputs_enc))); - cb(v, "v", il); - - struct ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_outputs_enc, n_embd_head, n_head_kv), kq); - cb(kqv, "kqv", il); - - struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3); - cb(kqv_merged, "kqv_merged", il); - - cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens); - cb(cur, "kqv_merged_cont", il); - - ggml_build_forward_expand(gf, cur); - - cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo_cross, cur); - cb(cur, "kqv_out", il); - } - - if (il == n_layer - 1) { - // skip computing output for unused tokens - struct ggml_tensor * inp_out_ids = build_inp_out_ids(); - n_tokens = n_outputs; - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); - inpCA = ggml_get_rows(ctx0, inpCA, inp_out_ids); - } - - struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpCA); - cb(ffn_inp, "ffn_inp", il); - - // feed-forward network - { - cur = llm_build_norm(ctx0, ffn_inp, hparams, - model.layers[il].ffn_norm, NULL, - LLM_NORM_RMS, cb, il); - cb(cur, "ffn_norm", il); - - // T5 uses relu, flan-T5 uses gelu-gated - cur = llm_build_ffn(ctx0, lctx, cur, - model.layers[il].ffn_up, NULL, NULL, - model.layers[il].ffn_gate, NULL, NULL, - model.layers[il].ffn_down, NULL, NULL, - NULL, - model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU, - model.layers[il].ffn_gate_enc ? LLM_FFN_PAR : LLM_FFN_SEQ, - cb, il); - cb(cur, "ffn_out", il); - } - - cur = ggml_add(ctx0, cur, ffn_inp); - cb(cur, "ffn_out", il); - - ggml_tensor * layer_dir = lctx.cvec.tensor_for(il); - if (layer_dir != nullptr) { - cur = ggml_add(ctx0, cur, layer_dir); - } - cb(cur, "l_out", il); - - // input for next layer - inpL = cur; + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + n_tokens = n_outputs; + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } - cur = inpL; - cb(cur, "result_embd", -1); + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); - cur = llm_build_norm(ctx0, cur, hparams, - model.output_norm, NULL, - LLM_NORM_RMS, cb, -1); - cb(cur, "result_norm", -1); + // feed-forward network + { + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm_enc, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); - // lm_head - cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); - cb(cur, "result_output", -1); + // T5 uses relu, flan-T5 uses gelu-gated + cur = llm_build_ffn(ctx0, lctx, cur, + model.layers[il].ffn_up_enc, NULL, NULL, + model.layers[il].ffn_gate_enc, NULL, NULL, + model.layers[il].ffn_down_enc, NULL, NULL, + NULL, + model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU, + model.layers[il].ffn_gate_enc ? LLM_FFN_PAR : LLM_FFN_SEQ, + cb, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + ggml_tensor * layer_dir = lctx.cvec.tensor_for(il); + if (layer_dir != nullptr) { + cur = ggml_add(ctx0, cur, layer_dir); + } + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; } + cur = inpL; + cb(cur, "result_embd", -1); + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm_enc, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + + struct ggml_cgraph * build_t5_decoder() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + + // mutable variable, needed during the last layer of the computation to skip unused tokens + int32_t n_tokens = this->n_tokens; + + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + GGML_ASSERT(!lctx.is_encoding); + GGML_ASSERT(n_outputs_enc > 0 && "call llama_encode() first"); + + struct ggml_tensor * embd_enc = llm_build_inp_embd_enc(); + struct ggml_tensor * pos_bucket_dec = llm_build_pos_bucket(true); + + struct ggml_tensor * KQ_mask_dec = build_inp_KQ_mask(); + struct ggml_tensor * KQ_mask_cross = llm_build_inp_KQ_mask_cross(); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + llm_build_kv_store(ctx0, hparams, cparams, kv_self, gf, Kcur, Vcur, n_tokens, kv_head, cb, il); + + struct ggml_tensor * k = + ggml_view_3d(ctx0, kv_self.k_l[il], + n_embd_head_k, n_kv, n_head_kv, + ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa), + ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k), + 0); + cb(k, "k", il); + + struct ggml_tensor * v = + ggml_view_3d(ctx0, kv_self.v_l[il], + n_kv, n_embd_head_v, n_head_kv, + ggml_element_size(kv_self.v_l[il])*n_ctx, + ggml_element_size(kv_self.v_l[il])*n_ctx*n_embd_head_v, + 0); + cb(v, "v", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + + struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); + + struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); + cb(kq, "kq", il); + + struct ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b ? model.layers[il].attn_rel_b : model.layers[0].attn_rel_b; + struct ggml_tensor * pos_bias = llm_build_pos_bias(pos_bucket_dec, attn_rel_b); + struct ggml_tensor * kq_b = ggml_add(ctx0, kq, pos_bias); + cb(kq_b, "kq_b", il); + + kq = ggml_soft_max_ext(ctx0, kq_b, KQ_mask_dec, 1.0f, hparams.f_max_alibi_bias); + cb(kq, "kq_soft_max_ext", il); + + struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq); + cb(kqv, "kqv", il); + + struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3); + cb(kqv_merged, "kqv_merged", il); + + cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens); + cb(cur, "kqv_merged_cont", il); + + ggml_build_forward_expand(gf, cur); + + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur); + cb(cur, "kqv_out", il); + } + + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "cross_inp", il); + + struct ggml_tensor * inpCA = cur; + + // norm + cur = llm_build_norm(ctx0, cur, hparams, + model.layers[il].attn_norm_cross, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm_cross", il); + + // cross-attention + { + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq_cross, cur); + cb(Qcur, "Qcur", il); + + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk_cross, embd_enc); + cb(Kcur, "Kcur", il); + + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv_cross, embd_enc); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_outputs_enc); + + struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); + struct ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3)); + + struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); + cb(kq, "kq", il); + + kq = ggml_soft_max_ext(ctx0, kq, KQ_mask_cross, 1.0f, hparams.f_max_alibi_bias); + cb(kq, "kq_soft_max_ext", il); + + struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_outputs_enc))); + cb(v, "v", il); + + struct ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_outputs_enc, n_embd_head, n_head_kv), kq); + cb(kqv, "kqv", il); + + struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3); + cb(kqv_merged, "kqv_merged", il); + + cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens); + cb(cur, "kqv_merged_cont", il); + + ggml_build_forward_expand(gf, cur); + + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo_cross, cur); + cb(cur, "kqv_out", il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + n_tokens = n_outputs; + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + inpCA = ggml_get_rows(ctx0, inpCA, inp_out_ids); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpCA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + { + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + // T5 uses relu, flan-T5 uses gelu-gated + cur = llm_build_ffn(ctx0, lctx, cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU, + model.layers[il].ffn_gate_enc ? LLM_FFN_PAR : LLM_FFN_SEQ, + cb, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + ggml_tensor * layer_dir = lctx.cvec.tensor_for(il); + if (layer_dir != nullptr) { + cur = ggml_add(ctx0, cur, layer_dir); + } + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + cb(cur, "result_embd", -1); + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); + cb(cur, "result_output", -1); + ggml_build_forward_expand(gf, cur); return gf; @@ -13898,7 +13980,15 @@ static struct ggml_cgraph * llama_build_graph( } break; case LLM_ARCH_T5: { - result = llm.build_t5(); + if (lctx.is_encoding) { + result = llm.build_t5_encoder(); + } else { + result = llm.build_t5_decoder(); + } + } break; + case LLM_ARCH_T5ENCODER: + { + result = llm.build_t5_encoder(); } break; case LLM_ARCH_JAIS: { @@ -14346,7 +14436,7 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) { // TODO: use a per-batch flag for logits presence instead const bool has_logits = !cparams.embeddings; - const bool has_embd = lctx.is_encoding || (cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE)); + const bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE); const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0; const size_t embd_size = has_embd ? n_embd*n_outputs_max : 0; @@ -14829,9 +14919,24 @@ static int llama_encode_internal( ggml_cgraph * gf = llama_build_graph(lctx, batch, false); // the output embeddings after the final encoder normalization - struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 1]; + struct ggml_tensor * embd = nullptr; - GGML_ASSERT(strcmp(embd->name, "result_norm") == 0); + // there are two cases here + if (llama_model_has_decoder(&lctx.model)) { + // first case is an encoder-decoder T5 model where embeddings are passed to decoder + embd = gf->nodes[gf->n_nodes - 1]; + GGML_ASSERT(strcmp(embd->name, "result_norm") == 0 && "missing result_output tensor"); + } else { + // second case is an encoder-only T5 model + if (cparams.embeddings) { + // only output embeddings if required + embd = gf->nodes[gf->n_nodes - 1]; + if (strcmp(embd->name, "result_embd_pooled") != 0) { + embd = gf->nodes[gf->n_nodes - 2]; + } + GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0 && "missing embeddings tensor"); + } + } ggml_backend_sched_alloc_graph(lctx.sched, gf); @@ -14844,20 +14949,54 @@ static int llama_encode_internal( ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd); GGML_ASSERT(backend_embd != nullptr); - // extract token embeddings - GGML_ASSERT(lctx.embd != nullptr); + if (llama_model_has_decoder(&lctx.model)) { + lctx.embd_enc.resize(n_tokens*n_embd); + float * embd_out = lctx.embd_enc.data(); - lctx.embd_enc.resize(n_tokens*n_embd); - float * embd_out = lctx.embd_enc.data(); + ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_tokens*n_embd*sizeof(float)); - ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_tokens*n_embd*sizeof(float)); + // remember the sequence ids used during the encoding - needed for cross attention later + lctx.seq_ids_enc.resize(n_tokens); + for (uint32_t i = 0; i < n_tokens; i++) { + for (int s = 0; s < batch.n_seq_id[i]; s++) { + llama_seq_id seq_id = batch.seq_id[i][s]; + lctx.seq_ids_enc[i].insert(seq_id); + } + } + } else { + GGML_ASSERT(lctx.embd != nullptr); - // remember the sequence ids used during the encoding - needed for cross attention later - lctx.seq_ids_enc.resize(n_tokens); - for (uint32_t i = 0; i < n_tokens; i++) { - for (int s = 0; s < batch.n_seq_id[i]; s++) { - llama_seq_id seq_id = batch.seq_id[i][s]; - lctx.seq_ids_enc[i].insert(seq_id); + switch (cparams.pooling_type) { + case LLAMA_POOLING_TYPE_NONE: + { + // extract token embeddings + GGML_ASSERT(lctx.embd != nullptr); + float * embd_out = lctx.embd; + + GGML_ASSERT(n_tokens*n_embd <= (int64_t) lctx.embd_size); + ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_tokens*n_embd*sizeof(float)); + } break; + case LLAMA_POOLING_TYPE_MEAN: + case LLAMA_POOLING_TYPE_CLS: + case LLAMA_POOLING_TYPE_LAST: + { + // extract sequence embeddings + auto & embd_seq_out = lctx.embd_seq; + embd_seq_out.clear(); + + for (uint32_t i = 0; i < n_tokens; i++) { + const llama_seq_id seq_id = batch.seq_id[i][0]; + if (embd_seq_out.find(seq_id) != embd_seq_out.end()) { + continue; + } + embd_seq_out[seq_id].resize(n_embd); + ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float)); + } + } break; + case LLAMA_POOLING_TYPE_UNSPECIFIED: + { + GGML_ABORT("unknown pooling type"); + } } } } @@ -16567,6 +16706,8 @@ struct llama_context * llama_new_context_with_model( ctx->sampling.rng = std::mt19937(params.seed); ctx->logits_all = params.logits_all; + // build worst-case graph for encoder if a model contains encoder + ctx->is_encoding = llama_model_has_encoder(model); uint32_t kv_size = cparams.n_ctx; ggml_type type_k = params.type_k; @@ -16881,6 +17022,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_MAMBA: case LLM_ARCH_JINA_BERT_V2: case LLM_ARCH_T5: + case LLM_ARCH_T5ENCODER: case LLM_ARCH_JAIS: return LLAMA_ROPE_TYPE_NONE; @@ -17028,8 +17170,16 @@ struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const ch bool llama_model_has_encoder(const struct llama_model * model) { switch (model->arch) { - case LLM_ARCH_T5: return true; - default: return false; + case LLM_ARCH_T5: return true; + case LLM_ARCH_T5ENCODER: return true; + default: return false; + } +} + +bool llama_model_has_decoder(const struct llama_model * model) { + switch (model->arch) { + case LLM_ARCH_T5ENCODER: return false; + default: return true; } }