mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-10 10:41:47 +00:00
sampling : change temperature sampler logic
For t <= 0.0f, keep the max logit intact and set the rest to -inf
This commit is contained in:
parent
33a69ec742
commit
cb75bebcad
@ -171,7 +171,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||||||
params.penalize_nl,
|
params.penalize_nl,
|
||||||
params.ignore_eos));
|
params.ignore_eos));
|
||||||
|
|
||||||
if (params.temp > 0.0f) {
|
if (params.temp >= 0.0f) {
|
||||||
if (params.mirostat == 0) {
|
if (params.mirostat == 0) {
|
||||||
for (const auto & cnstr : params.samplers) {
|
for (const auto & cnstr : params.samplers) {
|
||||||
switch (cnstr) {
|
switch (cnstr) {
|
||||||
@ -214,6 +214,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||||||
GGML_ASSERT(false && "unknown mirostat version");
|
GGML_ASSERT(false && "unknown mirostat version");
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
// negative temperatures will trigger "greedy" sampling: simply take the most likely token each time
|
||||||
if (params.n_probs > 0) {
|
if (params.n_probs > 0) {
|
||||||
// some use cases require to sample greedily, but still obtain the probabilities of the top tokens
|
// some use cases require to sample greedily, but still obtain the probabilities of the top tokens
|
||||||
// ref: https://github.com/ggerganov/llama.cpp/pull/9605
|
// ref: https://github.com/ggerganov/llama.cpp/pull/9605
|
||||||
|
@ -1104,6 +1104,8 @@ extern "C" {
|
|||||||
|
|
||||||
/// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
|
/// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
|
||||||
LLAMA_API struct llama_sampler * llama_sampler_init_typical (float p, size_t min_keep);
|
LLAMA_API struct llama_sampler * llama_sampler_init_typical (float p, size_t min_keep);
|
||||||
|
|
||||||
|
/// #details Updates the logits l_i` = l_i/t. When t <= 0.0f, the maximum logit is kept at it's original value, the rest are set to -inf
|
||||||
LLAMA_API struct llama_sampler * llama_sampler_init_temp (float t);
|
LLAMA_API struct llama_sampler * llama_sampler_init_temp (float t);
|
||||||
|
|
||||||
/// @details Dynamic temperature implementation (a.k.a. entropy) described in the paper https://arxiv.org/abs/2309.02772.
|
/// @details Dynamic temperature implementation (a.k.a. entropy) described in the paper https://arxiv.org/abs/2309.02772.
|
||||||
|
@ -915,6 +915,28 @@ static const char * llama_sampler_temp_name(const struct llama_sampler * /*smpl*
|
|||||||
|
|
||||||
static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||||
const auto * ctx = (llama_sampler_temp *) smpl->ctx;
|
const auto * ctx = (llama_sampler_temp *) smpl->ctx;
|
||||||
|
|
||||||
|
if (ctx->temp <= 0.0f) {
|
||||||
|
// find the token with the highest logit and set the rest to -inf
|
||||||
|
llama_token max_id = cur_p->data[0].id;
|
||||||
|
float max_logit = cur_p->data[0].logit;
|
||||||
|
|
||||||
|
for (size_t i = 1; i < cur_p->size; ++i) {
|
||||||
|
if (cur_p->data[i].logit > max_logit) {
|
||||||
|
max_id = cur_p->data[i].id;
|
||||||
|
max_logit = cur_p->data[i].logit;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||||
|
if (cur_p->data[i].id != max_id) {
|
||||||
|
cur_p->data[i].logit = -INFINITY;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||||
cur_p->data[i].logit /= ctx->temp;
|
cur_p->data[i].logit /= ctx->temp;
|
||||||
}
|
}
|
||||||
@ -964,6 +986,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke
|
|||||||
if (ctx->delta > 0) {
|
if (ctx->delta > 0) {
|
||||||
const float min_temp = std::max(0.0f, ctx->temp - ctx->delta);
|
const float min_temp = std::max(0.0f, ctx->temp - ctx->delta);
|
||||||
const float max_temp = ctx->temp + ctx->delta;
|
const float max_temp = ctx->temp + ctx->delta;
|
||||||
|
|
||||||
float exponent_val = ctx->exponent;
|
float exponent_val = ctx->exponent;
|
||||||
|
|
||||||
// no need to do anything if there is only one (or zero) candidates
|
// no need to do anything if there is only one (or zero) candidates
|
||||||
|
@ -274,6 +274,9 @@ static void test_perf() {
|
|||||||
int main(void) {
|
int main(void) {
|
||||||
ggml_time_init();
|
ggml_time_init();
|
||||||
|
|
||||||
|
test_temp({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f);
|
||||||
|
test_temp({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.0f);
|
||||||
|
|
||||||
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f}, 1);
|
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f}, 1);
|
||||||
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.44444f, 0.33333f, 0.22222f}, 3);
|
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.44444f, 0.33333f, 0.22222f}, 3);
|
||||||
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 4);
|
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 4);
|
||||||
|
Loading…
Reference in New Issue
Block a user