mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-28 12:24:35 +00:00
jina : support v1 reranker
This commit is contained in:
parent
c62a39d91e
commit
866c0113fb
@ -597,6 +597,9 @@ class Model:
|
||||
if chkhsh == "a8594e3edff7c29c003940395316294b2c623e09894deebbc65f33f1515df79e":
|
||||
# ref: https://huggingface.co/databricks/dbrx-base
|
||||
res = "dbrx"
|
||||
if chkhsh == "c7699093ba4255a91e702aa38a596aa81669f3525dae06c2953267dde580f448":
|
||||
# ref: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
|
||||
res = "jina-v1-en"
|
||||
if chkhsh == "0876d13b50744004aa9aeae05e7b0647eac9d801b5ba4668afc01e709c15e19f":
|
||||
# ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-en
|
||||
res = "jina-v2-en"
|
||||
@ -3117,6 +3120,13 @@ class JinaBertV2Model(BertModel):
|
||||
self.gguf_writer.add_add_bos_token(True)
|
||||
self.gguf_writer.add_add_eos_token(True)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# if name starts with "bert.", remove the prefix
|
||||
# e.g. https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
|
||||
if name.startswith("bert."):
|
||||
name = name[5:]
|
||||
|
||||
return super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
@Model.register("OpenELMForCausalLM")
|
||||
class OpenELMModel(Model):
|
||||
|
@ -81,6 +81,7 @@ models = [
|
||||
{"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen1.5-7B", },
|
||||
{"name": "olmo", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/allenai/OLMo-1.7-7B-hf", },
|
||||
{"name": "dbrx", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/databricks/dbrx-base", },
|
||||
{"name": "jina-v1-en", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-reranker-v1-tiny-en", },
|
||||
{"name": "jina-v2-en", "tokt": TOKENIZER_TYPE.WPM, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-en", }, # WPM!
|
||||
{"name": "jina-v2-es", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-es", },
|
||||
{"name": "jina-v2-de", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-de", },
|
||||
|
@ -647,6 +647,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.LAYER_OUT_NORM,
|
||||
MODEL_TENSOR.CLS,
|
||||
],
|
||||
MODEL_ARCH.MPT: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
|
@ -681,6 +681,7 @@ class TensorNameMap:
|
||||
),
|
||||
|
||||
MODEL_TENSOR.CLS: (
|
||||
"classifier", # jina
|
||||
"classifier.dense", # roberta
|
||||
),
|
||||
|
||||
|
@ -828,6 +828,7 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
{ LLM_TENSOR_CLS, "cls" },
|
||||
},
|
||||
},
|
||||
{
|
||||
@ -5590,11 +5591,11 @@ static void llm_load_hparams(
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
||||
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
|
||||
ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type);
|
||||
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type);
|
||||
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false);
|
||||
hparams.f_max_alibi_bias = 8.0f;
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 4: model.type = e_model::MODEL_33M; break; // jina-embeddings-small
|
||||
case 4: model.type = e_model::MODEL_33M; break; // jina-embeddings-small
|
||||
case 12: model.type = e_model::MODEL_137M; break; // jina-embeddings-base
|
||||
}
|
||||
} break;
|
||||
@ -6287,6 +6288,7 @@ static void llm_load_vocab(
|
||||
tokenizer_pre == "phi-2" ||
|
||||
tokenizer_pre == "jina-es" ||
|
||||
tokenizer_pre == "jina-de" ||
|
||||
tokenizer_pre == "jina-v1-en" ||
|
||||
tokenizer_pre == "jina-v2-es" ||
|
||||
tokenizer_pre == "jina-v2-de" ||
|
||||
tokenizer_pre == "jina-v2-code") {
|
||||
@ -6408,7 +6410,12 @@ static void llm_load_vocab(
|
||||
|
||||
for (uint32_t i = 0; i < n_vocab; i++) {
|
||||
std::string word = gguf_get_arr_str(ctx, token_idx, i);
|
||||
GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0);
|
||||
|
||||
//GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0);
|
||||
if (word.empty()) {
|
||||
LLAMA_LOG_WARN("%s: empty token at index %u\n", __func__, i);
|
||||
word = "[EMPTY_" + std::to_string(i) + "]";
|
||||
}
|
||||
|
||||
vocab.token_to_id[word] = i;
|
||||
vocab.max_token_len = std::max(vocab.max_token_len, (int) word.size());
|
||||
@ -6487,8 +6494,14 @@ static void llm_load_vocab(
|
||||
vocab.linefeed_id = ids[0];
|
||||
} else {
|
||||
const std::vector<int> ids = llama_tokenize_internal(vocab, "\xC4\x8A", false); // U+010A
|
||||
GGML_ASSERT(!ids.empty() && "model vocab missing newline token");
|
||||
vocab.linefeed_id = ids[0];
|
||||
|
||||
//GGML_ASSERT(!ids.empty() && "model vocab missing newline token");
|
||||
if (ids.empty()) {
|
||||
LLAMA_LOG_WARN("%s: model vocab missing newline token, using special_pad_id instead\n", __func__);
|
||||
vocab.linefeed_id = vocab.special_pad_id;
|
||||
} else {
|
||||
vocab.linefeed_id = ids[0];
|
||||
}
|
||||
}
|
||||
|
||||
// special tokens
|
||||
@ -7419,6 +7432,8 @@ static bool llm_load_tensors(
|
||||
model.tok_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); // LayerNorm
|
||||
model.tok_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}); //LayerNorm bias
|
||||
|
||||
model.cls = ml.create_tensor(ctx_output, tn(LLM_TENSOR_CLS, "weight"), {n_embd, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
model.cls_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_CLS, "bias"), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
ggml_context * ctx_layer = ctx_for_layer(i);
|
||||
ggml_context * ctx_split = ctx_for_layer_split(i);
|
||||
@ -10237,12 +10252,15 @@ struct llm_build_context {
|
||||
// https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
|
||||
GGML_ASSERT(model.cls != nullptr);
|
||||
GGML_ASSERT(model.cls_b != nullptr);
|
||||
GGML_ASSERT(model.cls_out != nullptr);
|
||||
GGML_ASSERT(model.cls_out_b != nullptr);
|
||||
|
||||
cur = ggml_add (ctx0, ggml_mul_mat(ctx0, model.cls, inp), model.cls_b);
|
||||
cur = ggml_tanh(ctx0, cur);
|
||||
cur = ggml_add (ctx0, ggml_mul_mat(ctx0, model.cls_out, cur), model.cls_out_b);
|
||||
|
||||
if (model.cls_out) {
|
||||
GGML_ASSERT(model.cls_out_b != nullptr);
|
||||
|
||||
cur = ggml_add (ctx0, ggml_mul_mat(ctx0, model.cls_out, cur), model.cls_out_b);
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
|
Loading…
Reference in New Issue
Block a user