mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-03 23:34:35 +00:00
llama : add classigication head (wip) [no ci]
This commit is contained in:
parent
dc0cdd8760
commit
d0a7bf9382
@ -391,7 +391,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
|
|||||||
[](gpt_params & params) {
|
[](gpt_params & params) {
|
||||||
params.verbose_prompt = true;
|
params.verbose_prompt = true;
|
||||||
}
|
}
|
||||||
).set_examples({LLAMA_EXAMPLE_MAIN}));
|
));
|
||||||
add_opt(llama_arg(
|
add_opt(llama_arg(
|
||||||
{"--no-display-prompt"},
|
{"--no-display-prompt"},
|
||||||
format("don't print prompt at generation (default: %s)", !params.display_prompt ? "true" : "false"),
|
format("don't print prompt at generation (default: %s)", !params.display_prompt ? "true" : "false"),
|
||||||
|
@ -11455,8 +11455,20 @@ struct llm_build_context {
|
|||||||
inpL = cur;
|
inpL = cur;
|
||||||
}
|
}
|
||||||
|
|
||||||
// final output
|
|
||||||
cur = inpL;
|
cur = inpL;
|
||||||
|
|
||||||
|
// classification head
|
||||||
|
// https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
|
||||||
|
// TODO: become pooling layer?
|
||||||
|
if (model.cls) {
|
||||||
|
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.cls, cur), model.cls_b);
|
||||||
|
|
||||||
|
cur = ggml_tanh(ctx0, cur);
|
||||||
|
|
||||||
|
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.cls_out, cur), model.cls_out_b);
|
||||||
|
// TODO: cur is now a scalar - what to do?
|
||||||
|
}
|
||||||
|
|
||||||
cb(cur, "result_embd", -1);
|
cb(cur, "result_embd", -1);
|
||||||
|
|
||||||
ggml_build_forward_expand(gf, cur);
|
ggml_build_forward_expand(gf, cur);
|
||||||
|
Loading…
Reference in New Issue
Block a user