mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-11-11 13:30:35 +00:00
Add support for encoder-only T5 models (#8900)
* gguf-py : add T5ENCODER model architecture * common : call llama_decode() during warmup only if the model has decoder * convert-hf : add T5EncoderModel * llama : add llama_model_has_decoder() API function * llama : split build_t5() into build_t5_encoder() and build_t5_decoder() * llama : add support for LLM_ARCH_T5ENCODER * llama-embedding : add support for LLAMA_POOLING_TYPE_NONE * llama-embedding : add support for encoder-only models --------- Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com>
This commit is contained in:
parent
911b437f22
commit
7c3f55c100
@ -2156,7 +2156,9 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
|
||||
tmp.clear();
|
||||
tmp.push_back(decoder_start_token_id);
|
||||
}
|
||||
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
|
||||
if (llama_model_has_decoder(model)) {
|
||||
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
|
||||
}
|
||||
llama_kv_cache_clear(lctx);
|
||||
llama_synchronize(lctx);
|
||||
llama_reset_timings(lctx);
|
||||
|
@ -3324,6 +3324,145 @@ class T5Model(Model):
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
|
||||
@Model.register("T5EncoderModel")
|
||||
class T5EncoderModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.T5ENCODER
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.shared_token_embeddings_found = False
|
||||
|
||||
def set_vocab(self):
|
||||
# to avoid TypeError: Descriptors cannot be created directly
|
||||
# exception when importing sentencepiece_model_pb2
|
||||
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
from sentencepiece import sentencepiece_model_pb2 as model
|
||||
|
||||
tokenizer_path = self.dir_model / 'tokenizer.model'
|
||||
|
||||
# many older models use spiece.model tokenizer model filename
|
||||
if not tokenizer_path.is_file():
|
||||
tokenizer_path = self.dir_model / 'spiece.model'
|
||||
|
||||
if not tokenizer_path.is_file():
|
||||
raise FileNotFoundError(f"File not found: {tokenizer_path}")
|
||||
|
||||
sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue]
|
||||
sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
|
||||
|
||||
# some models like Pile-T5 family use BPE tokenizer instead of Unigram
|
||||
if sentencepiece_model.trainer_spec.model_type == 2: # BPE
|
||||
# assure the tokenizer model file name is correct
|
||||
assert tokenizer_path.name == 'tokenizer.model'
|
||||
return self._set_vocab_sentencepiece()
|
||||
else:
|
||||
assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM
|
||||
|
||||
add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix
|
||||
remove_whitespaces = sentencepiece_model.normalizer_spec.remove_extra_whitespaces
|
||||
precompiled_charsmap = sentencepiece_model.normalizer_spec.precompiled_charsmap
|
||||
|
||||
tokenizer = SentencePieceProcessor()
|
||||
tokenizer.LoadFromFile(str(tokenizer_path))
|
||||
|
||||
vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())
|
||||
|
||||
tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)]
|
||||
scores: list[float] = [-10000.0] * vocab_size
|
||||
toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size
|
||||
|
||||
for token_id in range(tokenizer.vocab_size()):
|
||||
piece = tokenizer.IdToPiece(token_id)
|
||||
text = piece.encode("utf-8")
|
||||
score = tokenizer.GetScore(token_id)
|
||||
|
||||
toktype = SentencePieceTokenTypes.NORMAL
|
||||
if tokenizer.IsUnknown(token_id):
|
||||
toktype = SentencePieceTokenTypes.UNKNOWN
|
||||
elif tokenizer.IsControl(token_id):
|
||||
toktype = SentencePieceTokenTypes.CONTROL
|
||||
elif tokenizer.IsUnused(token_id):
|
||||
toktype = SentencePieceTokenTypes.UNUSED
|
||||
elif tokenizer.IsByte(token_id):
|
||||
toktype = SentencePieceTokenTypes.BYTE
|
||||
|
||||
tokens[token_id] = text
|
||||
scores[token_id] = score
|
||||
toktypes[token_id] = toktype
|
||||
|
||||
added_tokens_file = self.dir_model / 'added_tokens.json'
|
||||
if added_tokens_file.is_file():
|
||||
with open(added_tokens_file, "r", encoding="utf-8") as f:
|
||||
added_tokens_json = json.load(f)
|
||||
for key in added_tokens_json:
|
||||
token_id = added_tokens_json[key]
|
||||
if token_id >= vocab_size:
|
||||
logger.warning(f'ignore token {token_id}: id is out of range, max={vocab_size - 1}')
|
||||
continue
|
||||
|
||||
tokens[token_id] = key.encode("utf-8")
|
||||
scores[token_id] = -1000.0
|
||||
toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED
|
||||
|
||||
if vocab_size > len(tokens):
|
||||
pad_count = vocab_size - len(tokens)
|
||||
logger.debug(f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]")
|
||||
for i in range(1, pad_count + 1):
|
||||
tokens.append(bytes(f"[PAD{i}]", encoding="utf-8"))
|
||||
scores.append(-1000.0)
|
||||
toktypes.append(SentencePieceTokenTypes.UNUSED)
|
||||
|
||||
self.gguf_writer.add_tokenizer_model("t5")
|
||||
self.gguf_writer.add_tokenizer_pre("default")
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_scores(scores)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
self.gguf_writer.add_add_space_prefix(add_prefix)
|
||||
self.gguf_writer.add_remove_extra_whitespaces(remove_whitespaces)
|
||||
if precompiled_charsmap:
|
||||
self.gguf_writer.add_precompiled_charsmap(precompiled_charsmap)
|
||||
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
self.gguf_writer.add_add_bos_token(False)
|
||||
self.gguf_writer.add_add_eos_token(True)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
if (n_ctx := self.find_hparam(["n_positions"], optional=True)) is None:
|
||||
logger.warning("Couldn't find context length in config.json, assuming default value of 512")
|
||||
n_ctx = 512
|
||||
self.gguf_writer.add_context_length(n_ctx)
|
||||
self.gguf_writer.add_embedding_length(self.hparams["d_model"])
|
||||
self.gguf_writer.add_feed_forward_length(self.hparams["d_ff"])
|
||||
self.gguf_writer.add_block_count(self.hparams["num_layers"])
|
||||
self.gguf_writer.add_head_count(self.hparams["num_heads"])
|
||||
self.gguf_writer.add_key_length(self.hparams["d_kv"])
|
||||
self.gguf_writer.add_value_length(self.hparams["d_kv"])
|
||||
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
|
||||
self.gguf_writer.add_relative_attn_buckets_count(self.hparams["relative_attention_num_buckets"])
|
||||
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"])
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
del bid # unused
|
||||
|
||||
# T5 based models contain shared token embeddings tensors saved randomly as either "encoder.embed_tokens.weight",
|
||||
# "decoder.embed_tokens.weight" or "shared.weight" tensor. In some models there are even multiple of them stored
|
||||
# in the safetensors files. We use the first tensor from these three as the token embeddings for both encoder
|
||||
# and decoder and ignore the remaining ones.
|
||||
if name in ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "shared.weight"]:
|
||||
if not self.shared_token_embeddings_found:
|
||||
name = "shared.weight"
|
||||
self.shared_token_embeddings_found = True
|
||||
else:
|
||||
logger.debug(f"Skipping shared tensor {name!r} in safetensors so that convert can end normally.")
|
||||
return []
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
|
||||
@Model.register("JAISLMHeadModel")
|
||||
class JaisModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.JAIS
|
||||
|
@ -31,13 +31,24 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
|
||||
}
|
||||
|
||||
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) {
|
||||
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
|
||||
const struct llama_model * model = llama_get_model(ctx);
|
||||
|
||||
// clear previous kv_cache values (irrelevant for embeddings)
|
||||
llama_kv_cache_clear(ctx);
|
||||
|
||||
// run model
|
||||
fprintf(stderr, "%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
|
||||
if (llama_decode(ctx, batch) < 0) {
|
||||
fprintf(stderr, "%s : failed to decode\n", __func__);
|
||||
if (llama_model_has_encoder(model) && !llama_model_has_decoder(model)) {
|
||||
// encoder-only model
|
||||
if (llama_encode(ctx, batch) < 0) {
|
||||
fprintf(stderr, "%s : failed to encode\n", __func__);
|
||||
}
|
||||
} else if (!llama_model_has_encoder(model) && llama_model_has_decoder(model)) {
|
||||
// decoder-only model
|
||||
if (llama_decode(ctx, batch) < 0) {
|
||||
fprintf(stderr, "%s : failed to decode\n", __func__);
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < batch.n_tokens; i++) {
|
||||
@ -45,11 +56,22 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
|
||||
continue;
|
||||
}
|
||||
|
||||
// try to get sequence embeddings - supported only when pooling_type is not NONE
|
||||
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
|
||||
GGML_ASSERT(embd != NULL && "failed to get sequence embeddings");
|
||||
const float * embd = nullptr;
|
||||
int embd_pos = 0;
|
||||
|
||||
float * out = output + batch.seq_id[i][0] * n_embd;
|
||||
if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
|
||||
// try to get token embeddings
|
||||
embd = llama_get_embeddings_ith(ctx, i);
|
||||
embd_pos = i;
|
||||
GGML_ASSERT(embd != NULL && "failed to get token embeddings");
|
||||
} else {
|
||||
// try to get sequence embeddings - supported only when pooling_type is not NONE
|
||||
embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
|
||||
embd_pos = batch.seq_id[i][0];
|
||||
GGML_ASSERT(embd != NULL && "failed to get sequence embeddings");
|
||||
}
|
||||
|
||||
float * out = output + embd_pos * n_embd;
|
||||
llama_embd_normalize(embd, out, n_embd, embd_norm);
|
||||
}
|
||||
}
|
||||
@ -93,8 +115,9 @@ int main(int argc, char ** argv) {
|
||||
const int n_ctx = llama_n_ctx(ctx);
|
||||
|
||||
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
|
||||
if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
|
||||
fprintf(stderr, "%s: error: pooling type NONE not supported\n", __func__);
|
||||
|
||||
if (llama_model_has_encoder(model) && llama_model_has_decoder(model)) {
|
||||
fprintf(stderr, "%s: error: computing embeddings in encoder-decoder models is not supported\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
@ -153,13 +176,23 @@ int main(int argc, char ** argv) {
|
||||
const int n_prompts = prompts.size();
|
||||
struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
|
||||
|
||||
// count number of embeddings
|
||||
int n_embd_count = 0;
|
||||
if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
|
||||
for (int k = 0; k < n_prompts; k++) {
|
||||
n_embd_count += inputs[k].size();
|
||||
}
|
||||
} else {
|
||||
n_embd_count = n_prompts;
|
||||
}
|
||||
|
||||
// allocate output
|
||||
const int n_embd = llama_n_embd(model);
|
||||
std::vector<float> embeddings(n_prompts * n_embd, 0);
|
||||
std::vector<float> embeddings(n_embd_count * n_embd, 0);
|
||||
float * emb = embeddings.data();
|
||||
|
||||
// break into batches
|
||||
int p = 0; // number of prompts processed already
|
||||
int e = 0; // number of embeddings already stored
|
||||
int s = 0; // number of prompts in current batch
|
||||
for (int k = 0; k < n_prompts; k++) {
|
||||
// clamp to n_batch tokens
|
||||
@ -169,11 +202,11 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// encode if at capacity
|
||||
if (batch.n_tokens + n_toks > n_batch) {
|
||||
float * out = emb + p * n_embd;
|
||||
float * out = emb + e * n_embd;
|
||||
batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize);
|
||||
llama_batch_clear(batch);
|
||||
p += s;
|
||||
e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.n_tokens : s;
|
||||
s = 0;
|
||||
llama_batch_clear(batch);
|
||||
}
|
||||
|
||||
// add to batch
|
||||
@ -182,40 +215,63 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
// final batch
|
||||
float * out = emb + p * n_embd;
|
||||
float * out = emb + e * n_embd;
|
||||
batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize);
|
||||
|
||||
if (params.embd_out.empty()) {
|
||||
// print the first part of the embeddings or for a single prompt, the full embedding
|
||||
fprintf(stdout, "\n");
|
||||
for (int j = 0; j < n_prompts; j++) {
|
||||
fprintf(stdout, "embedding %d: ", j);
|
||||
for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd) : n_embd); i++) {
|
||||
if (params.embd_normalize == 0) {
|
||||
fprintf(stdout, "%6.0f ", emb[j * n_embd + i]);
|
||||
} else {
|
||||
fprintf(stdout, "%9.6f ", emb[j * n_embd + i]);
|
||||
}
|
||||
}
|
||||
fprintf(stdout, "\n");
|
||||
}
|
||||
|
||||
// print cosine similarity matrix
|
||||
if (n_prompts > 1) {
|
||||
fprintf(stdout, "\n");
|
||||
printf("cosine similarity matrix:\n\n");
|
||||
for (int i = 0; i < n_prompts; i++) {
|
||||
fprintf(stdout, "%6.6s ", prompts[i].c_str());
|
||||
}
|
||||
fprintf(stdout, "\n");
|
||||
for (int i = 0; i < n_prompts; i++) {
|
||||
for (int j = 0; j < n_prompts; j++) {
|
||||
float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
|
||||
fprintf(stdout, "%6.2f ", sim);
|
||||
if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
|
||||
for (int j = 0; j < n_embd_count; j++) {
|
||||
fprintf(stdout, "embedding %d: ", j);
|
||||
for (int i = 0; i < std::min(3, n_embd); i++) {
|
||||
if (params.embd_normalize == 0) {
|
||||
fprintf(stdout, "%6.0f ", emb[j * n_embd + i]);
|
||||
} else {
|
||||
fprintf(stdout, "%9.6f ", emb[j * n_embd + i]);
|
||||
}
|
||||
}
|
||||
fprintf(stdout, " ... ");
|
||||
for (int i = n_embd - 3; i < n_embd; i++) {
|
||||
if (params.embd_normalize == 0) {
|
||||
fprintf(stdout, "%6.0f ", emb[j * n_embd + i]);
|
||||
} else {
|
||||
fprintf(stdout, "%9.6f ", emb[j * n_embd + i]);
|
||||
}
|
||||
}
|
||||
fprintf(stdout, "%1.10s", prompts[i].c_str());
|
||||
fprintf(stdout, "\n");
|
||||
}
|
||||
} else {
|
||||
// print the first part of the embeddings or for a single prompt, the full embedding
|
||||
for (int j = 0; j < n_prompts; j++) {
|
||||
fprintf(stdout, "embedding %d: ", j);
|
||||
for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd) : n_embd); i++) {
|
||||
if (params.embd_normalize == 0) {
|
||||
fprintf(stdout, "%6.0f ", emb[j * n_embd + i]);
|
||||
} else {
|
||||
fprintf(stdout, "%9.6f ", emb[j * n_embd + i]);
|
||||
}
|
||||
}
|
||||
fprintf(stdout, "\n");
|
||||
}
|
||||
|
||||
// print cosine similarity matrix
|
||||
if (n_prompts > 1) {
|
||||
fprintf(stdout, "\n");
|
||||
printf("cosine similarity matrix:\n\n");
|
||||
for (int i = 0; i < n_prompts; i++) {
|
||||
fprintf(stdout, "%6.6s ", prompts[i].c_str());
|
||||
}
|
||||
fprintf(stdout, "\n");
|
||||
for (int i = 0; i < n_prompts; i++) {
|
||||
for (int j = 0; j < n_prompts; j++) {
|
||||
float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
|
||||
fprintf(stdout, "%6.2f ", sim);
|
||||
}
|
||||
fprintf(stdout, "%1.10s", prompts[i].c_str());
|
||||
fprintf(stdout, "\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -233,23 +289,23 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
fprintf(stdout, notArray ? "]\n }" : "]");
|
||||
j++;
|
||||
if (j < n_prompts) fprintf(stdout, notArray ? ",\n" : ","); else break;
|
||||
if (j < n_embd_count) fprintf(stdout, notArray ? ",\n" : ","); else break;
|
||||
}
|
||||
fprintf(stdout, notArray ? "\n ]" : "]\n");
|
||||
|
||||
if (params.embd_out == "json+" && n_prompts > 1) {
|
||||
fprintf(stdout, ",\n \"cosineSimilarity\": [\n");
|
||||
for (int i = 0;;) { // at least two iteration (n_prompts > 1)
|
||||
for (int i = 0;;) { // at least two iteration (n_embd_count > 1)
|
||||
fprintf(stdout, " [");
|
||||
for (int j = 0;;) { // at least two iteration (n_prompts > 1)
|
||||
for (int j = 0;;) { // at least two iteration (n_embd_count > 1)
|
||||
float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
|
||||
fprintf(stdout, "%6.2f", sim);
|
||||
j++;
|
||||
if (j < n_prompts) fprintf(stdout, ", "); else break;
|
||||
if (j < n_embd_count) fprintf(stdout, ", "); else break;
|
||||
}
|
||||
fprintf(stdout, " ]");
|
||||
i++;
|
||||
if (i < n_prompts) fprintf(stdout, ",\n"); else break;
|
||||
if (i < n_embd_count) fprintf(stdout, ",\n"); else break;
|
||||
}
|
||||
fprintf(stdout, "\n ]");
|
||||
}
|
||||
|
@ -217,6 +217,7 @@ class MODEL_ARCH(IntEnum):
|
||||
CHATGLM = auto()
|
||||
BITNET = auto()
|
||||
T5 = auto()
|
||||
T5ENCODER = auto()
|
||||
JAIS = auto()
|
||||
|
||||
|
||||
@ -344,6 +345,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.CHATGLM: "chatglm",
|
||||
MODEL_ARCH.BITNET: "bitnet",
|
||||
MODEL_ARCH.T5: "t5",
|
||||
MODEL_ARCH.T5ENCODER: "t5encoder",
|
||||
MODEL_ARCH.JAIS: "jais",
|
||||
}
|
||||
|
||||
@ -1036,6 +1038,21 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.ENC_FFN_UP,
|
||||
MODEL_TENSOR.ENC_OUTPUT_NORM,
|
||||
],
|
||||
MODEL_ARCH.T5ENCODER: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ENC_ATTN_NORM,
|
||||
MODEL_TENSOR.ENC_ATTN_Q,
|
||||
MODEL_TENSOR.ENC_ATTN_K,
|
||||
MODEL_TENSOR.ENC_ATTN_V,
|
||||
MODEL_TENSOR.ENC_ATTN_OUT,
|
||||
MODEL_TENSOR.ENC_ATTN_REL_B,
|
||||
MODEL_TENSOR.ENC_FFN_NORM,
|
||||
MODEL_TENSOR.ENC_FFN_GATE,
|
||||
MODEL_TENSOR.ENC_FFN_DOWN,
|
||||
MODEL_TENSOR.ENC_FFN_UP,
|
||||
MODEL_TENSOR.ENC_OUTPUT_NORM,
|
||||
],
|
||||
MODEL_ARCH.JAIS: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
|
@ -504,6 +504,9 @@ extern "C" {
|
||||
// Returns true if the model contains an encoder that requires llama_encode() call
|
||||
LLAMA_API bool llama_model_has_encoder(const struct llama_model * model);
|
||||
|
||||
// Returns true if the model contains a decoder that requires llama_decode() call
|
||||
LLAMA_API bool llama_model_has_decoder(const struct llama_model * model);
|
||||
|
||||
// For encoder-decoder models, this function returns id of the token that must be provided
|
||||
// to the decoder to start generating output sequence. For other models, it returns -1.
|
||||
LLAMA_API llama_token llama_model_decoder_start_token(const struct llama_model * model);
|
||||
|
730
src/llama.cpp
730
src/llama.cpp
@ -208,6 +208,7 @@ enum llm_arch {
|
||||
LLM_ARCH_CHATGLM,
|
||||
LLM_ARCH_BITNET,
|
||||
LLM_ARCH_T5,
|
||||
LLM_ARCH_T5ENCODER,
|
||||
LLM_ARCH_JAIS,
|
||||
LLM_ARCH_UNKNOWN,
|
||||
};
|
||||
@ -252,6 +253,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_CHATGLM, "chatglm" },
|
||||
{ LLM_ARCH_BITNET, "bitnet" },
|
||||
{ LLM_ARCH_T5, "t5" },
|
||||
{ LLM_ARCH_T5ENCODER, "t5encoder" },
|
||||
{ LLM_ARCH_JAIS, "jais" },
|
||||
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
||||
};
|
||||
@ -1261,6 +1263,24 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
|
||||
{ LLM_TENSOR_ENC_FFN_UP, "enc.blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_T5ENCODER,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ENC_OUTPUT_NORM, "enc.output_norm" },
|
||||
{ LLM_TENSOR_ENC_ATTN_NORM, "enc.blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ENC_ATTN_Q, "enc.blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ENC_ATTN_K, "enc.blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ENC_ATTN_V, "enc.blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ENC_ATTN_OUT, "enc.blk.%d.attn_o" },
|
||||
{ LLM_TENSOR_ENC_ATTN_REL_B, "enc.blk.%d.attn_rel_b" },
|
||||
{ LLM_TENSOR_ENC_FFN_NORM, "enc.blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_ENC_FFN_GATE, "enc.blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_ENC_FFN_DOWN, "enc.blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_ENC_FFN_UP, "enc.blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_JAIS,
|
||||
{
|
||||
@ -5187,6 +5207,12 @@ static void llm_load_hparams(
|
||||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_T5ENCODER:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
ml.get_key(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts);
|
||||
model.type = e_model::MODEL_UNKNOWN;
|
||||
} break;
|
||||
case LLM_ARCH_JAIS:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
||||
@ -7421,6 +7447,42 @@ static bool llm_load_tensors(
|
||||
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_FFN_UP, "weight", i), {n_embd, n_ff});
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_T5ENCODER:
|
||||
{
|
||||
const auto n_rel_attn_bkts = hparams.n_rel_attn_bkts;
|
||||
|
||||
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
|
||||
|
||||
// output
|
||||
{
|
||||
model.output_norm_enc = ml.create_tensor(ctx_output, tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd});
|
||||
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
// if output is NULL, init from the input tok embed
|
||||
if (model.output == NULL) {
|
||||
model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
ggml_context * ctx_layer = ctx_for_layer(i);
|
||||
ggml_context * ctx_split = ctx_for_layer_split(i);
|
||||
|
||||
auto & layer = model.layers[i];
|
||||
|
||||
layer.attn_norm_enc = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_ATTN_NORM, "weight", i), {n_embd});
|
||||
layer.attn_rel_b_enc = ml.create_tensor(ctx_input, tn(LLM_TENSOR_ENC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
|
||||
layer.wq_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa});
|
||||
layer.wk_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa});
|
||||
layer.wv_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa});
|
||||
layer.wo_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd});
|
||||
|
||||
layer.ffn_norm_enc = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_FFN_NORM, "weight", i), {n_embd});
|
||||
layer.ffn_gate_enc = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd, n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
layer.ffn_down_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), { n_ff, n_embd});
|
||||
layer.ffn_up_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_FFN_UP, "weight", i), {n_embd, n_ff});
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_JAIS:
|
||||
{
|
||||
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
|
||||
@ -13135,7 +13197,7 @@ struct llm_build_context {
|
||||
return gf;
|
||||
}
|
||||
|
||||
struct ggml_cgraph * build_t5() {
|
||||
struct ggml_cgraph * build_t5_encoder() {
|
||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
||||
|
||||
// mutable variable, needed during the last layer of the computation to skip unused tokens
|
||||
@ -13150,303 +13212,323 @@ struct llm_build_context {
|
||||
|
||||
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
|
||||
|
||||
if (lctx.is_encoding) {
|
||||
struct ggml_tensor * pos_bucket_enc = llm_build_pos_bucket(false);
|
||||
GGML_ASSERT(lctx.is_encoding);
|
||||
struct ggml_tensor * pos_bucket_enc = llm_build_pos_bucket(false);
|
||||
|
||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||
struct ggml_tensor * KQ_mask_enc = build_inp_KQ_mask(false);
|
||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||
struct ggml_tensor * KQ_mask_enc = build_inp_KQ_mask(false);
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * inpSA = inpL;
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * inpSA = inpL;
|
||||
|
||||
// norm
|
||||
cur = llm_build_norm(ctx0, inpL, hparams,
|
||||
model.layers[il].attn_norm_enc, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
// norm
|
||||
cur = llm_build_norm(ctx0, inpL, hparams,
|
||||
model.layers[il].attn_norm_enc, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self-attention
|
||||
{
|
||||
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq_enc, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
// self-attention
|
||||
{
|
||||
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq_enc, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk_enc, cur);
|
||||
cb(Kcur, "Kcur", il);
|
||||
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk_enc, cur);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv_enc, cur);
|
||||
cb(Vcur, "Vcur", il);
|
||||
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv_enc, cur);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
|
||||
struct ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
|
||||
struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
|
||||
struct ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
|
||||
|
||||
struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
||||
cb(kq, "kq", il);
|
||||
struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
||||
cb(kq, "kq", il);
|
||||
|
||||
struct ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b_enc ? model.layers[il].attn_rel_b_enc : model.layers[0].attn_rel_b_enc;
|
||||
struct ggml_tensor * pos_bias = llm_build_pos_bias(pos_bucket_enc, attn_rel_b);
|
||||
struct ggml_tensor * kq_b = ggml_add(ctx0, kq, pos_bias);
|
||||
cb(kq_b, "kq_b", il);
|
||||
struct ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b_enc ? model.layers[il].attn_rel_b_enc : model.layers[0].attn_rel_b_enc;
|
||||
struct ggml_tensor * pos_bias = llm_build_pos_bias(pos_bucket_enc, attn_rel_b);
|
||||
struct ggml_tensor * kq_b = ggml_add(ctx0, kq, pos_bias);
|
||||
cb(kq_b, "kq_b", il);
|
||||
|
||||
kq = ggml_soft_max_ext(ctx0, kq_b, KQ_mask_enc, 1.0f, hparams.f_max_alibi_bias);
|
||||
cb(kq, "kq_soft_max_ext", il);
|
||||
kq = ggml_soft_max_ext(ctx0, kq_b, KQ_mask_enc, 1.0f, hparams.f_max_alibi_bias);
|
||||
cb(kq, "kq_soft_max_ext", il);
|
||||
|
||||
struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_tokens)));
|
||||
cb(v, "v", il);
|
||||
struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_tokens)));
|
||||
cb(v, "v", il);
|
||||
|
||||
struct ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_tokens, n_embd_head, n_head_kv), kq);
|
||||
cb(kqv, "kqv", il);
|
||||
struct ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_tokens, n_embd_head, n_head_kv), kq);
|
||||
cb(kqv, "kqv", il);
|
||||
|
||||
struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
|
||||
cb(kqv_merged, "kqv_merged", il);
|
||||
struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
|
||||
cb(kqv_merged, "kqv_merged", il);
|
||||
|
||||
cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens);
|
||||
cb(cur, "kqv_merged_cont", il);
|
||||
cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens);
|
||||
cb(cur, "kqv_merged_cont", il);
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo_enc, cur);
|
||||
cb(cur, "kqv_out", il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1) {
|
||||
// skip computing output for unused tokens
|
||||
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
n_tokens = n_outputs;
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
// feed-forward network
|
||||
{
|
||||
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
||||
model.layers[il].ffn_norm_enc, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
// T5 uses relu, flan-T5 uses gelu-gated
|
||||
cur = llm_build_ffn(ctx0, lctx, cur,
|
||||
model.layers[il].ffn_up_enc, NULL, NULL,
|
||||
model.layers[il].ffn_gate_enc, NULL, NULL,
|
||||
model.layers[il].ffn_down_enc, NULL, NULL,
|
||||
NULL,
|
||||
model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU,
|
||||
model.layers[il].ffn_gate_enc ? LLM_FFN_PAR : LLM_FFN_SEQ,
|
||||
cb, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
}
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
cb(cur, "ffn_out", il);
|
||||
|
||||
ggml_tensor * layer_dir = lctx.cvec.tensor_for(il);
|
||||
if (layer_dir != nullptr) {
|
||||
cur = ggml_add(ctx0, cur, layer_dir);
|
||||
}
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo_enc, cur);
|
||||
cb(cur, "kqv_out", il);
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
cb(cur, "result_embd", -1);
|
||||
|
||||
cur = llm_build_norm(ctx0, cur, hparams,
|
||||
model.output_norm_enc, NULL,
|
||||
LLM_NORM_RMS, cb, -1);
|
||||
cb(cur, "result_norm", -1);
|
||||
} else {
|
||||
GGML_ASSERT(n_outputs_enc > 0 && "call llama_encode() first");
|
||||
|
||||
struct ggml_tensor * embd_enc = llm_build_inp_embd_enc();
|
||||
struct ggml_tensor * pos_bucket_dec = llm_build_pos_bucket(true);
|
||||
|
||||
struct ggml_tensor * KQ_mask_dec = build_inp_KQ_mask();
|
||||
struct ggml_tensor * KQ_mask_cross = llm_build_inp_KQ_mask_cross();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * inpSA = inpL;
|
||||
|
||||
// norm
|
||||
cur = llm_build_norm(ctx0, inpL, hparams,
|
||||
model.layers[il].attn_norm, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self-attention
|
||||
{
|
||||
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
llm_build_kv_store(ctx0, hparams, cparams, kv_self, gf, Kcur, Vcur, n_tokens, kv_head, cb, il);
|
||||
|
||||
struct ggml_tensor * k =
|
||||
ggml_view_3d(ctx0, kv_self.k_l[il],
|
||||
n_embd_head_k, n_kv, n_head_kv,
|
||||
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
|
||||
ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
|
||||
0);
|
||||
cb(k, "k", il);
|
||||
|
||||
struct ggml_tensor * v =
|
||||
ggml_view_3d(ctx0, kv_self.v_l[il],
|
||||
n_kv, n_embd_head_v, n_head_kv,
|
||||
ggml_element_size(kv_self.v_l[il])*n_ctx,
|
||||
ggml_element_size(kv_self.v_l[il])*n_ctx*n_embd_head_v,
|
||||
0);
|
||||
cb(v, "v", il);
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
|
||||
struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
|
||||
|
||||
struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
||||
cb(kq, "kq", il);
|
||||
|
||||
struct ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b ? model.layers[il].attn_rel_b : model.layers[0].attn_rel_b;
|
||||
struct ggml_tensor * pos_bias = llm_build_pos_bias(pos_bucket_dec, attn_rel_b);
|
||||
struct ggml_tensor * kq_b = ggml_add(ctx0, kq, pos_bias);
|
||||
cb(kq_b, "kq_b", il);
|
||||
|
||||
kq = ggml_soft_max_ext(ctx0, kq_b, KQ_mask_dec, 1.0f, hparams.f_max_alibi_bias);
|
||||
cb(kq, "kq_soft_max_ext", il);
|
||||
|
||||
struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
|
||||
cb(kqv, "kqv", il);
|
||||
|
||||
struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
|
||||
cb(kqv_merged, "kqv_merged", il);
|
||||
|
||||
cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens);
|
||||
cb(cur, "kqv_merged_cont", il);
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur);
|
||||
cb(cur, "kqv_out", il);
|
||||
}
|
||||
|
||||
cur = ggml_add(ctx0, cur, inpSA);
|
||||
cb(cur, "cross_inp", il);
|
||||
|
||||
struct ggml_tensor * inpCA = cur;
|
||||
|
||||
// norm
|
||||
cur = llm_build_norm(ctx0, cur, hparams,
|
||||
model.layers[il].attn_norm_cross, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "attn_norm_cross", il);
|
||||
|
||||
// cross-attention
|
||||
{
|
||||
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq_cross, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk_cross, embd_enc);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv_cross, embd_enc);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_outputs_enc);
|
||||
|
||||
struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
|
||||
struct ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
|
||||
|
||||
struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
||||
cb(kq, "kq", il);
|
||||
|
||||
kq = ggml_soft_max_ext(ctx0, kq, KQ_mask_cross, 1.0f, hparams.f_max_alibi_bias);
|
||||
cb(kq, "kq_soft_max_ext", il);
|
||||
|
||||
struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_outputs_enc)));
|
||||
cb(v, "v", il);
|
||||
|
||||
struct ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_outputs_enc, n_embd_head, n_head_kv), kq);
|
||||
cb(kqv, "kqv", il);
|
||||
|
||||
struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
|
||||
cb(kqv_merged, "kqv_merged", il);
|
||||
|
||||
cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens);
|
||||
cb(cur, "kqv_merged_cont", il);
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo_cross, cur);
|
||||
cb(cur, "kqv_out", il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1) {
|
||||
// skip computing output for unused tokens
|
||||
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
n_tokens = n_outputs;
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
inpCA = ggml_get_rows(ctx0, inpCA, inp_out_ids);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpCA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
// feed-forward network
|
||||
{
|
||||
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
||||
model.layers[il].ffn_norm, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
// T5 uses relu, flan-T5 uses gelu-gated
|
||||
cur = llm_build_ffn(ctx0, lctx, cur,
|
||||
model.layers[il].ffn_up, NULL, NULL,
|
||||
model.layers[il].ffn_gate, NULL, NULL,
|
||||
model.layers[il].ffn_down, NULL, NULL,
|
||||
NULL,
|
||||
model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU,
|
||||
model.layers[il].ffn_gate_enc ? LLM_FFN_PAR : LLM_FFN_SEQ,
|
||||
cb, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
}
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
cb(cur, "ffn_out", il);
|
||||
|
||||
ggml_tensor * layer_dir = lctx.cvec.tensor_for(il);
|
||||
if (layer_dir != nullptr) {
|
||||
cur = ggml_add(ctx0, cur, layer_dir);
|
||||
}
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
if (il == n_layer - 1) {
|
||||
// skip computing output for unused tokens
|
||||
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
n_tokens = n_outputs;
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
cb(cur, "result_embd", -1);
|
||||
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
cur = llm_build_norm(ctx0, cur, hparams,
|
||||
model.output_norm, NULL,
|
||||
LLM_NORM_RMS, cb, -1);
|
||||
cb(cur, "result_norm", -1);
|
||||
// feed-forward network
|
||||
{
|
||||
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
||||
model.layers[il].ffn_norm_enc, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
// lm_head
|
||||
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
||||
cb(cur, "result_output", -1);
|
||||
// T5 uses relu, flan-T5 uses gelu-gated
|
||||
cur = llm_build_ffn(ctx0, lctx, cur,
|
||||
model.layers[il].ffn_up_enc, NULL, NULL,
|
||||
model.layers[il].ffn_gate_enc, NULL, NULL,
|
||||
model.layers[il].ffn_down_enc, NULL, NULL,
|
||||
NULL,
|
||||
model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU,
|
||||
model.layers[il].ffn_gate_enc ? LLM_FFN_PAR : LLM_FFN_SEQ,
|
||||
cb, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
}
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
cb(cur, "ffn_out", il);
|
||||
|
||||
ggml_tensor * layer_dir = lctx.cvec.tensor_for(il);
|
||||
if (layer_dir != nullptr) {
|
||||
cur = ggml_add(ctx0, cur, layer_dir);
|
||||
}
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
cb(cur, "result_embd", -1);
|
||||
|
||||
cur = llm_build_norm(ctx0, cur, hparams,
|
||||
model.output_norm_enc, NULL,
|
||||
LLM_NORM_RMS, cb, -1);
|
||||
cb(cur, "result_norm", -1);
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
struct ggml_cgraph * build_t5_decoder() {
|
||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
||||
|
||||
// mutable variable, needed during the last layer of the computation to skip unused tokens
|
||||
int32_t n_tokens = this->n_tokens;
|
||||
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
|
||||
struct ggml_tensor * cur;
|
||||
struct ggml_tensor * inpL;
|
||||
|
||||
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
|
||||
|
||||
GGML_ASSERT(!lctx.is_encoding);
|
||||
GGML_ASSERT(n_outputs_enc > 0 && "call llama_encode() first");
|
||||
|
||||
struct ggml_tensor * embd_enc = llm_build_inp_embd_enc();
|
||||
struct ggml_tensor * pos_bucket_dec = llm_build_pos_bucket(true);
|
||||
|
||||
struct ggml_tensor * KQ_mask_dec = build_inp_KQ_mask();
|
||||
struct ggml_tensor * KQ_mask_cross = llm_build_inp_KQ_mask_cross();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * inpSA = inpL;
|
||||
|
||||
// norm
|
||||
cur = llm_build_norm(ctx0, inpL, hparams,
|
||||
model.layers[il].attn_norm, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self-attention
|
||||
{
|
||||
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
llm_build_kv_store(ctx0, hparams, cparams, kv_self, gf, Kcur, Vcur, n_tokens, kv_head, cb, il);
|
||||
|
||||
struct ggml_tensor * k =
|
||||
ggml_view_3d(ctx0, kv_self.k_l[il],
|
||||
n_embd_head_k, n_kv, n_head_kv,
|
||||
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
|
||||
ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
|
||||
0);
|
||||
cb(k, "k", il);
|
||||
|
||||
struct ggml_tensor * v =
|
||||
ggml_view_3d(ctx0, kv_self.v_l[il],
|
||||
n_kv, n_embd_head_v, n_head_kv,
|
||||
ggml_element_size(kv_self.v_l[il])*n_ctx,
|
||||
ggml_element_size(kv_self.v_l[il])*n_ctx*n_embd_head_v,
|
||||
0);
|
||||
cb(v, "v", il);
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
|
||||
struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
|
||||
|
||||
struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
||||
cb(kq, "kq", il);
|
||||
|
||||
struct ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b ? model.layers[il].attn_rel_b : model.layers[0].attn_rel_b;
|
||||
struct ggml_tensor * pos_bias = llm_build_pos_bias(pos_bucket_dec, attn_rel_b);
|
||||
struct ggml_tensor * kq_b = ggml_add(ctx0, kq, pos_bias);
|
||||
cb(kq_b, "kq_b", il);
|
||||
|
||||
kq = ggml_soft_max_ext(ctx0, kq_b, KQ_mask_dec, 1.0f, hparams.f_max_alibi_bias);
|
||||
cb(kq, "kq_soft_max_ext", il);
|
||||
|
||||
struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
|
||||
cb(kqv, "kqv", il);
|
||||
|
||||
struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
|
||||
cb(kqv_merged, "kqv_merged", il);
|
||||
|
||||
cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens);
|
||||
cb(cur, "kqv_merged_cont", il);
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur);
|
||||
cb(cur, "kqv_out", il);
|
||||
}
|
||||
|
||||
cur = ggml_add(ctx0, cur, inpSA);
|
||||
cb(cur, "cross_inp", il);
|
||||
|
||||
struct ggml_tensor * inpCA = cur;
|
||||
|
||||
// norm
|
||||
cur = llm_build_norm(ctx0, cur, hparams,
|
||||
model.layers[il].attn_norm_cross, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "attn_norm_cross", il);
|
||||
|
||||
// cross-attention
|
||||
{
|
||||
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq_cross, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk_cross, embd_enc);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv_cross, embd_enc);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_outputs_enc);
|
||||
|
||||
struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
|
||||
struct ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
|
||||
|
||||
struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
||||
cb(kq, "kq", il);
|
||||
|
||||
kq = ggml_soft_max_ext(ctx0, kq, KQ_mask_cross, 1.0f, hparams.f_max_alibi_bias);
|
||||
cb(kq, "kq_soft_max_ext", il);
|
||||
|
||||
struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_outputs_enc)));
|
||||
cb(v, "v", il);
|
||||
|
||||
struct ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_outputs_enc, n_embd_head, n_head_kv), kq);
|
||||
cb(kqv, "kqv", il);
|
||||
|
||||
struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
|
||||
cb(kqv_merged, "kqv_merged", il);
|
||||
|
||||
cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens);
|
||||
cb(cur, "kqv_merged_cont", il);
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo_cross, cur);
|
||||
cb(cur, "kqv_out", il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1) {
|
||||
// skip computing output for unused tokens
|
||||
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
n_tokens = n_outputs;
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
inpCA = ggml_get_rows(ctx0, inpCA, inp_out_ids);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpCA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
// feed-forward network
|
||||
{
|
||||
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
||||
model.layers[il].ffn_norm, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
// T5 uses relu, flan-T5 uses gelu-gated
|
||||
cur = llm_build_ffn(ctx0, lctx, cur,
|
||||
model.layers[il].ffn_up, NULL, NULL,
|
||||
model.layers[il].ffn_gate, NULL, NULL,
|
||||
model.layers[il].ffn_down, NULL, NULL,
|
||||
NULL,
|
||||
model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU,
|
||||
model.layers[il].ffn_gate_enc ? LLM_FFN_PAR : LLM_FFN_SEQ,
|
||||
cb, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
}
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
cb(cur, "ffn_out", il);
|
||||
|
||||
ggml_tensor * layer_dir = lctx.cvec.tensor_for(il);
|
||||
if (layer_dir != nullptr) {
|
||||
cur = ggml_add(ctx0, cur, layer_dir);
|
||||
}
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
cb(cur, "result_embd", -1);
|
||||
|
||||
cur = llm_build_norm(ctx0, cur, hparams,
|
||||
model.output_norm, NULL,
|
||||
LLM_NORM_RMS, cb, -1);
|
||||
cb(cur, "result_norm", -1);
|
||||
|
||||
// lm_head
|
||||
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
||||
cb(cur, "result_output", -1);
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
return gf;
|
||||
@ -13898,7 +13980,15 @@ static struct ggml_cgraph * llama_build_graph(
|
||||
} break;
|
||||
case LLM_ARCH_T5:
|
||||
{
|
||||
result = llm.build_t5();
|
||||
if (lctx.is_encoding) {
|
||||
result = llm.build_t5_encoder();
|
||||
} else {
|
||||
result = llm.build_t5_decoder();
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_T5ENCODER:
|
||||
{
|
||||
result = llm.build_t5_encoder();
|
||||
} break;
|
||||
case LLM_ARCH_JAIS:
|
||||
{
|
||||
@ -14346,7 +14436,7 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) {
|
||||
|
||||
// TODO: use a per-batch flag for logits presence instead
|
||||
const bool has_logits = !cparams.embeddings;
|
||||
const bool has_embd = lctx.is_encoding || (cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE));
|
||||
const bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
|
||||
|
||||
const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0;
|
||||
const size_t embd_size = has_embd ? n_embd*n_outputs_max : 0;
|
||||
@ -14829,9 +14919,24 @@ static int llama_encode_internal(
|
||||
ggml_cgraph * gf = llama_build_graph(lctx, batch, false);
|
||||
|
||||
// the output embeddings after the final encoder normalization
|
||||
struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 1];
|
||||
struct ggml_tensor * embd = nullptr;
|
||||
|
||||
GGML_ASSERT(strcmp(embd->name, "result_norm") == 0);
|
||||
// there are two cases here
|
||||
if (llama_model_has_decoder(&lctx.model)) {
|
||||
// first case is an encoder-decoder T5 model where embeddings are passed to decoder
|
||||
embd = gf->nodes[gf->n_nodes - 1];
|
||||
GGML_ASSERT(strcmp(embd->name, "result_norm") == 0 && "missing result_output tensor");
|
||||
} else {
|
||||
// second case is an encoder-only T5 model
|
||||
if (cparams.embeddings) {
|
||||
// only output embeddings if required
|
||||
embd = gf->nodes[gf->n_nodes - 1];
|
||||
if (strcmp(embd->name, "result_embd_pooled") != 0) {
|
||||
embd = gf->nodes[gf->n_nodes - 2];
|
||||
}
|
||||
GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0 && "missing embeddings tensor");
|
||||
}
|
||||
}
|
||||
|
||||
ggml_backend_sched_alloc_graph(lctx.sched, gf);
|
||||
|
||||
@ -14844,20 +14949,54 @@ static int llama_encode_internal(
|
||||
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd);
|
||||
GGML_ASSERT(backend_embd != nullptr);
|
||||
|
||||
// extract token embeddings
|
||||
GGML_ASSERT(lctx.embd != nullptr);
|
||||
if (llama_model_has_decoder(&lctx.model)) {
|
||||
lctx.embd_enc.resize(n_tokens*n_embd);
|
||||
float * embd_out = lctx.embd_enc.data();
|
||||
|
||||
lctx.embd_enc.resize(n_tokens*n_embd);
|
||||
float * embd_out = lctx.embd_enc.data();
|
||||
ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
|
||||
|
||||
ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
|
||||
// remember the sequence ids used during the encoding - needed for cross attention later
|
||||
lctx.seq_ids_enc.resize(n_tokens);
|
||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||
for (int s = 0; s < batch.n_seq_id[i]; s++) {
|
||||
llama_seq_id seq_id = batch.seq_id[i][s];
|
||||
lctx.seq_ids_enc[i].insert(seq_id);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
GGML_ASSERT(lctx.embd != nullptr);
|
||||
|
||||
// remember the sequence ids used during the encoding - needed for cross attention later
|
||||
lctx.seq_ids_enc.resize(n_tokens);
|
||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||
for (int s = 0; s < batch.n_seq_id[i]; s++) {
|
||||
llama_seq_id seq_id = batch.seq_id[i][s];
|
||||
lctx.seq_ids_enc[i].insert(seq_id);
|
||||
switch (cparams.pooling_type) {
|
||||
case LLAMA_POOLING_TYPE_NONE:
|
||||
{
|
||||
// extract token embeddings
|
||||
GGML_ASSERT(lctx.embd != nullptr);
|
||||
float * embd_out = lctx.embd;
|
||||
|
||||
GGML_ASSERT(n_tokens*n_embd <= (int64_t) lctx.embd_size);
|
||||
ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_MEAN:
|
||||
case LLAMA_POOLING_TYPE_CLS:
|
||||
case LLAMA_POOLING_TYPE_LAST:
|
||||
{
|
||||
// extract sequence embeddings
|
||||
auto & embd_seq_out = lctx.embd_seq;
|
||||
embd_seq_out.clear();
|
||||
|
||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||
const llama_seq_id seq_id = batch.seq_id[i][0];
|
||||
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
||||
continue;
|
||||
}
|
||||
embd_seq_out[seq_id].resize(n_embd);
|
||||
ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
|
||||
}
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
||||
{
|
||||
GGML_ABORT("unknown pooling type");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -16567,6 +16706,8 @@ struct llama_context * llama_new_context_with_model(
|
||||
|
||||
ctx->sampling.rng = std::mt19937(params.seed);
|
||||
ctx->logits_all = params.logits_all;
|
||||
// build worst-case graph for encoder if a model contains encoder
|
||||
ctx->is_encoding = llama_model_has_encoder(model);
|
||||
|
||||
uint32_t kv_size = cparams.n_ctx;
|
||||
ggml_type type_k = params.type_k;
|
||||
@ -16881,6 +17022,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
|
||||
case LLM_ARCH_MAMBA:
|
||||
case LLM_ARCH_JINA_BERT_V2:
|
||||
case LLM_ARCH_T5:
|
||||
case LLM_ARCH_T5ENCODER:
|
||||
case LLM_ARCH_JAIS:
|
||||
return LLAMA_ROPE_TYPE_NONE;
|
||||
|
||||
@ -17028,8 +17170,16 @@ struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const ch
|
||||
|
||||
bool llama_model_has_encoder(const struct llama_model * model) {
|
||||
switch (model->arch) {
|
||||
case LLM_ARCH_T5: return true;
|
||||
default: return false;
|
||||
case LLM_ARCH_T5: return true;
|
||||
case LLM_ARCH_T5ENCODER: return true;
|
||||
default: return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool llama_model_has_decoder(const struct llama_model * model) {
|
||||
switch (model->arch) {
|
||||
case LLM_ARCH_T5ENCODER: return false;
|
||||
default: return true;
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user