mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-11-11 13:30:35 +00:00
llama : allow for user specified embedding pooling type (#5849)
* allow for user specified pooling type * llama : use enum types over int --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
87c2e8b279
commit
475df1d6cf
@ -335,6 +335,16 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
||||
break;
|
||||
}
|
||||
params.yarn_beta_slow = std::stof(argv[i]);
|
||||
} else if (arg == "--pooling") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
std::string value(argv[i]);
|
||||
/**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; }
|
||||
else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; }
|
||||
else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; }
|
||||
else { invalid_param = true; break; }
|
||||
} else if (arg == "--defrag-thold" || arg == "-dt") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
@ -1014,6 +1024,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
||||
printf(" --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n");
|
||||
printf(" --yarn-beta-slow N YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow);
|
||||
printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast);
|
||||
printf(" --pooling {none,mean,cls}\n");
|
||||
printf(" pooling type for embeddings, use model default if unspecified\n");
|
||||
printf(" -dt N, --defrag-thold N\n");
|
||||
printf(" KV cache defragmentation threshold (default: %.1f, < 0 - disabled)\n", params.defrag_thold);
|
||||
printf(" --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
|
||||
@ -1296,6 +1308,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
|
||||
cparams.yarn_beta_fast = params.yarn_beta_fast;
|
||||
cparams.yarn_beta_slow = params.yarn_beta_slow;
|
||||
cparams.yarn_orig_ctx = params.yarn_orig_ctx;
|
||||
cparams.pooling_type = params.pooling_type;
|
||||
cparams.defrag_thold = params.defrag_thold;
|
||||
cparams.offload_kqv = !params.no_kv_offload;
|
||||
|
||||
|
@ -76,9 +76,12 @@ struct gpt_params {
|
||||
float yarn_beta_slow = 1.0f; // YaRN high correction dim
|
||||
int32_t yarn_orig_ctx = 0; // YaRN original context length
|
||||
float defrag_thold = -1.0f; // KV cache defragmentation threshold
|
||||
int32_t rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
|
||||
|
||||
ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED;
|
||||
|
||||
llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
|
||||
llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
|
||||
|
||||
// // sampling parameters
|
||||
struct llama_sampling_params sparams;
|
||||
|
||||
|
@ -1644,16 +1644,17 @@ class BertModel(Model):
|
||||
self.gguf_writer.add_causal_attention(False)
|
||||
|
||||
# get pooling path
|
||||
with open(self.dir_model / "modules.json", encoding="utf-8") as f:
|
||||
modules = json.load(f)
|
||||
pooling_path = None
|
||||
module_path = self.dir_model / "modules.json"
|
||||
if module_path.is_file():
|
||||
with open(module_path, encoding="utf-8") as f:
|
||||
modules = json.load(f)
|
||||
for mod in modules:
|
||||
if mod["type"] == "sentence_transformers.models.Pooling":
|
||||
pooling_path = mod["path"]
|
||||
break
|
||||
|
||||
# get pooling type
|
||||
pooling_type = gguf.PoolingType.NONE
|
||||
if pooling_path is not None:
|
||||
with open(self.dir_model / pooling_path / "config.json", encoding="utf-8") as f:
|
||||
pooling = json.load(f)
|
||||
@ -1663,7 +1664,6 @@ class BertModel(Model):
|
||||
pooling_type = gguf.PoolingType.CLS
|
||||
else:
|
||||
raise NotImplementedError("Only MEAN and CLS pooling types supported")
|
||||
|
||||
self.gguf_writer.add_pooling_type(pooling_type)
|
||||
|
||||
def set_vocab(self):
|
||||
|
34
llama.cpp
34
llama.cpp
@ -873,16 +873,16 @@ struct LLM_TN {
|
||||
// gguf helpers
|
||||
//
|
||||
|
||||
static const std::map<int32_t, const char *> LLAMA_ROPE_SCALING_TYPES = {
|
||||
static const std::map<llama_rope_scaling_type, const char *> LLAMA_ROPE_SCALING_TYPES = {
|
||||
{ LLAMA_ROPE_SCALING_TYPE_NONE, "none" },
|
||||
{ LLAMA_ROPE_SCALING_TYPE_LINEAR, "linear" },
|
||||
{ LLAMA_ROPE_SCALING_TYPE_YARN, "yarn" },
|
||||
};
|
||||
|
||||
static int32_t llama_rope_scaling_type_from_string(const std::string & name) {
|
||||
static llama_rope_scaling_type llama_rope_scaling_type_from_string(const std::string & name) {
|
||||
for (const auto & kv : LLAMA_ROPE_SCALING_TYPES) {
|
||||
if (kv.second == name) {
|
||||
return kv.first;
|
||||
return (llama_rope_scaling_type) kv.first;
|
||||
}
|
||||
}
|
||||
|
||||
@ -1612,7 +1612,6 @@ struct llama_hparams {
|
||||
float rope_freq_base_train;
|
||||
float rope_freq_scale_train;
|
||||
uint32_t n_yarn_orig_ctx;
|
||||
int32_t rope_scaling_type_train;
|
||||
|
||||
float f_clamp_kqv = 0.0f;
|
||||
float f_max_alibi_bias = 0.0f;
|
||||
@ -1622,6 +1621,7 @@ struct llama_hparams {
|
||||
|
||||
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE;
|
||||
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
|
||||
enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE;
|
||||
|
||||
bool operator!=(const llama_hparams & other) const {
|
||||
if (this->vocab_only != other.vocab_only) return true;
|
||||
@ -1683,7 +1683,7 @@ struct llama_cparams {
|
||||
float defrag_thold;
|
||||
|
||||
bool offload_kqv;
|
||||
bool do_pooling;
|
||||
enum llama_pooling_type pooling_type;
|
||||
|
||||
ggml_backend_sched_eval_callback cb_eval;
|
||||
void * cb_eval_user_data;
|
||||
@ -2933,7 +2933,11 @@ template<>
|
||||
bool llama_model_loader::get_key(const enum llm_kv kid, enum llama_pooling_type & result, const bool required) {
|
||||
uint32_t tmp;
|
||||
const bool found = get_key(kid, tmp, required);
|
||||
if (found) {
|
||||
result = (enum llama_pooling_type) tmp;
|
||||
} else {
|
||||
result = LLAMA_POOLING_TYPE_UNSPECIFIED;
|
||||
}
|
||||
return found;
|
||||
}
|
||||
|
||||
@ -3210,7 +3214,7 @@ static void llm_load_hparams(
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
||||
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
|
||||
ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type);
|
||||
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type);
|
||||
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false);
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 3:
|
||||
@ -5175,7 +5179,7 @@ struct llm_build_context {
|
||||
n_kv (worst_case ? n_ctx : kv_self.n),
|
||||
kv_head (worst_case ? n_ctx - n_tokens : kv_self.head),
|
||||
n_orig_ctx (cparams.n_yarn_orig_ctx),
|
||||
pooling_type (cparams.do_pooling ? hparams.pooling_type : LLAMA_POOLING_TYPE_NONE),
|
||||
pooling_type (cparams.pooling_type),
|
||||
rope_type (hparams.rope_type),
|
||||
cb (cb),
|
||||
buf_compute_meta (lctx.buf_compute_meta) {
|
||||
@ -8015,7 +8019,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||
}
|
||||
}
|
||||
|
||||
if (cparams.do_pooling && hparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
|
||||
if (cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
|
||||
const int64_t n_tokens = batch.n_tokens;
|
||||
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer));
|
||||
@ -8043,7 +8047,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||
}
|
||||
}
|
||||
|
||||
if (cparams.do_pooling && hparams.pooling_type == LLAMA_POOLING_TYPE_CLS) {
|
||||
if (cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) {
|
||||
const int64_t n_tokens = batch.n_tokens;
|
||||
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
|
||||
@ -11846,6 +11850,7 @@ struct llama_context_params llama_context_default_params() {
|
||||
/*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
|
||||
/*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS,
|
||||
/*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
|
||||
/*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED,
|
||||
/*.rope_freq_base =*/ 0.0f,
|
||||
/*.rope_freq_scale =*/ 0.0f,
|
||||
/*.yarn_ext_factor =*/ -1.0f,
|
||||
@ -11861,7 +11866,6 @@ struct llama_context_params llama_context_default_params() {
|
||||
/*.logits_all =*/ false,
|
||||
/*.embedding =*/ false,
|
||||
/*.offload_kqv =*/ true,
|
||||
/*.do_pooling =*/ true,
|
||||
/*.abort_callback =*/ nullptr,
|
||||
/*.abort_callback_data =*/ nullptr,
|
||||
};
|
||||
@ -12012,7 +12016,7 @@ struct llama_context * llama_new_context_with_model(
|
||||
cparams.yarn_beta_slow = params.yarn_beta_slow;
|
||||
cparams.defrag_thold = params.defrag_thold;
|
||||
cparams.offload_kqv = params.offload_kqv;
|
||||
cparams.do_pooling = params.do_pooling;
|
||||
cparams.pooling_type = params.pooling_type;
|
||||
|
||||
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
|
||||
cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
|
||||
@ -12038,6 +12042,14 @@ struct llama_context * llama_new_context_with_model(
|
||||
cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f;
|
||||
}
|
||||
|
||||
if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
|
||||
if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
|
||||
cparams.pooling_type = LLAMA_POOLING_TYPE_NONE;
|
||||
} else {
|
||||
cparams.pooling_type = hparams.pooling_type;
|
||||
}
|
||||
}
|
||||
|
||||
if (params.seed == LLAMA_DEFAULT_SEED) {
|
||||
params.seed = time(NULL);
|
||||
}
|
||||
|
7
llama.h
7
llama.h
@ -129,6 +129,7 @@ extern "C" {
|
||||
};
|
||||
|
||||
enum llama_pooling_type {
|
||||
LLAMA_POOLING_TYPE_UNSPECIFIED = -1,
|
||||
LLAMA_POOLING_TYPE_NONE = 0,
|
||||
LLAMA_POOLING_TYPE_MEAN = 1,
|
||||
LLAMA_POOLING_TYPE_CLS = 2,
|
||||
@ -236,7 +237,10 @@ extern "C" {
|
||||
uint32_t n_batch; // prompt processing maximum batch size
|
||||
uint32_t n_threads; // number of threads to use for generation
|
||||
uint32_t n_threads_batch; // number of threads to use for batch processing
|
||||
int32_t rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
|
||||
|
||||
enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
|
||||
enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id
|
||||
// (ignored if no pooling layer)
|
||||
|
||||
// ref: https://github.com/ggerganov/llama.cpp/pull/2054
|
||||
float rope_freq_base; // RoPE base frequency, 0 = from model
|
||||
@ -258,7 +262,6 @@ extern "C" {
|
||||
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
|
||||
bool embedding; // embedding mode only
|
||||
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
|
||||
bool do_pooling; // whether to pool (sum) embedding results by sequence id (ignored if no pooling layer)
|
||||
|
||||
// Abort callback
|
||||
// if it returns true, execution of llama_decode() will be aborted
|
||||
|
Loading…
Reference in New Issue
Block a user