mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-31 22:04:35 +00:00
llama : add classigication head (wip) [no ci]
This commit is contained in:
parent
dc0cdd8760
commit
d0a7bf9382
@ -391,7 +391,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
|
||||
[](gpt_params & params) {
|
||||
params.verbose_prompt = true;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_MAIN}));
|
||||
));
|
||||
add_opt(llama_arg(
|
||||
{"--no-display-prompt"},
|
||||
format("don't print prompt at generation (default: %s)", !params.display_prompt ? "true" : "false"),
|
||||
|
@ -11455,8 +11455,20 @@ struct llm_build_context {
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
// final output
|
||||
cur = inpL;
|
||||
|
||||
// classification head
|
||||
// https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
|
||||
// TODO: become pooling layer?
|
||||
if (model.cls) {
|
||||
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.cls, cur), model.cls_b);
|
||||
|
||||
cur = ggml_tanh(ctx0, cur);
|
||||
|
||||
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.cls_out, cur), model.cls_out_b);
|
||||
// TODO: cur is now a scalar - what to do?
|
||||
}
|
||||
|
||||
cb(cur, "result_embd", -1);
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
Loading…
Reference in New Issue
Block a user