This commit is contained in:
Georgi Gerganov 2024-09-22 17:55:48 +08:00 committed by GitHub
commit c4aba398ca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 288 additions and 24 deletions

View File

@ -391,7 +391,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
[](gpt_params & params) { [](gpt_params & params) {
params.verbose_prompt = true; params.verbose_prompt = true;
} }
).set_examples({LLAMA_EXAMPLE_MAIN})); ));
add_opt(llama_arg( add_opt(llama_arg(
{"--no-display-prompt"}, {"--no-display-prompt"},
format("don't print prompt at generation (default: %s)", !params.display_prompt ? "true" : "false"), format("don't print prompt at generation (default: %s)", !params.display_prompt ? "true" : "false"),
@ -1098,11 +1098,12 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
[](gpt_params & params, const std::string & value) { [](gpt_params & params, const std::string & value) {
/**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; } /**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; }
else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; } else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; }
else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; } else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; }
else if (value == "last") { params.pooling_type = LLAMA_POOLING_TYPE_LAST; } else if (value == "last") { params.pooling_type = LLAMA_POOLING_TYPE_LAST; }
else if (value == "rank") { params.pooling_type = LLAMA_POOLING_TYPE_RANK; }
else { throw std::invalid_argument("invalid value"); } else { throw std::invalid_argument("invalid value"); }
} }
).set_examples({LLAMA_EXAMPLE_EMBEDDING})); ).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_SERVER}));
add_opt(llama_arg( add_opt(llama_arg(
{"--attention"}, "{causal,non,causal}", {"--attention"}, "{causal,non,causal}",
"attention type for embeddings, use model default if unspecified", "attention type for embeddings, use model default if unspecified",

View File

@ -291,8 +291,13 @@ class Model:
bid = int(part) bid = int(part)
break break
for new_name, data in ((n, d.squeeze().numpy()) for n, d in self.modify_tensors(data_torch, name, bid)): for new_name, data_torch in (self.modify_tensors(data_torch, name, bid)):
data: np.ndarray # type hint data = data_torch.squeeze().numpy()
# if data ends up empty, it means data_torch was a scalar tensor -> restore
if len(data.shape) == 0:
data = data_torch.numpy()
n_dims = len(data.shape) n_dims = len(data.shape)
data_qtype: gguf.GGMLQuantizationType | bool = self.tensor_force_quant(name, new_name, bid, n_dims) data_qtype: gguf.GGMLQuantizationType | bool = self.tensor_force_quant(name, new_name, bid, n_dims)
@ -2598,7 +2603,7 @@ class NomicBertModel(BertModel):
self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"]) self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"])
@Model.register("XLMRobertaModel") @Model.register("XLMRobertaModel", "XLMRobertaForSequenceClassification")
class XLMRobertaModel(BertModel): class XLMRobertaModel(BertModel):
model_arch = gguf.MODEL_ARCH.BERT model_arch = gguf.MODEL_ARCH.BERT
@ -2696,6 +2701,11 @@ class XLMRobertaModel(BertModel):
self.gguf_writer.add_add_eos_token(True) self.gguf_writer.add_add_eos_token(True)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# if name starts with "roberta.", remove the prefix
# e.g. https://huggingface.co/BAAI/bge-reranker-v2-m3/tree/main
if name.startswith("roberta."):
name = name[8:]
# position embeddings start at pad_token_id + 1, so just chop down the weight tensor # position embeddings start at pad_token_id + 1, so just chop down the weight tensor
if name == "embeddings.position_embeddings.weight": if name == "embeddings.position_embeddings.weight":
if self._position_offset is not None: if self._position_offset is not None:

View File

@ -234,6 +234,10 @@ int main(int argc, char ** argv) {
} }
LOG("\n"); LOG("\n");
} }
} else if (pooling_type == LLAMA_POOLING_TYPE_RANK) {
for (int j = 0; j < n_embd_count; j++) {
LOG("rank score %d: %8.3f\n", j, emb[j * n_embd]);
}
} else { } else {
// print the first part of the embeddings or for a single prompt, the full embedding // print the first part of the embeddings or for a single prompt, the full embedding
for (int j = 0; j < n_prompts; j++) { for (int j = 0; j < n_prompts; j++) {

View File

@ -92,6 +92,7 @@ enum server_task_type {
enum server_task_cmpl_type { enum server_task_cmpl_type {
SERVER_TASK_CMPL_TYPE_NORMAL, SERVER_TASK_CMPL_TYPE_NORMAL,
SERVER_TASK_CMPL_TYPE_EMBEDDING, SERVER_TASK_CMPL_TYPE_EMBEDDING,
SERVER_TASK_CMPL_TYPE_RERANK,
SERVER_TASK_CMPL_TYPE_INFILL, SERVER_TASK_CMPL_TYPE_INFILL,
}; };
@ -172,6 +173,7 @@ struct server_slot {
std::vector<completion_token_output> generated_token_probs; std::vector<completion_token_output> generated_token_probs;
server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL; server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
bool has_next_token = true; bool has_next_token = true;
bool truncated = false; bool truncated = false;
bool stopped_eos = false; bool stopped_eos = false;
@ -954,8 +956,17 @@ struct server_context {
slot.prompt = *prompt; slot.prompt = *prompt;
} else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_array()) { } else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_array()) {
slot.prompt = prompt->at(0); slot.prompt = prompt->at(0);
} else if (prompt->is_array() && prompt->size() > 1) {
// array of strings
for (const auto & el : *prompt) {
if (!el.is_string()) {
send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST);
return false;
}
}
slot.prompt = *prompt;
} else { } else {
send_error(task, "\"prompt\" must be a string or an array of integers", ERROR_TYPE_INVALID_REQUEST); send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST);
return false; return false;
} }
} }
@ -1380,6 +1391,7 @@ struct server_context {
res.data = json { res.data = json {
{"embedding", std::vector<float>(n_embd, 0.0f)}, {"embedding", std::vector<float>(n_embd, 0.0f)},
{"index", slot.index},
}; };
continue; continue;
@ -1398,6 +1410,44 @@ struct server_context {
queue_results.send(res); queue_results.send(res);
} }
void send_rank(const server_slot & slot, const llama_batch & batch) {
server_task_result res;
res.id = slot.id_task;
res.error = false;
res.stop = true;
for (int i = 0; i < batch.n_tokens; ++i) {
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) {
continue;
}
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
if (embd == NULL) {
embd = llama_get_embeddings_ith(ctx, i);
}
if (embd == NULL) {
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
res.data = json {
{"index", slot.index},
{"rank", -1e6},
};
continue;
}
res.data = json {
{"index", slot.index},
{"rank", embd[0]},
};
}
SLT_DBG(slot, "sending rank, res = '%s'\n", res.data.dump().c_str());
queue_results.send(res);
}
// //
// Functions to create new task(s) and receive result(s) // Functions to create new task(s) and receive result(s)
// //
@ -1433,13 +1483,23 @@ struct server_context {
// otherwise, it's a multiple-prompt task, we break it into smaller tasks // otherwise, it's a multiple-prompt task, we break it into smaller tasks
else if (prompt.is_array()) { else if (prompt.is_array()) {
std::vector<json> prompts = prompt; std::vector<json> prompts = prompt;
for (size_t i = 0; i < prompts.size(); i++) { if (cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
const auto & e = prompts[i]; for (size_t i = 1; i < prompts.size(); i++) {
if (e.is_string() || json_is_array_of_numbers(e)) { json qd;
data["index"] = i; qd.push_back(prompts[0]);
create_task(data, true, e); qd.push_back(prompts[i]);
} else { data["index"] = i - 1;
throw std::runtime_error(error_msg); create_task(data, true, qd);
}
} else {
for (size_t i = 0; i < prompts.size(); i++) {
const auto & e = prompts[i];
if (e.is_string() || json_is_array_of_numbers(e)) {
data["index"] = i;
create_task(data, true, e);
} else {
throw std::runtime_error(error_msg);
}
} }
} }
} }
@ -1483,7 +1543,9 @@ struct server_context {
break; break;
} }
size_t idx = result.data["index"]; const size_t idx = result.data["index"];
GGML_ASSERT(idx < results.size() && "index out of range");
results[idx] = result; results[idx] = result;
} }
result_handler(results); result_handler(results);
@ -1934,6 +1996,29 @@ struct server_context {
} }
prompt_tokens = embd_inp; prompt_tokens = embd_inp;
} else if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
// require slot.prompt to be array of 2 strings
if (!slot.prompt.is_array() || slot.prompt.size() != 2) {
SLT_ERR(slot, "%s", "invalid prompt for rerank task\n");
slot.release();
send_error(slot, "invalid prompt for rerank task", ERROR_TYPE_INVALID_REQUEST);
continue;
}
// prompt: <s>query</s><s>doc</s>
prompt_tokens.clear();
prompt_tokens.push_back(llama_token_bos(model));
{
const auto part = tokenize(slot.prompt[0], false);
prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
}
prompt_tokens.push_back(llama_token_eos(model));
prompt_tokens.push_back(llama_token_bos(model));
{
const auto part = tokenize(slot.prompt[1], false);
prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
}
prompt_tokens.push_back(llama_token_eos(model));
} else { } else {
prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt
} }
@ -1953,7 +2038,7 @@ struct server_context {
continue; continue;
} }
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) { if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
// this prompt is too large to process - discard it // this prompt is too large to process - discard it
if (slot.n_prompt_tokens > n_ubatch) { if (slot.n_prompt_tokens > n_ubatch) {
slot.release(); slot.release();
@ -2023,7 +2108,7 @@ struct server_context {
slot.n_prompt_tokens_processed = 0; slot.n_prompt_tokens_processed = 0;
} }
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) { if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
// cannot fit the prompt in the current batch - will try next iter // cannot fit the prompt in the current batch - will try next iter
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) { if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
continue; continue;
@ -2031,7 +2116,10 @@ struct server_context {
} }
// check that we are in the right batch_type, if not defer the slot // check that we are in the right batch_type, if not defer the slot
bool slot_type = slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING ? 1 : 0; const bool slot_type =
slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING ||
slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK ? 1 : 0;
if (batch_type == -1) { if (batch_type == -1) {
batch_type = slot_type; batch_type = slot_type;
} else if (batch_type != slot_type) { } else if (batch_type != slot_type) {
@ -2204,6 +2292,13 @@ struct server_context {
continue; // continue loop of slots continue; // continue loop of slots
} }
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
send_rank(slot, batch_view);
slot.release();
slot.i_batch = -1;
continue; // continue loop of slots
}
// prompt evaluated for next-token prediction // prompt evaluated for next-token prediction
slot.state = SLOT_STATE_GENERATING; slot.state = SLOT_STATE_GENERATING;
} else if (slot.state != SLOT_STATE_GENERATING) { } else if (slot.state != SLOT_STATE_GENERATING) {
@ -2994,6 +3089,82 @@ int main(int argc, char ** argv) {
res_ok(res, root); res_ok(res, root);
}; };
const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
const json body = json::parse(req.body);
// TODO: implement
//int top_n = 1;
//if (body.count("top_n") != 1) {
// top_n = body.at("top_n");
//} else {
// res_error(res, format_error_response("\"top_n\" must be provided", ERROR_TYPE_INVALID_REQUEST));
// return;
//}
json query;
if (body.count("query") == 1) {
query = body.at("query");
if (!query.is_string()) {
res_error(res, format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST));
return;
}
} else {
exit(0);
res_error(res, format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST));
return;
}
json documents;
if (body.count("documents") != 0) {
documents = body.at("documents");
if (!documents.is_array() || documents.size() == 0) {
res_error(res, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST));
return;
}
} else {
res_error(res, format_error_response("\"documents\" must be provided", ERROR_TYPE_INVALID_REQUEST));
return;
}
// construct prompt object: array of ["query", "doc0", "doc1", ...]
json prompt;
prompt.push_back(query);
for (const auto & doc : documents) {
prompt.push_back(doc);
}
LOG_DBG("rerank prompt: %s\n", prompt.dump().c_str());
// create and queue the task
json responses = json::array();
bool error = false;
{
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_RERANK);
ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks);
// get the result
std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
for (const auto & res : results) {
responses.push_back(res.data);
}
}, [&](const json & error_data) {
res_error(res, error_data);
error = true;
});
}
if (error) {
return;
}
// write JSON response
json root = format_response_rerank(body, responses);
res_ok(res, root);
};
const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) { const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) {
json result = json::array(); json result = json::array();
for (size_t i = 0; i < ctx_server.loras.size(); ++i) { for (size_t i = 0; i < ctx_server.loras.size(); ++i) {
@ -3090,6 +3261,7 @@ int main(int argc, char ** argv) {
svr->Post("/embedding", handle_embeddings); // legacy svr->Post("/embedding", handle_embeddings); // legacy
svr->Post("/embeddings", handle_embeddings); svr->Post("/embeddings", handle_embeddings);
svr->Post("/v1/embeddings", handle_embeddings); svr->Post("/v1/embeddings", handle_embeddings);
svr->Post("/v1/rerank", handle_rerank);
svr->Post("/tokenize", handle_tokenize); svr->Post("/tokenize", handle_tokenize);
svr->Post("/detokenize", handle_detokenize); svr->Post("/detokenize", handle_detokenize);
// LoRA adapters hotswap // LoRA adapters hotswap

View File

@ -537,7 +537,7 @@ static json format_embeddings_response_oaicompat(const json & request, const jso
json res = json { json res = json {
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
{"object", "list"}, {"object", "list"},
{"usage", json { {"usage", json { // TODO: fill
{"prompt_tokens", 0}, {"prompt_tokens", 0},
{"total_tokens", 0} {"total_tokens", 0}
}}, }},
@ -547,6 +547,29 @@ static json format_embeddings_response_oaicompat(const json & request, const jso
return res; return res;
} }
static json format_response_rerank(const json & request, const json & ranks) {
json data = json::array();
int i = 0;
for (const auto & rank : ranks) {
data.push_back(json{
{"index", i++},
{"relevance_score", json_value(rank, "rank", 0.0)},
});
}
json res = json {
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
{"object", "list"},
{"usage", json { // TODO: fill
{"prompt_tokens", 0},
{"total_tokens", 0}
}},
{"results", data}
};
return res;
}
static bool is_valid_utf8(const std::string & str) { static bool is_valid_utf8(const std::string & str) {
const unsigned char* bytes = reinterpret_cast<const unsigned char*>(str.data()); const unsigned char* bytes = reinterpret_cast<const unsigned char*>(str.data());
const unsigned char* end = bytes + str.length(); const unsigned char* end = bytes + str.length();

View File

@ -342,6 +342,8 @@ class MODEL_TENSOR(IntEnum):
ENC_FFN_DOWN = auto() ENC_FFN_DOWN = auto()
ENC_FFN_UP = auto() ENC_FFN_UP = auto()
ENC_OUTPUT_NORM = auto() ENC_OUTPUT_NORM = auto()
CLS = auto() # classifier
CLS_OUT = auto() # classifier output projection
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
@ -499,6 +501,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.ENC_FFN_DOWN: "enc.blk.{bid}.ffn_down", MODEL_TENSOR.ENC_FFN_DOWN: "enc.blk.{bid}.ffn_down",
MODEL_TENSOR.ENC_FFN_UP: "enc.blk.{bid}.ffn_up", MODEL_TENSOR.ENC_FFN_UP: "enc.blk.{bid}.ffn_up",
MODEL_TENSOR.ENC_OUTPUT_NORM: "enc.output_norm", MODEL_TENSOR.ENC_OUTPUT_NORM: "enc.output_norm",
MODEL_TENSOR.CLS: "cls",
MODEL_TENSOR.CLS_OUT: "cls.output",
} }
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
@ -608,6 +612,8 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP, MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.LAYER_OUT_NORM, MODEL_TENSOR.LAYER_OUT_NORM,
MODEL_TENSOR.CLS,
MODEL_TENSOR.CLS_OUT,
], ],
MODEL_ARCH.NOMIC_BERT: [ MODEL_ARCH.NOMIC_BERT: [
MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.TOKEN_EMBD,

View File

@ -677,6 +677,14 @@ class TensorNameMap:
MODEL_TENSOR.ENC_OUTPUT_NORM: ( MODEL_TENSOR.ENC_OUTPUT_NORM: (
"encoder.final_layer_norm", # t5 "encoder.final_layer_norm", # t5
), ),
MODEL_TENSOR.CLS: (
"classifier.dense", # roberta
),
MODEL_TENSOR.CLS_OUT: (
"classifier.out_proj", # roberta
),
} }
# architecture-specific block mappings # architecture-specific block mappings

View File

@ -192,6 +192,7 @@ extern "C" {
LLAMA_POOLING_TYPE_MEAN = 1, LLAMA_POOLING_TYPE_MEAN = 1,
LLAMA_POOLING_TYPE_CLS = 2, LLAMA_POOLING_TYPE_CLS = 2,
LLAMA_POOLING_TYPE_LAST = 3, LLAMA_POOLING_TYPE_LAST = 3,
LLAMA_POOLING_TYPE_RANK = 4,
}; };
enum llama_attention_type { enum llama_attention_type {

View File

@ -600,6 +600,8 @@ enum llm_tensor {
LLM_TENSOR_ENC_FFN_DOWN, LLM_TENSOR_ENC_FFN_DOWN,
LLM_TENSOR_ENC_FFN_UP, LLM_TENSOR_ENC_FFN_UP,
LLM_TENSOR_ENC_OUTPUT_NORM, LLM_TENSOR_ENC_OUTPUT_NORM,
LLM_TENSOR_CLS,
LLM_TENSOR_CLS_OUT,
}; };
static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES = { static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES = {
@ -787,6 +789,8 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
{ LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_CLS, "cls" },
{ LLM_TENSOR_CLS_OUT, "cls.output" },
}, },
}, },
{ {
@ -2861,6 +2865,12 @@ struct llama_model {
struct ggml_tensor * output_b; struct ggml_tensor * output_b;
struct ggml_tensor * output_norm_enc; struct ggml_tensor * output_norm_enc;
// classifier
struct ggml_tensor * cls;
struct ggml_tensor * cls_b;
struct ggml_tensor * cls_out;
struct ggml_tensor * cls_out_b;
std::vector<llama_layer> layers; std::vector<llama_layer> layers;
llama_split_mode split_mode; llama_split_mode split_mode;
@ -7284,6 +7294,12 @@ static bool llm_load_tensors(
if (model.arch == LLM_ARCH_BERT) { if (model.arch == LLM_ARCH_BERT) {
model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}); model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train});
model.cls = ml.create_tensor(ctx_output, tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
model.cls_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_CLS, "bias"), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
model.cls_out = ml.create_tensor(ctx_output, tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
model.cls_out_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_CLS_OUT, "bias"), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
} }
model.tok_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); model.tok_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd});
@ -10111,6 +10127,10 @@ struct llm_build_context {
struct ggml_tensor * cur; struct ggml_tensor * cur;
switch (pooling_type) { switch (pooling_type) {
case LLAMA_POOLING_TYPE_NONE:
{
cur = inp;
} break;
case LLAMA_POOLING_TYPE_MEAN: case LLAMA_POOLING_TYPE_MEAN:
{ {
struct ggml_tensor * inp_mean = build_inp_mean(); struct ggml_tensor * inp_mean = build_inp_mean();
@ -10122,9 +10142,24 @@ struct llm_build_context {
struct ggml_tensor * inp_cls = build_inp_cls(); struct ggml_tensor * inp_cls = build_inp_cls();
cur = ggml_get_rows(ctx0, inp, inp_cls); cur = ggml_get_rows(ctx0, inp, inp_cls);
} break; } break;
case LLAMA_POOLING_TYPE_NONE: case LLAMA_POOLING_TYPE_RANK:
{ {
cur = inp; struct ggml_tensor * inp_cls = build_inp_cls();
inp = ggml_get_rows(ctx0, inp, inp_cls);
// classification head
// https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
GGML_ASSERT(model.cls != nullptr);
GGML_ASSERT(model.cls_b != nullptr);
GGML_ASSERT(model.cls_out != nullptr);
GGML_ASSERT(model.cls_out_b != nullptr);
cur = ggml_add (ctx0, ggml_mul_mat(ctx0, model.cls, inp), model.cls_b);
cur = ggml_tanh(ctx0, cur);
cur = ggml_add (ctx0, ggml_mul_mat(ctx0, model.cls_out, cur), model.cls_out_b);
// broadcast across the embedding size to make it compatible with the llama_get_embeddings API
cur = ggml_repeat(ctx0, cur, inp);
} break; } break;
default: default:
{ {
@ -11353,8 +11388,8 @@ struct llm_build_context {
inpL = cur; inpL = cur;
} }
// final output
cur = inpL; cur = inpL;
cb(cur, "result_embd", -1); cb(cur, "result_embd", -1);
ggml_build_forward_expand(gf, cur); ggml_build_forward_expand(gf, cur);
@ -16331,7 +16366,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
} }
} }
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) { if (cparams.embeddings && (
cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) {
const int64_t n_tokens = batch.n_tokens; const int64_t n_tokens = batch.n_tokens;
const int64_t n_seq_tokens = batch.n_seq_tokens; const int64_t n_seq_tokens = batch.n_seq_tokens;
const int64_t n_seqs = batch.n_seqs; const int64_t n_seqs = batch.n_seqs;
@ -16346,7 +16383,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
const llama_seq_id seq_id = batch.seq_id[s][0]; const llama_seq_id seq_id = batch.seq_id[s][0];
// TODO: adapt limits to n_seqs when batch.equal_seqs is true // TODO: adapt limits to n_seqs when batch.equal_seqs is true
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS"); GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK");
for (int i = 0; i < n_seq_tokens; ++i) { for (int i = 0; i < n_seq_tokens; ++i) {
const llama_pos pos = batch.pos[s*n_seq_tokens + i]; const llama_pos pos = batch.pos[s*n_seq_tokens + i];
@ -16873,6 +16910,7 @@ static int llama_decode_internal(
case LLAMA_POOLING_TYPE_MEAN: case LLAMA_POOLING_TYPE_MEAN:
case LLAMA_POOLING_TYPE_CLS: case LLAMA_POOLING_TYPE_CLS:
case LLAMA_POOLING_TYPE_LAST: case LLAMA_POOLING_TYPE_LAST:
case LLAMA_POOLING_TYPE_RANK:
{ {
// extract sequence embeddings (cleared before processing each batch) // extract sequence embeddings (cleared before processing each batch)
auto & embd_seq_out = lctx.embd_seq; auto & embd_seq_out = lctx.embd_seq;
@ -17076,6 +17114,7 @@ static int llama_encode_internal(
case LLAMA_POOLING_TYPE_MEAN: case LLAMA_POOLING_TYPE_MEAN:
case LLAMA_POOLING_TYPE_CLS: case LLAMA_POOLING_TYPE_CLS:
case LLAMA_POOLING_TYPE_LAST: case LLAMA_POOLING_TYPE_LAST:
case LLAMA_POOLING_TYPE_RANK:
{ {
// extract sequence embeddings // extract sequence embeddings
auto & embd_seq_out = lctx.embd_seq; auto & embd_seq_out = lctx.embd_seq;