mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 03:14:35 +00:00
llama : streamline embeddings from "non-embedding" models (#8087)
This commit is contained in:
parent
bcefa03bc0
commit
d12f781074
@ -472,6 +472,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
|||||||
else { invalid_param = true; }
|
else { invalid_param = true; }
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
if (arg == "--attention") {
|
||||||
|
CHECK_ARG
|
||||||
|
std::string value(argv[i]);
|
||||||
|
/**/ if (value == "causal") { params.attention_type = LLAMA_ATTENTION_TYPE_CAUSAL; }
|
||||||
|
else if (value == "non-causal") { params.attention_type = LLAMA_ATTENTION_TYPE_NON_CAUSAL; }
|
||||||
|
else { invalid_param = true; }
|
||||||
|
return true;
|
||||||
|
}
|
||||||
if (arg == "--defrag-thold" || arg == "-dt") {
|
if (arg == "--defrag-thold" || arg == "-dt") {
|
||||||
CHECK_ARG
|
CHECK_ARG
|
||||||
params.defrag_thold = std::stof(argv[i]);
|
params.defrag_thold = std::stof(argv[i]);
|
||||||
@ -1468,8 +1476,10 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
|
|||||||
"For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead" });
|
"For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead" });
|
||||||
|
|
||||||
options.push_back({ "embedding" });
|
options.push_back({ "embedding" });
|
||||||
options.push_back({ "embedding", " --pooling {none,mean,cls}",
|
options.push_back({ "embedding", " --pooling {none,mean,cls,last}",
|
||||||
"pooling type for embeddings, use model default if unspecified" });
|
"pooling type for embeddings, use model default if unspecified" });
|
||||||
|
options.push_back({ "embedding", " --attention {causal,non-causal}",
|
||||||
|
"attention type for embeddings, use model default if unspecified" });
|
||||||
|
|
||||||
options.push_back({ "context hacking" });
|
options.push_back({ "context hacking" });
|
||||||
options.push_back({ "*", " --rope-scaling {none,linear,yarn}",
|
options.push_back({ "*", " --rope-scaling {none,linear,yarn}",
|
||||||
@ -2175,6 +2185,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
|
|||||||
cparams.yarn_beta_slow = params.yarn_beta_slow;
|
cparams.yarn_beta_slow = params.yarn_beta_slow;
|
||||||
cparams.yarn_orig_ctx = params.yarn_orig_ctx;
|
cparams.yarn_orig_ctx = params.yarn_orig_ctx;
|
||||||
cparams.pooling_type = params.pooling_type;
|
cparams.pooling_type = params.pooling_type;
|
||||||
|
cparams.attention_type = params.attention_type;
|
||||||
cparams.defrag_thold = params.defrag_thold;
|
cparams.defrag_thold = params.defrag_thold;
|
||||||
cparams.cb_eval = params.cb_eval;
|
cparams.cb_eval = params.cb_eval;
|
||||||
cparams.cb_eval_user_data = params.cb_eval_user_data;
|
cparams.cb_eval_user_data = params.cb_eval_user_data;
|
||||||
|
@ -99,6 +99,7 @@ struct gpt_params {
|
|||||||
enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
|
enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
|
||||||
enum llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
|
enum llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
|
||||||
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
|
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
|
||||||
|
enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings
|
||||||
|
|
||||||
// // sampling parameters
|
// // sampling parameters
|
||||||
struct llama_sampling_params sparams;
|
struct llama_sampling_params sparams;
|
||||||
|
@ -180,6 +180,12 @@ extern "C" {
|
|||||||
LLAMA_POOLING_TYPE_LAST = 3,
|
LLAMA_POOLING_TYPE_LAST = 3,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
enum llama_attention_type {
|
||||||
|
LLAMA_ATTENTION_TYPE_UNSPECIFIED = -1,
|
||||||
|
LLAMA_ATTENTION_TYPE_CAUSAL = 0,
|
||||||
|
LLAMA_ATTENTION_TYPE_NON_CAUSAL = 1,
|
||||||
|
};
|
||||||
|
|
||||||
enum llama_split_mode {
|
enum llama_split_mode {
|
||||||
LLAMA_SPLIT_MODE_NONE = 0, // single GPU
|
LLAMA_SPLIT_MODE_NONE = 0, // single GPU
|
||||||
LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs
|
LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs
|
||||||
@ -297,6 +303,7 @@ extern "C" {
|
|||||||
|
|
||||||
enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
|
enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
|
||||||
enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id
|
enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id
|
||||||
|
enum llama_attention_type attention_type; // attention type to use for embeddings
|
||||||
|
|
||||||
// ref: https://github.com/ggerganov/llama.cpp/pull/2054
|
// ref: https://github.com/ggerganov/llama.cpp/pull/2054
|
||||||
float rope_freq_base; // RoPE base frequency, 0 = from model
|
float rope_freq_base; // RoPE base frequency, 0 = from model
|
||||||
|
@ -13840,7 +13840,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
|
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
|
||||||
const int64_t n_tokens = batch.n_tokens;
|
const int64_t n_tokens = batch.n_tokens;
|
||||||
|
|
||||||
GGML_ASSERT(lctx.inp_mean);
|
GGML_ASSERT(lctx.inp_mean);
|
||||||
@ -13872,7 +13872,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) {
|
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) {
|
||||||
const int64_t n_tokens = batch.n_tokens;
|
const int64_t n_tokens = batch.n_tokens;
|
||||||
|
|
||||||
GGML_ASSERT(lctx.inp_cls);
|
GGML_ASSERT(lctx.inp_cls);
|
||||||
@ -13893,7 +13893,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
|
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
|
||||||
const int64_t n_tokens = batch.n_tokens;
|
const int64_t n_tokens = batch.n_tokens;
|
||||||
|
|
||||||
GGML_ASSERT(lctx.inp_cls);
|
GGML_ASSERT(lctx.inp_cls);
|
||||||
@ -14181,14 +14181,15 @@ static int llama_decode_internal(
|
|||||||
std::vector<llama_seq_id *> seq_id_arr;
|
std::vector<llama_seq_id *> seq_id_arr;
|
||||||
std::vector<std::vector<llama_seq_id>> seq_id;
|
std::vector<std::vector<llama_seq_id>> seq_id;
|
||||||
|
|
||||||
|
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
|
||||||
|
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
||||||
|
|
||||||
// count outputs
|
// count outputs
|
||||||
if (cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE) {
|
if (batch_all.logits && !embd_pooled) {
|
||||||
n_outputs = n_tokens_all;
|
|
||||||
} else if (batch_all.logits) {
|
|
||||||
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
||||||
n_outputs += batch_all.logits[i] != 0;
|
n_outputs += batch_all.logits[i] != 0;
|
||||||
}
|
}
|
||||||
} else if (lctx.logits_all) {
|
} else if (lctx.logits_all || embd_pooled) {
|
||||||
n_outputs = n_tokens_all;
|
n_outputs = n_tokens_all;
|
||||||
} else {
|
} else {
|
||||||
// keep last output only
|
// keep last output only
|
||||||
@ -14234,7 +14235,7 @@ static int llama_decode_internal(
|
|||||||
{
|
{
|
||||||
int32_t n_outputs_new = 0;
|
int32_t n_outputs_new = 0;
|
||||||
|
|
||||||
if (u_batch.logits) {
|
if (u_batch.logits && !embd_pooled) {
|
||||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||||
n_outputs_new += u_batch.logits[i] != 0;
|
n_outputs_new += u_batch.logits[i] != 0;
|
||||||
}
|
}
|
||||||
@ -18533,6 +18534,7 @@ struct llama_context_params llama_context_default_params() {
|
|||||||
/*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS,
|
/*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS,
|
||||||
/*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
|
/*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
|
||||||
/*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED,
|
/*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED,
|
||||||
|
/*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED,
|
||||||
/*.rope_freq_base =*/ 0.0f,
|
/*.rope_freq_base =*/ 0.0f,
|
||||||
/*.rope_freq_scale =*/ 0.0f,
|
/*.rope_freq_scale =*/ 0.0f,
|
||||||
/*.yarn_ext_factor =*/ -1.0f,
|
/*.yarn_ext_factor =*/ -1.0f,
|
||||||
@ -18785,7 +18787,6 @@ struct llama_context * llama_new_context_with_model(
|
|||||||
}
|
}
|
||||||
|
|
||||||
cparams.yarn_attn_factor *= hparams.rope_attn_factor;
|
cparams.yarn_attn_factor *= hparams.rope_attn_factor;
|
||||||
cparams.causal_attn = hparams.causal_attn;
|
|
||||||
|
|
||||||
if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
|
if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
|
||||||
if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
|
if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
|
||||||
@ -18795,6 +18796,12 @@ struct llama_context * llama_new_context_with_model(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (params.attention_type == LLAMA_ATTENTION_TYPE_UNSPECIFIED) {
|
||||||
|
cparams.causal_attn = hparams.causal_attn;
|
||||||
|
} else {
|
||||||
|
cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL;
|
||||||
|
}
|
||||||
|
|
||||||
if (params.seed == LLAMA_DEFAULT_SEED) {
|
if (params.seed == LLAMA_DEFAULT_SEED) {
|
||||||
params.seed = time(NULL);
|
params.seed = time(NULL);
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user