llama : add "rank" pooling type

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-09-19 13:21:15 +03:00
parent d0a7bf9382
commit 125a0671ab
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
4 changed files with 34 additions and 17 deletions

View File

@ -1098,8 +1098,9 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
[](gpt_params & params, const std::string & value) { [](gpt_params & params, const std::string & value) {
/**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; } /**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; }
else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; } else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; }
else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; } else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; }
else if (value == "last") { params.pooling_type = LLAMA_POOLING_TYPE_LAST; } else if (value == "last") { params.pooling_type = LLAMA_POOLING_TYPE_LAST; }
else if (value == "rank") { params.pooling_type = LLAMA_POOLING_TYPE_RANK; }
else { throw std::invalid_argument("invalid value"); } else { throw std::invalid_argument("invalid value"); }
} }
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_POOLING")); ).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_POOLING"));

View File

@ -234,6 +234,10 @@ int main(int argc, char ** argv) {
} }
LOG("\n"); LOG("\n");
} }
} else if (pooling_type == LLAMA_POOLING_TYPE_RANK) {
for (int j = 0; j < n_embd_count; j++) {
LOG("rank score %d: %8.3f\n", j, emb[j * n_embd]);
}
} else { } else {
// print the first part of the embeddings or for a single prompt, the full embedding // print the first part of the embeddings or for a single prompt, the full embedding
for (int j = 0; j < n_prompts; j++) { for (int j = 0; j < n_prompts; j++) {

View File

@ -192,6 +192,7 @@ extern "C" {
LLAMA_POOLING_TYPE_MEAN = 1, LLAMA_POOLING_TYPE_MEAN = 1,
LLAMA_POOLING_TYPE_CLS = 2, LLAMA_POOLING_TYPE_CLS = 2,
LLAMA_POOLING_TYPE_LAST = 3, LLAMA_POOLING_TYPE_LAST = 3,
LLAMA_POOLING_TYPE_RANK = 4,
}; };
enum llama_attention_type { enum llama_attention_type {

View File

@ -10213,6 +10213,10 @@ struct llm_build_context {
struct ggml_tensor * cur; struct ggml_tensor * cur;
switch (pooling_type) { switch (pooling_type) {
case LLAMA_POOLING_TYPE_NONE:
{
cur = inp;
} break;
case LLAMA_POOLING_TYPE_MEAN: case LLAMA_POOLING_TYPE_MEAN:
{ {
struct ggml_tensor * inp_mean = build_inp_mean(); struct ggml_tensor * inp_mean = build_inp_mean();
@ -10224,9 +10228,24 @@ struct llm_build_context {
struct ggml_tensor * inp_cls = build_inp_cls(); struct ggml_tensor * inp_cls = build_inp_cls();
cur = ggml_get_rows(ctx0, inp, inp_cls); cur = ggml_get_rows(ctx0, inp, inp_cls);
} break; } break;
case LLAMA_POOLING_TYPE_NONE: case LLAMA_POOLING_TYPE_RANK:
{ {
cur = inp; struct ggml_tensor * inp_cls = build_inp_cls();
inp = ggml_get_rows(ctx0, inp, inp_cls);
// classification head
// https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
GGML_ASSERT(model.cls != nullptr);
GGML_ASSERT(model.cls_b != nullptr);
GGML_ASSERT(model.cls_out != nullptr);
GGML_ASSERT(model.cls_out_b != nullptr);
cur = ggml_add (ctx0, ggml_mul_mat(ctx0, model.cls, inp), model.cls_b);
cur = ggml_tanh(ctx0, cur);
cur = ggml_add (ctx0, ggml_mul_mat(ctx0, model.cls_out, cur), model.cls_out_b);
// broadcast across the embedding size to make it compatible with the llama_get_embeddings API
cur = ggml_repeat(ctx0, cur, inp);
} break; } break;
default: default:
{ {
@ -11457,18 +11476,6 @@ struct llm_build_context {
cur = inpL; cur = inpL;
// classification head
// https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
// TODO: become pooling layer?
if (model.cls) {
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.cls, cur), model.cls_b);
cur = ggml_tanh(ctx0, cur);
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.cls_out, cur), model.cls_out_b);
// TODO: cur is now a scalar - what to do?
}
cb(cur, "result_embd", -1); cb(cur, "result_embd", -1);
ggml_build_forward_expand(gf, cur); ggml_build_forward_expand(gf, cur);
@ -16446,7 +16453,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
} }
} }
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) { if (cparams.embeddings && (
cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) {
const int64_t n_tokens = batch.n_tokens; const int64_t n_tokens = batch.n_tokens;
const int64_t n_seq_tokens = batch.n_seq_tokens; const int64_t n_seq_tokens = batch.n_seq_tokens;
const int64_t n_seqs = batch.n_seqs; const int64_t n_seqs = batch.n_seqs;
@ -16461,7 +16470,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
const llama_seq_id seq_id = batch.seq_id[s][0]; const llama_seq_id seq_id = batch.seq_id[s][0];
// TODO: adapt limits to n_seqs when batch.equal_seqs is true // TODO: adapt limits to n_seqs when batch.equal_seqs is true
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS"); GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK");
for (int i = 0; i < n_seq_tokens; ++i) { for (int i = 0; i < n_seq_tokens; ++i) {
const llama_pos pos = batch.pos[s*n_seq_tokens + i]; const llama_pos pos = batch.pos[s*n_seq_tokens + i];
@ -16988,6 +16997,7 @@ static int llama_decode_internal(
case LLAMA_POOLING_TYPE_MEAN: case LLAMA_POOLING_TYPE_MEAN:
case LLAMA_POOLING_TYPE_CLS: case LLAMA_POOLING_TYPE_CLS:
case LLAMA_POOLING_TYPE_LAST: case LLAMA_POOLING_TYPE_LAST:
case LLAMA_POOLING_TYPE_RANK:
{ {
// extract sequence embeddings (cleared before processing each batch) // extract sequence embeddings (cleared before processing each batch)
auto & embd_seq_out = lctx.embd_seq; auto & embd_seq_out = lctx.embd_seq;
@ -17191,6 +17201,7 @@ static int llama_encode_internal(
case LLAMA_POOLING_TYPE_MEAN: case LLAMA_POOLING_TYPE_MEAN:
case LLAMA_POOLING_TYPE_CLS: case LLAMA_POOLING_TYPE_CLS:
case LLAMA_POOLING_TYPE_LAST: case LLAMA_POOLING_TYPE_LAST:
case LLAMA_POOLING_TYPE_RANK:
{ {
// extract sequence embeddings // extract sequence embeddings
auto & embd_seq_out = lctx.embd_seq; auto & embd_seq_out = lctx.embd_seq;