mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-31 22:04:35 +00:00
llama : add "rank" pooling type
ggml-ci
This commit is contained in:
parent
d0a7bf9382
commit
125a0671ab
@ -1098,8 +1098,9 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
|
||||
[](gpt_params & params, const std::string & value) {
|
||||
/**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; }
|
||||
else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; }
|
||||
else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; }
|
||||
else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; }
|
||||
else if (value == "last") { params.pooling_type = LLAMA_POOLING_TYPE_LAST; }
|
||||
else if (value == "rank") { params.pooling_type = LLAMA_POOLING_TYPE_RANK; }
|
||||
else { throw std::invalid_argument("invalid value"); }
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_POOLING"));
|
||||
|
@ -234,6 +234,10 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
LOG("\n");
|
||||
}
|
||||
} else if (pooling_type == LLAMA_POOLING_TYPE_RANK) {
|
||||
for (int j = 0; j < n_embd_count; j++) {
|
||||
LOG("rank score %d: %8.3f\n", j, emb[j * n_embd]);
|
||||
}
|
||||
} else {
|
||||
// print the first part of the embeddings or for a single prompt, the full embedding
|
||||
for (int j = 0; j < n_prompts; j++) {
|
||||
|
@ -192,6 +192,7 @@ extern "C" {
|
||||
LLAMA_POOLING_TYPE_MEAN = 1,
|
||||
LLAMA_POOLING_TYPE_CLS = 2,
|
||||
LLAMA_POOLING_TYPE_LAST = 3,
|
||||
LLAMA_POOLING_TYPE_RANK = 4,
|
||||
};
|
||||
|
||||
enum llama_attention_type {
|
||||
|
@ -10213,6 +10213,10 @@ struct llm_build_context {
|
||||
struct ggml_tensor * cur;
|
||||
|
||||
switch (pooling_type) {
|
||||
case LLAMA_POOLING_TYPE_NONE:
|
||||
{
|
||||
cur = inp;
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_MEAN:
|
||||
{
|
||||
struct ggml_tensor * inp_mean = build_inp_mean();
|
||||
@ -10224,9 +10228,24 @@ struct llm_build_context {
|
||||
struct ggml_tensor * inp_cls = build_inp_cls();
|
||||
cur = ggml_get_rows(ctx0, inp, inp_cls);
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_NONE:
|
||||
case LLAMA_POOLING_TYPE_RANK:
|
||||
{
|
||||
cur = inp;
|
||||
struct ggml_tensor * inp_cls = build_inp_cls();
|
||||
inp = ggml_get_rows(ctx0, inp, inp_cls);
|
||||
|
||||
// classification head
|
||||
// https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
|
||||
GGML_ASSERT(model.cls != nullptr);
|
||||
GGML_ASSERT(model.cls_b != nullptr);
|
||||
GGML_ASSERT(model.cls_out != nullptr);
|
||||
GGML_ASSERT(model.cls_out_b != nullptr);
|
||||
|
||||
cur = ggml_add (ctx0, ggml_mul_mat(ctx0, model.cls, inp), model.cls_b);
|
||||
cur = ggml_tanh(ctx0, cur);
|
||||
cur = ggml_add (ctx0, ggml_mul_mat(ctx0, model.cls_out, cur), model.cls_out_b);
|
||||
|
||||
// broadcast across the embedding size to make it compatible with the llama_get_embeddings API
|
||||
cur = ggml_repeat(ctx0, cur, inp);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
@ -11457,18 +11476,6 @@ struct llm_build_context {
|
||||
|
||||
cur = inpL;
|
||||
|
||||
// classification head
|
||||
// https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
|
||||
// TODO: become pooling layer?
|
||||
if (model.cls) {
|
||||
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.cls, cur), model.cls_b);
|
||||
|
||||
cur = ggml_tanh(ctx0, cur);
|
||||
|
||||
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.cls_out, cur), model.cls_out_b);
|
||||
// TODO: cur is now a scalar - what to do?
|
||||
}
|
||||
|
||||
cb(cur, "result_embd", -1);
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
@ -16446,7 +16453,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
|
||||
}
|
||||
}
|
||||
|
||||
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) {
|
||||
if (cparams.embeddings && (
|
||||
cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
|
||||
cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) {
|
||||
const int64_t n_tokens = batch.n_tokens;
|
||||
const int64_t n_seq_tokens = batch.n_seq_tokens;
|
||||
const int64_t n_seqs = batch.n_seqs;
|
||||
@ -16461,7 +16470,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
|
||||
const llama_seq_id seq_id = batch.seq_id[s][0];
|
||||
|
||||
// TODO: adapt limits to n_seqs when batch.equal_seqs is true
|
||||
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS");
|
||||
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK");
|
||||
|
||||
for (int i = 0; i < n_seq_tokens; ++i) {
|
||||
const llama_pos pos = batch.pos[s*n_seq_tokens + i];
|
||||
@ -16988,6 +16997,7 @@ static int llama_decode_internal(
|
||||
case LLAMA_POOLING_TYPE_MEAN:
|
||||
case LLAMA_POOLING_TYPE_CLS:
|
||||
case LLAMA_POOLING_TYPE_LAST:
|
||||
case LLAMA_POOLING_TYPE_RANK:
|
||||
{
|
||||
// extract sequence embeddings (cleared before processing each batch)
|
||||
auto & embd_seq_out = lctx.embd_seq;
|
||||
@ -17191,6 +17201,7 @@ static int llama_encode_internal(
|
||||
case LLAMA_POOLING_TYPE_MEAN:
|
||||
case LLAMA_POOLING_TYPE_CLS:
|
||||
case LLAMA_POOLING_TYPE_LAST:
|
||||
case LLAMA_POOLING_TYPE_RANK:
|
||||
{
|
||||
// extract sequence embeddings
|
||||
auto & embd_seq_out = lctx.embd_seq;
|
||||
|
Loading…
Reference in New Issue
Block a user