llama : add "rank" pooling type

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-09-19 13:21:15 +03:00
parent d0a7bf9382
commit 125a0671ab
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
4 changed files with 34 additions and 17 deletions

View File

@ -1100,6 +1100,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; }
else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; }
else if (value == "last") { params.pooling_type = LLAMA_POOLING_TYPE_LAST; }
else if (value == "rank") { params.pooling_type = LLAMA_POOLING_TYPE_RANK; }
else { throw std::invalid_argument("invalid value"); }
}
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_POOLING"));

View File

@ -234,6 +234,10 @@ int main(int argc, char ** argv) {
}
LOG("\n");
}
} else if (pooling_type == LLAMA_POOLING_TYPE_RANK) {
for (int j = 0; j < n_embd_count; j++) {
LOG("rank score %d: %8.3f\n", j, emb[j * n_embd]);
}
} else {
// print the first part of the embeddings or for a single prompt, the full embedding
for (int j = 0; j < n_prompts; j++) {

View File

@ -192,6 +192,7 @@ extern "C" {
LLAMA_POOLING_TYPE_MEAN = 1,
LLAMA_POOLING_TYPE_CLS = 2,
LLAMA_POOLING_TYPE_LAST = 3,
LLAMA_POOLING_TYPE_RANK = 4,
};
enum llama_attention_type {

View File

@ -10213,6 +10213,10 @@ struct llm_build_context {
struct ggml_tensor * cur;
switch (pooling_type) {
case LLAMA_POOLING_TYPE_NONE:
{
cur = inp;
} break;
case LLAMA_POOLING_TYPE_MEAN:
{
struct ggml_tensor * inp_mean = build_inp_mean();
@ -10224,9 +10228,24 @@ struct llm_build_context {
struct ggml_tensor * inp_cls = build_inp_cls();
cur = ggml_get_rows(ctx0, inp, inp_cls);
} break;
case LLAMA_POOLING_TYPE_NONE:
case LLAMA_POOLING_TYPE_RANK:
{
cur = inp;
struct ggml_tensor * inp_cls = build_inp_cls();
inp = ggml_get_rows(ctx0, inp, inp_cls);
// classification head
// https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
GGML_ASSERT(model.cls != nullptr);
GGML_ASSERT(model.cls_b != nullptr);
GGML_ASSERT(model.cls_out != nullptr);
GGML_ASSERT(model.cls_out_b != nullptr);
cur = ggml_add (ctx0, ggml_mul_mat(ctx0, model.cls, inp), model.cls_b);
cur = ggml_tanh(ctx0, cur);
cur = ggml_add (ctx0, ggml_mul_mat(ctx0, model.cls_out, cur), model.cls_out_b);
// broadcast across the embedding size to make it compatible with the llama_get_embeddings API
cur = ggml_repeat(ctx0, cur, inp);
} break;
default:
{
@ -11457,18 +11476,6 @@ struct llm_build_context {
cur = inpL;
// classification head
// https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
// TODO: become pooling layer?
if (model.cls) {
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.cls, cur), model.cls_b);
cur = ggml_tanh(ctx0, cur);
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.cls_out, cur), model.cls_out_b);
// TODO: cur is now a scalar - what to do?
}
cb(cur, "result_embd", -1);
ggml_build_forward_expand(gf, cur);
@ -16446,7 +16453,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
}
}
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) {
if (cparams.embeddings && (
cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) {
const int64_t n_tokens = batch.n_tokens;
const int64_t n_seq_tokens = batch.n_seq_tokens;
const int64_t n_seqs = batch.n_seqs;
@ -16461,7 +16470,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
const llama_seq_id seq_id = batch.seq_id[s][0];
// TODO: adapt limits to n_seqs when batch.equal_seqs is true
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS");
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK");
for (int i = 0; i < n_seq_tokens; ++i) {
const llama_pos pos = batch.pos[s*n_seq_tokens + i];
@ -16988,6 +16997,7 @@ static int llama_decode_internal(
case LLAMA_POOLING_TYPE_MEAN:
case LLAMA_POOLING_TYPE_CLS:
case LLAMA_POOLING_TYPE_LAST:
case LLAMA_POOLING_TYPE_RANK:
{
// extract sequence embeddings (cleared before processing each batch)
auto & embd_seq_out = lctx.embd_seq;
@ -17191,6 +17201,7 @@ static int llama_encode_internal(
case LLAMA_POOLING_TYPE_MEAN:
case LLAMA_POOLING_TYPE_CLS:
case LLAMA_POOLING_TYPE_LAST:
case LLAMA_POOLING_TYPE_RANK:
{
// extract sequence embeddings
auto & embd_seq_out = lctx.embd_seq;