diff --git a/common/arg.cpp b/common/arg.cpp index 922391069..23de17b64 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -391,7 +391,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex, [](gpt_params & params) { params.verbose_prompt = true; } - ).set_examples({LLAMA_EXAMPLE_MAIN})); + )); add_opt(llama_arg( {"--no-display-prompt"}, format("don't print prompt at generation (default: %s)", !params.display_prompt ? "true" : "false"), @@ -1098,11 +1098,12 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex, [](gpt_params & params, const std::string & value) { /**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; } else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; } - else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; } + else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; } else if (value == "last") { params.pooling_type = LLAMA_POOLING_TYPE_LAST; } + else if (value == "rank") { params.pooling_type = LLAMA_POOLING_TYPE_RANK; } else { throw std::invalid_argument("invalid value"); } } - ).set_examples({LLAMA_EXAMPLE_EMBEDDING})); + ).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_SERVER})); add_opt(llama_arg( {"--attention"}, "{causal,non,causal}", "attention type for embeddings, use model default if unspecified", diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index ff4c9226f..9c78a2dc1 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -291,8 +291,13 @@ class Model: bid = int(part) break - for new_name, data in ((n, d.squeeze().numpy()) for n, d in self.modify_tensors(data_torch, name, bid)): - data: np.ndarray # type hint + for new_name, data_torch in (self.modify_tensors(data_torch, name, bid)): + data = data_torch.squeeze().numpy() + + # if data ends up empty, it means data_torch was a scalar tensor -> restore + if len(data.shape) == 0: + data = data_torch.numpy() + n_dims = len(data.shape) data_qtype: gguf.GGMLQuantizationType | bool = self.tensor_force_quant(name, new_name, bid, n_dims) @@ -2598,7 +2603,7 @@ class NomicBertModel(BertModel): self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"]) -@Model.register("XLMRobertaModel") +@Model.register("XLMRobertaModel", "XLMRobertaForSequenceClassification") class XLMRobertaModel(BertModel): model_arch = gguf.MODEL_ARCH.BERT @@ -2696,6 +2701,11 @@ class XLMRobertaModel(BertModel): self.gguf_writer.add_add_eos_token(True) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # if name starts with "roberta.", remove the prefix + # e.g. https://huggingface.co/BAAI/bge-reranker-v2-m3/tree/main + if name.startswith("roberta."): + name = name[8:] + # position embeddings start at pad_token_id + 1, so just chop down the weight tensor if name == "embeddings.position_embeddings.weight": if self._position_offset is not None: diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index a438dcb5a..a0ca9d98c 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -234,6 +234,10 @@ int main(int argc, char ** argv) { } LOG("\n"); } + } else if (pooling_type == LLAMA_POOLING_TYPE_RANK) { + for (int j = 0; j < n_embd_count; j++) { + LOG("rank score %d: %8.3f\n", j, emb[j * n_embd]); + } } else { // print the first part of the embeddings or for a single prompt, the full embedding for (int j = 0; j < n_prompts; j++) { diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 0ca999994..8ca228176 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -92,6 +92,7 @@ enum server_task_type { enum server_task_cmpl_type { SERVER_TASK_CMPL_TYPE_NORMAL, SERVER_TASK_CMPL_TYPE_EMBEDDING, + SERVER_TASK_CMPL_TYPE_RERANK, SERVER_TASK_CMPL_TYPE_INFILL, }; @@ -172,6 +173,7 @@ struct server_slot { std::vector generated_token_probs; server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL; + bool has_next_token = true; bool truncated = false; bool stopped_eos = false; @@ -954,8 +956,17 @@ struct server_context { slot.prompt = *prompt; } else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_array()) { slot.prompt = prompt->at(0); + } else if (prompt->is_array() && prompt->size() > 1) { + // array of strings + for (const auto & el : *prompt) { + if (!el.is_string()) { + send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST); + return false; + } + } + slot.prompt = *prompt; } else { - send_error(task, "\"prompt\" must be a string or an array of integers", ERROR_TYPE_INVALID_REQUEST); + send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST); return false; } } @@ -1380,6 +1391,7 @@ struct server_context { res.data = json { {"embedding", std::vector(n_embd, 0.0f)}, + {"index", slot.index}, }; continue; @@ -1398,6 +1410,44 @@ struct server_context { queue_results.send(res); } + void send_rank(const server_slot & slot, const llama_batch & batch) { + server_task_result res; + res.id = slot.id_task; + res.error = false; + res.stop = true; + + for (int i = 0; i < batch.n_tokens; ++i) { + if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) { + continue; + } + + const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + if (embd == NULL) { + embd = llama_get_embeddings_ith(ctx, i); + } + + if (embd == NULL) { + SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]); + + res.data = json { + {"index", slot.index}, + {"rank", -1e6}, + }; + + continue; + } + + res.data = json { + {"index", slot.index}, + {"rank", embd[0]}, + }; + } + + SLT_DBG(slot, "sending rank, res = '%s'\n", res.data.dump().c_str()); + + queue_results.send(res); + } + // // Functions to create new task(s) and receive result(s) // @@ -1433,13 +1483,23 @@ struct server_context { // otherwise, it's a multiple-prompt task, we break it into smaller tasks else if (prompt.is_array()) { std::vector prompts = prompt; - for (size_t i = 0; i < prompts.size(); i++) { - const auto & e = prompts[i]; - if (e.is_string() || json_is_array_of_numbers(e)) { - data["index"] = i; - create_task(data, true, e); - } else { - throw std::runtime_error(error_msg); + if (cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) { + for (size_t i = 1; i < prompts.size(); i++) { + json qd; + qd.push_back(prompts[0]); + qd.push_back(prompts[i]); + data["index"] = i - 1; + create_task(data, true, qd); + } + } else { + for (size_t i = 0; i < prompts.size(); i++) { + const auto & e = prompts[i]; + if (e.is_string() || json_is_array_of_numbers(e)) { + data["index"] = i; + create_task(data, true, e); + } else { + throw std::runtime_error(error_msg); + } } } } @@ -1483,7 +1543,9 @@ struct server_context { break; } - size_t idx = result.data["index"]; + const size_t idx = result.data["index"]; + GGML_ASSERT(idx < results.size() && "index out of range"); + results[idx] = result; } result_handler(results); @@ -1934,6 +1996,29 @@ struct server_context { } prompt_tokens = embd_inp; + } else if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) { + // require slot.prompt to be array of 2 strings + if (!slot.prompt.is_array() || slot.prompt.size() != 2) { + SLT_ERR(slot, "%s", "invalid prompt for rerank task\n"); + slot.release(); + send_error(slot, "invalid prompt for rerank task", ERROR_TYPE_INVALID_REQUEST); + continue; + } + + // prompt: querydoc + prompt_tokens.clear(); + prompt_tokens.push_back(llama_token_bos(model)); + { + const auto part = tokenize(slot.prompt[0], false); + prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end()); + } + prompt_tokens.push_back(llama_token_eos(model)); + prompt_tokens.push_back(llama_token_bos(model)); + { + const auto part = tokenize(slot.prompt[1], false); + prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end()); + } + prompt_tokens.push_back(llama_token_eos(model)); } else { prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt } @@ -1953,7 +2038,7 @@ struct server_context { continue; } - if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) { + if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) { // this prompt is too large to process - discard it if (slot.n_prompt_tokens > n_ubatch) { slot.release(); @@ -2023,7 +2108,7 @@ struct server_context { slot.n_prompt_tokens_processed = 0; } - if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) { + if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) { // cannot fit the prompt in the current batch - will try next iter if (batch.n_tokens + slot.n_prompt_tokens > n_batch) { continue; @@ -2031,7 +2116,10 @@ struct server_context { } // check that we are in the right batch_type, if not defer the slot - bool slot_type = slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING ? 1 : 0; + const bool slot_type = + slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || + slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK ? 1 : 0; + if (batch_type == -1) { batch_type = slot_type; } else if (batch_type != slot_type) { @@ -2204,6 +2292,13 @@ struct server_context { continue; // continue loop of slots } + if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) { + send_rank(slot, batch_view); + slot.release(); + slot.i_batch = -1; + continue; // continue loop of slots + } + // prompt evaluated for next-token prediction slot.state = SLOT_STATE_GENERATING; } else if (slot.state != SLOT_STATE_GENERATING) { @@ -2994,6 +3089,82 @@ int main(int argc, char ** argv) { res_ok(res, root); }; + const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) { + const json body = json::parse(req.body); + + // TODO: implement + //int top_n = 1; + //if (body.count("top_n") != 1) { + // top_n = body.at("top_n"); + //} else { + // res_error(res, format_error_response("\"top_n\" must be provided", ERROR_TYPE_INVALID_REQUEST)); + // return; + //} + + json query; + if (body.count("query") == 1) { + query = body.at("query"); + if (!query.is_string()) { + res_error(res, format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST)); + return; + } + } else { + exit(0); + res_error(res, format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST)); + return; + } + + json documents; + if (body.count("documents") != 0) { + documents = body.at("documents"); + if (!documents.is_array() || documents.size() == 0) { + res_error(res, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST)); + return; + } + } else { + res_error(res, format_error_response("\"documents\" must be provided", ERROR_TYPE_INVALID_REQUEST)); + return; + } + + // construct prompt object: array of ["query", "doc0", "doc1", ...] + json prompt; + prompt.push_back(query); + for (const auto & doc : documents) { + prompt.push_back(doc); + } + + LOG_DBG("rerank prompt: %s\n", prompt.dump().c_str()); + + // create and queue the task + json responses = json::array(); + bool error = false; + { + std::vector tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_RERANK); + ctx_server.queue_results.add_waiting_tasks(tasks); + ctx_server.queue_tasks.post(tasks); + + // get the result + std::unordered_set task_ids = server_task::get_list_id(tasks); + + ctx_server.receive_cmpl_results(task_ids, [&](std::vector & results) { + for (const auto & res : results) { + responses.push_back(res.data); + } + }, [&](const json & error_data) { + res_error(res, error_data); + error = true; + }); + } + + if (error) { + return; + } + + // write JSON response + json root = format_response_rerank(body, responses); + res_ok(res, root); + }; + const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) { json result = json::array(); for (size_t i = 0; i < ctx_server.loras.size(); ++i) { @@ -3090,6 +3261,7 @@ int main(int argc, char ** argv) { svr->Post("/embedding", handle_embeddings); // legacy svr->Post("/embeddings", handle_embeddings); svr->Post("/v1/embeddings", handle_embeddings); + svr->Post("/v1/rerank", handle_rerank); svr->Post("/tokenize", handle_tokenize); svr->Post("/detokenize", handle_detokenize); // LoRA adapters hotswap diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index f093f547f..91e7f792d 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -537,7 +537,7 @@ static json format_embeddings_response_oaicompat(const json & request, const jso json res = json { {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, {"object", "list"}, - {"usage", json { + {"usage", json { // TODO: fill {"prompt_tokens", 0}, {"total_tokens", 0} }}, @@ -547,6 +547,29 @@ static json format_embeddings_response_oaicompat(const json & request, const jso return res; } +static json format_response_rerank(const json & request, const json & ranks) { + json data = json::array(); + int i = 0; + for (const auto & rank : ranks) { + data.push_back(json{ + {"index", i++}, + {"relevance_score", json_value(rank, "rank", 0.0)}, + }); + } + + json res = json { + {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, + {"object", "list"}, + {"usage", json { // TODO: fill + {"prompt_tokens", 0}, + {"total_tokens", 0} + }}, + {"results", data} + }; + + return res; +} + static bool is_valid_utf8(const std::string & str) { const unsigned char* bytes = reinterpret_cast(str.data()); const unsigned char* end = bytes + str.length(); diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index b36a60d49..d54427479 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -342,6 +342,8 @@ class MODEL_TENSOR(IntEnum): ENC_FFN_DOWN = auto() ENC_FFN_UP = auto() ENC_OUTPUT_NORM = auto() + CLS = auto() # classifier + CLS_OUT = auto() # classifier output projection MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { @@ -499,6 +501,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.ENC_FFN_DOWN: "enc.blk.{bid}.ffn_down", MODEL_TENSOR.ENC_FFN_UP: "enc.blk.{bid}.ffn_up", MODEL_TENSOR.ENC_OUTPUT_NORM: "enc.output_norm", + MODEL_TENSOR.CLS: "cls", + MODEL_TENSOR.CLS_OUT: "cls.output", } MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { @@ -608,6 +612,8 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, MODEL_TENSOR.LAYER_OUT_NORM, + MODEL_TENSOR.CLS, + MODEL_TENSOR.CLS_OUT, ], MODEL_ARCH.NOMIC_BERT: [ MODEL_TENSOR.TOKEN_EMBD, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 2ebfa2b43..608d20ea8 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -677,6 +677,14 @@ class TensorNameMap: MODEL_TENSOR.ENC_OUTPUT_NORM: ( "encoder.final_layer_norm", # t5 ), + + MODEL_TENSOR.CLS: ( + "classifier.dense", # roberta + ), + + MODEL_TENSOR.CLS_OUT: ( + "classifier.out_proj", # roberta + ), } # architecture-specific block mappings diff --git a/include/llama.h b/include/llama.h index f316a87ba..a54a70077 100644 --- a/include/llama.h +++ b/include/llama.h @@ -192,6 +192,7 @@ extern "C" { LLAMA_POOLING_TYPE_MEAN = 1, LLAMA_POOLING_TYPE_CLS = 2, LLAMA_POOLING_TYPE_LAST = 3, + LLAMA_POOLING_TYPE_RANK = 4, }; enum llama_attention_type { diff --git a/src/llama.cpp b/src/llama.cpp index bc4e408e0..c82a59244 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -600,6 +600,8 @@ enum llm_tensor { LLM_TENSOR_ENC_FFN_DOWN, LLM_TENSOR_ENC_FFN_UP, LLM_TENSOR_ENC_OUTPUT_NORM, + LLM_TENSOR_CLS, + LLM_TENSOR_CLS_OUT, }; static const std::map> LLM_TENSOR_NAMES = { @@ -787,6 +789,8 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_CLS, "cls" }, + { LLM_TENSOR_CLS_OUT, "cls.output" }, }, }, { @@ -2861,6 +2865,12 @@ struct llama_model { struct ggml_tensor * output_b; struct ggml_tensor * output_norm_enc; + // classifier + struct ggml_tensor * cls; + struct ggml_tensor * cls_b; + struct ggml_tensor * cls_out; + struct ggml_tensor * cls_out_b; + std::vector layers; llama_split_mode split_mode; @@ -7284,6 +7294,12 @@ static bool llm_load_tensors( if (model.arch == LLM_ARCH_BERT) { model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}); + + model.cls = ml.create_tensor(ctx_output, tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.cls_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_CLS, "bias"), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + + model.cls_out = ml.create_tensor(ctx_output, tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, 1}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.cls_out_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_CLS_OUT, "bias"), {1}, llama_model_loader::TENSOR_NOT_REQUIRED); } model.tok_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); @@ -10111,6 +10127,10 @@ struct llm_build_context { struct ggml_tensor * cur; switch (pooling_type) { + case LLAMA_POOLING_TYPE_NONE: + { + cur = inp; + } break; case LLAMA_POOLING_TYPE_MEAN: { struct ggml_tensor * inp_mean = build_inp_mean(); @@ -10122,9 +10142,24 @@ struct llm_build_context { struct ggml_tensor * inp_cls = build_inp_cls(); cur = ggml_get_rows(ctx0, inp, inp_cls); } break; - case LLAMA_POOLING_TYPE_NONE: + case LLAMA_POOLING_TYPE_RANK: { - cur = inp; + struct ggml_tensor * inp_cls = build_inp_cls(); + inp = ggml_get_rows(ctx0, inp, inp_cls); + + // classification head + // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566 + GGML_ASSERT(model.cls != nullptr); + GGML_ASSERT(model.cls_b != nullptr); + GGML_ASSERT(model.cls_out != nullptr); + GGML_ASSERT(model.cls_out_b != nullptr); + + cur = ggml_add (ctx0, ggml_mul_mat(ctx0, model.cls, inp), model.cls_b); + cur = ggml_tanh(ctx0, cur); + cur = ggml_add (ctx0, ggml_mul_mat(ctx0, model.cls_out, cur), model.cls_out_b); + + // broadcast across the embedding size to make it compatible with the llama_get_embeddings API + cur = ggml_repeat(ctx0, cur, inp); } break; default: { @@ -11353,8 +11388,8 @@ struct llm_build_context { inpL = cur; } - // final output cur = inpL; + cb(cur, "result_embd", -1); ggml_build_forward_expand(gf, cur); @@ -16331,7 +16366,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { } } - if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) { + if (cparams.embeddings && ( + cparams.pooling_type == LLAMA_POOLING_TYPE_CLS || + cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) { const int64_t n_tokens = batch.n_tokens; const int64_t n_seq_tokens = batch.n_seq_tokens; const int64_t n_seqs = batch.n_seqs; @@ -16346,7 +16383,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { const llama_seq_id seq_id = batch.seq_id[s][0]; // TODO: adapt limits to n_seqs when batch.equal_seqs is true - GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS"); + GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK"); for (int i = 0; i < n_seq_tokens; ++i) { const llama_pos pos = batch.pos[s*n_seq_tokens + i]; @@ -16873,6 +16910,7 @@ static int llama_decode_internal( case LLAMA_POOLING_TYPE_MEAN: case LLAMA_POOLING_TYPE_CLS: case LLAMA_POOLING_TYPE_LAST: + case LLAMA_POOLING_TYPE_RANK: { // extract sequence embeddings (cleared before processing each batch) auto & embd_seq_out = lctx.embd_seq; @@ -17076,6 +17114,7 @@ static int llama_encode_internal( case LLAMA_POOLING_TYPE_MEAN: case LLAMA_POOLING_TYPE_CLS: case LLAMA_POOLING_TYPE_LAST: + case LLAMA_POOLING_TYPE_RANK: { // extract sequence embeddings auto & embd_seq_out = lctx.embd_seq;