This commit is contained in:
MaggotHATE 2024-10-31 04:26:28 +00:00 committed by GitHub
commit e0611f1a6d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 140 additions and 16 deletions

View File

@ -922,6 +922,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.sparams.temp = std::max(params.sparams.temp, 0.0f);
}
).set_sparam());
add_opt(common_arg(
{"--k-shift"}, "N",
string_format("k-shift sampling (default: %d, 0 = disabled)", params.sparams.k_shift),
[](common_params & params, int value) {
params.sparams.k_shift = value;
}
).set_sparam());
add_opt(common_arg(
{"--top-k"}, "N",
string_format("top-k sampling (default: %d, 0 = disabled)", params.sparams.top_k),

View File

@ -2091,6 +2091,7 @@ void yaml_dump_non_result_info(FILE * stream, const common_params & params, cons
yaml_dump_vector_float(stream, "tensor_split", tensor_split_vector);
fprintf(stream, "threads: %d # default: %u\n", params.cpuparams.n_threads, std::thread::hardware_concurrency());
fprintf(stream, "k_shift: %d # default: 0\n", sparams.k_shift);
fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k);
fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p);
fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p);

View File

@ -85,14 +85,15 @@ enum llama_example {
enum common_sampler_type {
COMMON_SAMPLER_TYPE_NONE = 0,
COMMON_SAMPLER_TYPE_DRY = 1,
COMMON_SAMPLER_TYPE_TOP_K = 2,
COMMON_SAMPLER_TYPE_TOP_P = 3,
COMMON_SAMPLER_TYPE_MIN_P = 4,
//COMMON_SAMPLER_TYPE_TFS_Z = 5,
COMMON_SAMPLER_TYPE_TYPICAL_P = 6,
COMMON_SAMPLER_TYPE_TEMPERATURE = 7,
COMMON_SAMPLER_TYPE_XTC = 8,
COMMON_SAMPLER_TYPE_INFILL = 9,
COMMON_SAMPLER_TYPE_K_SHIFT = 2,
COMMON_SAMPLER_TYPE_TOP_K = 3,
COMMON_SAMPLER_TYPE_TOP_P = 4,
COMMON_SAMPLER_TYPE_MIN_P = 5,
//COMMON_SAMPLER_TYPE_TFS_Z = 6,
COMMON_SAMPLER_TYPE_TYPICAL_P = 7,
COMMON_SAMPLER_TYPE_TEMPERATURE = 8,
COMMON_SAMPLER_TYPE_XTC = 9,
COMMON_SAMPLER_TYPE_INFILL = 10,
};
// dimensionality reduction methods, used by cvector-generator
@ -108,6 +109,7 @@ struct common_sampler_params {
int32_t n_prev = 64; // number of previous tokens to remember
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
int32_t k_shift = 0; // 0 = disabled
int32_t top_k = 40; // <= 0 to use vocab size
float top_p = 0.95f; // 1.0 = disabled
float min_p = 0.05f; // 0.0 = disabled
@ -137,6 +139,7 @@ struct common_sampler_params {
std::vector<enum common_sampler_type> samplers = {
COMMON_SAMPLER_TYPE_DRY,
COMMON_SAMPLER_TYPE_K_SHIFT,
COMMON_SAMPLER_TYPE_TOP_K,
COMMON_SAMPLER_TYPE_TYPICAL_P,
COMMON_SAMPLER_TYPE_TOP_P,

View File

@ -131,11 +131,11 @@ std::string common_sampler_params::print() const {
snprintf(result, sizeof(result),
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
"\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n"
"\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f\n"
"\tk_shift = %d, top_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f\n"
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n,
top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp,
k_shift, top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp,
mirostat, mirostat_eta, mirostat_tau);
return std::string(result);
@ -187,6 +187,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (model, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
}
break;
case COMMON_SAMPLER_TYPE_K_SHIFT:
llama_sampler_chain_add(result->chain, llama_sampler_init_k_shift (params.k_shift));
break;
case COMMON_SAMPLER_TYPE_TOP_K:
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
break;
@ -369,6 +372,7 @@ std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx_
char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
switch (cnstr) {
case COMMON_SAMPLER_TYPE_DRY: return 'd';
case COMMON_SAMPLER_TYPE_K_SHIFT: return 's';
case COMMON_SAMPLER_TYPE_TOP_K: return 'k';
case COMMON_SAMPLER_TYPE_TYPICAL_P: return 'y';
case COMMON_SAMPLER_TYPE_TOP_P: return 'p';
@ -383,6 +387,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
switch (cnstr) {
case COMMON_SAMPLER_TYPE_DRY: return "dry";
case COMMON_SAMPLER_TYPE_K_SHIFT: return "k_shift";
case COMMON_SAMPLER_TYPE_TOP_K: return "top_k";
case COMMON_SAMPLER_TYPE_TYPICAL_P: return "typ_p";
case COMMON_SAMPLER_TYPE_TOP_P: return "top_p";
@ -398,6 +403,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
std::unordered_map<std::string, common_sampler_type> sampler_canonical_name_map {
{ "dry", COMMON_SAMPLER_TYPE_DRY },
{ "top_k", COMMON_SAMPLER_TYPE_TOP_K },
{ "k_shift", COMMON_SAMPLER_TYPE_K_SHIFT },
{ "top_p", COMMON_SAMPLER_TYPE_TOP_P },
{ "typ_p", COMMON_SAMPLER_TYPE_TYPICAL_P },
{ "min_p", COMMON_SAMPLER_TYPE_MIN_P },
@ -410,6 +416,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
// make it ready for both system names and input names
std::unordered_map<std::string, common_sampler_type> sampler_alt_name_map {
{ "top-k", COMMON_SAMPLER_TYPE_TOP_K },
{ "k-shift", COMMON_SAMPLER_TYPE_K_SHIFT },
{ "top-p", COMMON_SAMPLER_TYPE_TOP_P },
{ "nucleus", COMMON_SAMPLER_TYPE_TOP_P },
{ "typical-p", COMMON_SAMPLER_TYPE_TYPICAL_P },
@ -443,6 +450,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
std::vector<common_sampler_type> common_sampler_types_from_chars(const std::string & chars) {
std::unordered_map<char, common_sampler_type> sampler_name_map = {
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_DRY), COMMON_SAMPLER_TYPE_DRY },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_K_SHIFT), COMMON_SAMPLER_TYPE_K_SHIFT },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_K), COMMON_SAMPLER_TYPE_TOP_K },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P), COMMON_SAMPLER_TYPE_TYPICAL_P },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P), COMMON_SAMPLER_TYPE_TOP_P },

View File

@ -211,6 +211,14 @@ DRY sampling provides more nuanced control over text generation, particularly fo
Example usage: `--dry-multiplier 0.8 --dry-base 1.75 --dry-allowed-length 2 --dry-penalty-last-n -1 --dry-sequence-breaker "—" --dry-sequence-breaker "##"`
### K-Shift Sampling
- `--k-shift N`: Shift the first token selection by cutting out N tokens from the top once (default: 0).
K-Shift is a sampling method that guides models away from the most obvious output, eliciting reasoning and analysis. It cuts out k top tokens once at the beginning of inference, making sure that the dialog will start from a less obvious path without guiding the model too much. The method was mentoned in a paper [Chain-of-Thought Reasoning without Prompting](https://arxiv.org/pdf/2402.10200) as a simple trick to guiding a model towards reasoning. In practice, K-Shift can improve the quality of reasoning, help bypass bias or censorship in certain cases, and may also be used as a diagnostics tool. K-Shift is intended to be used with greedy sampling (`--k-shift 10 --top-k 1`), but can help with creative writing too - albeit, not as much as XTC. The default value is 0.
Example usage: `--k-shift 10`
### Top-K Sampling
- `--top-k N`: Limit the next token selection to the K most probable tokens (default: 40).

View File

@ -44,7 +44,8 @@
dry_base: 1.75, // 0.0 = disabled
dry_allowed_length: 2, // tokens extending repetitions beyond this receive penalty, 2 works well
dry_penalty_last_n: -1, // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
top_k: 0, // <= 0 to use vocab size
k_shift: 0, // <= 0 to use vocab size
top_k: 0, // 0 = disabled
top_p: 1.0, // 1.0 = disabled
min_p: 0.05, // 0 = disabled; recommended for non-english: ~ 0.4
xtc_probability: 0.0, // 0 = disabled;
@ -834,6 +835,7 @@ return html`
<details>
<summary><span class="summary-title">Further Options</span></summary>
<fieldset class="params">
${IntField({ label: "K-Shift", title: "Cuts out first k tokens once at the start of sampling. Intended to use with greedy sampling.", max: 100, min: 0, step: 1, name: "k_shift", value: params.value.k_shift })}
${IntField({ label: "Top-K", title: "Limits the selection of the next token to the K most probable tokens. 1 means no randomness = greedy sampling. If set to 0, it means the entire vocabulary size is considered.", max: 100, min: 0, step: 1, name: "top_k", value: params.value.top_k })}
${IntField({ label: "Penalize Last N", title: "The last n tokens that are taken into account to penalise repetitions. A value of 0 means that this function is deactivated and -1 means that the entire size of the context is taken into account.", max: 2048, min: 0, step: 16, name: "repeat_last_n", value: params.value.repeat_last_n })}
${FloatField({ label: "Presence Penalty", title: "A penalty that is applied if certain tokens appear repeatedly in the generated text. A higher value leads to fewer repetitions.", max: 1.0, min: 0.0, name: "presence_penalty", step: 0.01, value: params.value.presence_penalty })}

View File

@ -308,6 +308,7 @@
dry_base: 1.75, // 0.0 = disabled
dry_allowed_length: 2, // tokens extending repetitions beyond this receive penalty, 2 works well
dry_penalty_last_n: -1, // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
k_shift: 0, // 0 = disabled
top_k: 40, // <= 0 to use vocab size
top_p: 0.95, // 1.0 = disabled
min_p: 0.05, // 0 = disabled
@ -1007,6 +1008,7 @@
${FloatField({ label: "Penalize repeat sequence", max: 2.0, min: 0.0, name: "repeat_penalty", step: 0.01, value: params.value.repeat_penalty })}
${IntField({ label: "Consider N tokens for penalize", max: 2048, min: 0, name: "repeat_last_n", value: params.value.repeat_last_n })}
${BoolField({ label: "Penalize repetition of newlines", name: "penalize_nl", value: params.value.penalize_nl })}
${IntField({ label: "K-shift", max: 100, min: -1, name: "k_shift", value: params.value.k_shift })}
${IntField({ label: "Top-K sampling", max: 100, min: -1, name: "top_k", value: params.value.top_k })}
${FloatField({ label: "Top-P sampling", max: 1.0, min: 0.0, name: "top_p", step: 0.01, value: params.value.top_p })}
${FloatField({ label: "Min-P sampling", max: 1.0, min: 0.0, name: "min_p", step: 0.01, value: params.value.min_p })}

View File

@ -804,6 +804,7 @@ struct server_context {
slot.params.cache_prompt = json_value(data, "cache_prompt", false);
slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", default_params.n_predict));
slot.params.n_indent = json_value(data, "n_indent", default_params.n_indent);
slot.sparams.k_shift = json_value(data, "k_shift", default_sparams.k_shift);
slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
@ -1143,6 +1144,7 @@ struct server_context {
{"temperature", slot.sparams.temp},
{"dynatemp_range", slot.sparams.dynatemp_range},
{"dynatemp_exponent", slot.sparams.dynatemp_exponent},
{"k_shift", slot.sparams.k_shift},
{"top_k", slot.sparams.top_k},
{"top_p", slot.sparams.top_p},
{"min_p", slot.sparams.min_p},

View File

@ -1096,6 +1096,9 @@ extern "C" {
/// @details XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335
LLAMA_API struct llama_sampler * llama_sampler_init_xtc (float p, float t, size_t min_keep, uint32_t seed);
LLAMA_API struct llama_sampler * llama_sampler_init_k_shift (int32_t k);
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.

View File

@ -188,6 +188,17 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
cur_p->size = k;
}
static void llama_sampler_top_shift_impl(llama_token_data_array * cur_p, int k) {
// sort before shifting
std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
return a.logit > b.logit;
});
// shift to a token #[k]
cur_p->data += k;
cur_p->size -= k;
}
static uint32_t get_rng_seed(uint32_t seed) {
if (seed == LLAMA_DEFAULT_SEED) {
// use system clock if std::random_device is not a true RNG
@ -1082,6 +1093,64 @@ struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep,
};
}
// k-shift
struct llama_sampler_k_shift {
const int32_t k;
bool k_set;
};
static const char * llama_sampler_k_shift_name(const struct llama_sampler * /*smpl*/) {
return "k-shift";
}
static void llama_sampler_k_shift_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
auto * ctx = (llama_sampler_k_shift *) smpl->ctx;
if (ctx->k_set == true
|| ctx->k <= 0
|| ctx->k >= (int) cur_p->size) {
return;
}
llama_sampler_top_shift_impl(cur_p, ctx->k);
ctx->k_set = true;
}
static struct llama_sampler * llama_sampler_k_shift_clone(const struct llama_sampler * smpl) {
auto * ctx = (const llama_sampler_k_shift *) smpl->ctx;
return llama_sampler_init_k_shift(ctx->k);
}
static void llama_sampler_k_shift_free(struct llama_sampler * smpl) {
delete (llama_sampler_k_shift *) smpl->ctx;
}
static void llama_sampler_k_shift_reset(struct llama_sampler * smpl) {
auto * ctx = (llama_sampler_k_shift *) smpl->ctx;
ctx->k_set = false;
}
static struct llama_sampler_i llama_sampler_k_shift_i = {
/* .name = */ llama_sampler_k_shift_name,
/* .accept = */ nullptr,
/* .apply = */ llama_sampler_k_shift_apply,
/* .reset = */ llama_sampler_k_shift_reset,
/* .clone = */ llama_sampler_k_shift_clone,
/* .free = */ llama_sampler_k_shift_free,
};
struct llama_sampler * llama_sampler_init_k_shift(int32_t k) {
return new llama_sampler {
/* .iface = */ &llama_sampler_k_shift_i,
/* .ctx = */ new llama_sampler_k_shift {
/* .k = */ k,
/* .k_set = */ false,
},
};
}
// mirostat
struct llama_sampler_mirostat {

View File

@ -83,6 +83,17 @@ static void test_temp_ext(const std::vector<float> & probs, const std::vector<fl
tester.check();
}
static void test_k_shift(const std::vector<float> & probs, const std::vector<float> & probs_expected, int k) {
sampler_tester tester(probs, probs_expected);
DUMP(&tester.cur_p);
tester.apply(llama_sampler_init_k_shift(k));
tester.apply(llama_sampler_init_dist (0));
DUMP(&tester.cur_p);
tester.check();
}
static void test_top_k(const std::vector<float> & probs, const std::vector<float> & probs_expected, int k) {
sampler_tester tester(probs, probs_expected);
@ -288,11 +299,13 @@ static void test_perf() {
data.emplace_back(llama_token_data{i, logit, 0.0f});
}
BENCH(llama_sampler_init_top_k (40), data, 32);
BENCH(llama_sampler_init_top_p (0.8f, 1), data, 32);
BENCH(llama_sampler_init_min_p (0.2f, 1), data, 32);
BENCH(llama_sampler_init_typical(0.5f, 1), data, 32);
BENCH(llama_sampler_init_xtc (1.0f, 0.1f, 1, 1), data, 32);
BENCH(llama_sampler_init_k_shift (10), data, 32);
BENCH(llama_sampler_init_top_k (40), data, 32);
BENCH(llama_sampler_init_top_p (0.8f, 1), data, 32);
BENCH(llama_sampler_init_min_p (0.2f, 1), data, 32);
BENCH(llama_sampler_init_typical (0.5f, 1), data, 32);
BENCH(llama_sampler_init_xtc (1.0f, 0.1f, 1, 1), data, 32);
}
int main(void) {
@ -304,6 +317,12 @@ int main(void) {
test_temp_ext({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f, 0.0f, 1.0f);
test_temp_ext({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.0f, 0.0f, 1.0f);
test_k_shift({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 4);
test_k_shift({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f}, 3);
test_k_shift({0.1f, 0.2f, 0.3f, 0.4f}, {0.66666f, 0.33333f}, 2);
test_k_shift({0.1f, 0.2f, 0.3f, 0.4f}, {0.5f, 0.33333f, 0.16666f}, 1);
test_k_shift({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0);
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f}, 1);
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.44444f, 0.33333f, 0.22222f}, 3);
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 4);