mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 02:44:36 +00:00
sampling : custom samplers order (#4285)
* Samplers sequence order w parameter * Cleaned commented code * Fixed formatting * Rewrote with unordered_map * Revert and rewrite, too many problems and safeguards would be needed * Fixed code style * Code style fixes according to review * More readable samplers input string, fixed help * Style fix in sampler_queue * Formatting fixes * Fixing whitespaces
This commit is contained in:
parent
e4b76bbe31
commit
52c8bc3cf3
@ -280,6 +280,18 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
|||||||
params.yarn_beta_slow = std::stof(argv[i]);
|
params.yarn_beta_slow = std::stof(argv[i]);
|
||||||
} else if (arg == "--memory-f32") {
|
} else if (arg == "--memory-f32") {
|
||||||
params.memory_f16 = false;
|
params.memory_f16 = false;
|
||||||
|
} else if (arg == "--samplers") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
sparams.samplers_sequence = parse_samplers_input(argv[i]);
|
||||||
|
} else if (arg == "--sampling-seq") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
sparams.samplers_sequence = argv[i];
|
||||||
} else if (arg == "--top-p") {
|
} else if (arg == "--top-p") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
@ -761,6 +773,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
|||||||
printf(" -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict);
|
printf(" -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict);
|
||||||
printf(" -c N, --ctx-size N size of the prompt context (default: %d, 0 = loaded from model)\n", params.n_ctx);
|
printf(" -c N, --ctx-size N size of the prompt context (default: %d, 0 = loaded from model)\n", params.n_ctx);
|
||||||
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
|
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
|
||||||
|
printf(" --samplers samplers that will be used for generation in the order, separated by \';\', for example: \"top_k;tfs;typical;top_p;min_p;temp\"\n");
|
||||||
|
printf(" --sampling-seq simplified sequence for samplers that will be used (default: %s)\n", sparams.samplers_sequence.c_str());
|
||||||
printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", sparams.top_k);
|
printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", sparams.top_k);
|
||||||
printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p);
|
printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p);
|
||||||
printf(" --min-p N min-p sampling (default: %.1f, 0.0 = disabled)\n", (double)sparams.min_p);
|
printf(" --min-p N min-p sampling (default: %.1f, 0.0 = disabled)\n", (double)sparams.min_p);
|
||||||
@ -886,6 +900,48 @@ std::string gpt_random_prompt(std::mt19937 & rng) {
|
|||||||
GGML_UNREACHABLE();
|
GGML_UNREACHABLE();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// String parsing
|
||||||
|
//
|
||||||
|
|
||||||
|
std::string parse_samplers_input(std::string input) {
|
||||||
|
std::string output = "";
|
||||||
|
// since samplers names are written multiple ways
|
||||||
|
// make it ready for both system names and input names
|
||||||
|
std::unordered_map<std::string, char> samplers_symbols {
|
||||||
|
{"top_k", 'k'},
|
||||||
|
{"top-k", 'k'},
|
||||||
|
{"top_p", 'p'},
|
||||||
|
{"top-p", 'p'},
|
||||||
|
{"nucleus", 'p'},
|
||||||
|
{"typical_p", 'y'},
|
||||||
|
{"typical-p", 'y'},
|
||||||
|
{"typical", 'y'},
|
||||||
|
{"min_p", 'm'},
|
||||||
|
{"min-p", 'm'},
|
||||||
|
{"tfs_z", 'f'},
|
||||||
|
{"tfs-z", 'f'},
|
||||||
|
{"tfs", 'f'},
|
||||||
|
{"temp", 't'},
|
||||||
|
{"temperature",'t'}
|
||||||
|
};
|
||||||
|
// expected format example: "temp;top_k;tfs_z;typical_p;top_p;min_p"
|
||||||
|
size_t separator = input.find(';');
|
||||||
|
while (separator != input.npos) {
|
||||||
|
std::string name = input.substr(0,separator);
|
||||||
|
input = input.substr(separator+1);
|
||||||
|
separator = input.find(';');
|
||||||
|
|
||||||
|
if (samplers_symbols.find(name) != samplers_symbols.end()) {
|
||||||
|
output += samplers_symbols[name];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (samplers_symbols.find(input) != samplers_symbols.end()) {
|
||||||
|
output += samplers_symbols[input];
|
||||||
|
}
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// Model utils
|
// Model utils
|
||||||
//
|
//
|
||||||
|
@ -141,6 +141,12 @@ std::string gpt_random_prompt(std::mt19937 & rng);
|
|||||||
|
|
||||||
void process_escapes(std::string& input);
|
void process_escapes(std::string& input);
|
||||||
|
|
||||||
|
//
|
||||||
|
// String parsing
|
||||||
|
//
|
||||||
|
|
||||||
|
std::string parse_samplers_input(std::string input);
|
||||||
|
|
||||||
//
|
//
|
||||||
// Model utils
|
// Model utils
|
||||||
//
|
//
|
||||||
|
@ -99,6 +99,54 @@ std::string llama_sampling_print(const llama_sampling_params & params) {
|
|||||||
return std::string(result);
|
return std::string(result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string llama_sampling_order_print(const llama_sampling_params & params) {
|
||||||
|
std::string result = "CFG -> Penalties ";
|
||||||
|
if (params.mirostat == 0) {
|
||||||
|
for (auto s : params.samplers_sequence) {
|
||||||
|
switch (s) {
|
||||||
|
case 'k': result += "-> top_k "; break;
|
||||||
|
case 'f': result += "-> tfs_z "; break;
|
||||||
|
case 'y': result += "-> typical_p "; break;
|
||||||
|
case 'p': result += "-> top_p "; break;
|
||||||
|
case 'm': result += "-> min_p "; break;
|
||||||
|
case 't': result += "-> temp "; break;
|
||||||
|
default : break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else result += "-> mirostat ";
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// no reasons to expose this function in header
|
||||||
|
void sampler_queue(
|
||||||
|
struct llama_context * ctx_main,
|
||||||
|
const llama_sampling_params & params,
|
||||||
|
llama_token_data_array & cur_p,
|
||||||
|
size_t & min_keep) {
|
||||||
|
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
||||||
|
|
||||||
|
const float temp = params.temp;
|
||||||
|
const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
|
||||||
|
const float top_p = params.top_p;
|
||||||
|
const float min_p = params.min_p;
|
||||||
|
const float tfs_z = params.tfs_z;
|
||||||
|
const float typical_p = params.typical_p;
|
||||||
|
const std::string & samplers_sequence = params.samplers_sequence;
|
||||||
|
|
||||||
|
for (auto s : samplers_sequence) {
|
||||||
|
switch (s){
|
||||||
|
case 'k': llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); break;
|
||||||
|
case 'f': llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); break;
|
||||||
|
case 'y': llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); break;
|
||||||
|
case 'p': llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); break;
|
||||||
|
case 'm': llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); break;
|
||||||
|
case 't': llama_sample_temp (ctx_main, &cur_p, temp); break;
|
||||||
|
default : break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
llama_token llama_sampling_sample(
|
llama_token llama_sampling_sample(
|
||||||
struct llama_sampling_context * ctx_sampling,
|
struct llama_sampling_context * ctx_sampling,
|
||||||
struct llama_context * ctx_main,
|
struct llama_context * ctx_main,
|
||||||
@ -109,11 +157,6 @@ llama_token llama_sampling_sample(
|
|||||||
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
||||||
|
|
||||||
const float temp = params.temp;
|
const float temp = params.temp;
|
||||||
const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
|
|
||||||
const float top_p = params.top_p;
|
|
||||||
const float min_p = params.min_p;
|
|
||||||
const float tfs_z = params.tfs_z;
|
|
||||||
const float typical_p = params.typical_p;
|
|
||||||
const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
|
const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
|
||||||
const float penalty_repeat = params.penalty_repeat;
|
const float penalty_repeat = params.penalty_repeat;
|
||||||
const float penalty_freq = params.penalty_freq;
|
const float penalty_freq = params.penalty_freq;
|
||||||
@ -188,12 +231,7 @@ llama_token llama_sampling_sample(
|
|||||||
// temperature sampling
|
// temperature sampling
|
||||||
size_t min_keep = std::max(1, params.n_probs);
|
size_t min_keep = std::max(1, params.n_probs);
|
||||||
|
|
||||||
llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep);
|
sampler_queue(ctx_main, params, cur_p, min_keep);
|
||||||
llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep);
|
|
||||||
llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep);
|
|
||||||
llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep);
|
|
||||||
llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep);
|
|
||||||
llama_sample_temp (ctx_main, &cur_p, temp);
|
|
||||||
|
|
||||||
id = llama_sample_token(ctx_main, &cur_p);
|
id = llama_sample_token(ctx_main, &cur_p);
|
||||||
|
|
||||||
|
@ -26,6 +26,7 @@ typedef struct llama_sampling_params {
|
|||||||
float mirostat_tau = 5.00f; // target entropy
|
float mirostat_tau = 5.00f; // target entropy
|
||||||
float mirostat_eta = 0.10f; // learning rate
|
float mirostat_eta = 0.10f; // learning rate
|
||||||
bool penalize_nl = true; // consider newlines as a repeatable token
|
bool penalize_nl = true; // consider newlines as a repeatable token
|
||||||
|
std::string samplers_sequence = "kfypmt"; // top_k, tail_free, typical_p, top_p, min_p, temp
|
||||||
|
|
||||||
std::string grammar; // optional BNF-like grammar to constrain sampling
|
std::string grammar; // optional BNF-like grammar to constrain sampling
|
||||||
|
|
||||||
@ -80,6 +81,9 @@ std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama
|
|||||||
// Print sampling parameters into a string
|
// Print sampling parameters into a string
|
||||||
std::string llama_sampling_print(const llama_sampling_params & params);
|
std::string llama_sampling_print(const llama_sampling_params & params);
|
||||||
|
|
||||||
|
// Print sampling order into a string
|
||||||
|
std::string llama_sampling_order_print(const llama_sampling_params & params);
|
||||||
|
|
||||||
// this is a common sampling function used across the examples for convenience
|
// this is a common sampling function used across the examples for convenience
|
||||||
// it can serve as a starting point for implementing your own sampling function
|
// it can serve as a starting point for implementing your own sampling function
|
||||||
// Note: When using multiple sequences, it is the caller's responsibility to call
|
// Note: When using multiple sequences, it is the caller's responsibility to call
|
||||||
|
@ -437,6 +437,7 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str());
|
LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str());
|
||||||
|
LOG_TEE("sampling order: \n%s\n", llama_sampling_order_print(sparams).c_str());
|
||||||
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
|
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
|
||||||
LOG_TEE("\n\n");
|
LOG_TEE("\n\n");
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user