llama : add classigication head (wip) [no ci]

This commit is contained in:
Georgi Gerganov 2024-09-18 21:20:21 +03:00
parent 00f40ae0ef
commit 152e90331e
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 14 additions and 2 deletions

View File

@ -391,7 +391,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
[](gpt_params & params) {
params.verbose_prompt = true;
}
).set_examples({LLAMA_EXAMPLE_MAIN}));
));
add_opt(llama_arg(
{"--no-display-prompt"},
format("don't print prompt at generation (default: %s)", !params.display_prompt ? "true" : "false"),

View File

@ -11291,8 +11291,20 @@ struct llm_build_context {
inpL = cur;
}
// final output
cur = inpL;
// classification head
// https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
// TODO: become pooling layer?
if (model.cls) {
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.cls, cur), model.cls_b);
cur = ggml_tanh(ctx0, cur);
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.cls_out, cur), model.cls_out_b);
// TODO: cur is now a scalar - what to do?
}
cb(cur, "result_embd", -1);
ggml_build_forward_expand(gf, cur);