mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-09 10:11:44 +00:00
llama : handle temp <= 0.0 in the temp_ext sampler too
Some checks are pending
flake8 Lint / Lint (push) Waiting to run
Some checks are pending
flake8 Lint / Lint (push) Waiting to run
ggml-ci
This commit is contained in:
parent
cd978508ac
commit
4a5b5870f1
@ -63,6 +63,33 @@ static void llama_log_softmax(float * array, size_t size) {
|
|||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
static void llama_sampler_temp_impl(llama_token_data_array * cur_p, float temp) {
|
||||||
|
if (temp <= 0.0f) {
|
||||||
|
// find the token with the highest logit and set the rest to -inf
|
||||||
|
llama_token max_id = cur_p->data[0].id;
|
||||||
|
float max_logit = cur_p->data[0].logit;
|
||||||
|
|
||||||
|
for (size_t i = 1; i < cur_p->size; ++i) {
|
||||||
|
if (cur_p->data[i].logit > max_logit) {
|
||||||
|
max_id = cur_p->data[i].id;
|
||||||
|
max_logit = cur_p->data[i].logit;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||||
|
if (cur_p->data[i].id != max_id) {
|
||||||
|
cur_p->data[i].logit = -INFINITY;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||||
|
cur_p->data[i].logit /= temp;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) {
|
static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) {
|
||||||
GGML_ASSERT(cur_p->size > 0);
|
GGML_ASSERT(cur_p->size > 0);
|
||||||
|
|
||||||
@ -916,30 +943,7 @@ static const char * llama_sampler_temp_name(const struct llama_sampler * /*smpl*
|
|||||||
static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||||
const auto * ctx = (llama_sampler_temp *) smpl->ctx;
|
const auto * ctx = (llama_sampler_temp *) smpl->ctx;
|
||||||
|
|
||||||
if (ctx->temp <= 0.0f) {
|
llama_sampler_temp_impl(cur_p, ctx->temp);
|
||||||
// find the token with the highest logit and set the rest to -inf
|
|
||||||
llama_token max_id = cur_p->data[0].id;
|
|
||||||
float max_logit = cur_p->data[0].logit;
|
|
||||||
|
|
||||||
for (size_t i = 1; i < cur_p->size; ++i) {
|
|
||||||
if (cur_p->data[i].logit > max_logit) {
|
|
||||||
max_id = cur_p->data[i].id;
|
|
||||||
max_logit = cur_p->data[i].logit;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
|
||||||
if (cur_p->data[i].id != max_id) {
|
|
||||||
cur_p->data[i].logit = -INFINITY;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
|
||||||
cur_p->data[i].logit /= ctx->temp;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static struct llama_sampler * llama_sampler_temp_clone(const struct llama_sampler * smpl) {
|
static struct llama_sampler * llama_sampler_temp_clone(const struct llama_sampler * smpl) {
|
||||||
@ -1024,9 +1028,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
// Apply the dynamically calculated temperature scaling
|
// Apply the dynamically calculated temperature scaling
|
||||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
llama_sampler_temp_impl(cur_p, dyn_temp);
|
||||||
cur_p->data[i].logit /= dyn_temp;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Re-compute softmax probabilities after scaling logits with dynamic temperature
|
// Re-compute softmax probabilities after scaling logits with dynamic temperature
|
||||||
const double max_l_double = cur_p->data[0].logit;
|
const double max_l_double = cur_p->data[0].logit;
|
||||||
@ -1050,9 +1052,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
} else {
|
} else {
|
||||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
llama_sampler_temp_impl(cur_p, ctx->temp);
|
||||||
cur_p->data[i].logit /= ctx->temp;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -70,6 +70,17 @@ static void test_temp(const std::vector<float> & probs, const std::vector<float>
|
|||||||
tester.check();
|
tester.check();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void test_temp_ext(const std::vector<float> & probs, const std::vector<float> & probs_expected, float temp, float delta, float exponent) {
|
||||||
|
sampler_tester tester(probs, probs_expected);
|
||||||
|
|
||||||
|
DUMP(&tester.cur_p);
|
||||||
|
tester.apply(llama_sampler_init_temp_ext(temp, delta, exponent));
|
||||||
|
tester.apply(llama_sampler_init_dist (0));
|
||||||
|
DUMP(&tester.cur_p);
|
||||||
|
|
||||||
|
tester.check();
|
||||||
|
}
|
||||||
|
|
||||||
static void test_top_k(const std::vector<float> & probs, const std::vector<float> & probs_expected, int k) {
|
static void test_top_k(const std::vector<float> & probs, const std::vector<float> & probs_expected, int k) {
|
||||||
sampler_tester tester(probs, probs_expected);
|
sampler_tester tester(probs, probs_expected);
|
||||||
|
|
||||||
@ -277,6 +288,9 @@ int main(void) {
|
|||||||
test_temp({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f);
|
test_temp({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f);
|
||||||
test_temp({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.0f);
|
test_temp({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.0f);
|
||||||
|
|
||||||
|
test_temp_ext({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f, 0.0f, 1.0f);
|
||||||
|
test_temp_ext({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.0f, 0.0f, 1.0f);
|
||||||
|
|
||||||
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f}, 1);
|
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f}, 1);
|
||||||
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.44444f, 0.33333f, 0.22222f}, 3);
|
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.44444f, 0.33333f, 0.22222f}, 3);
|
||||||
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 4);
|
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 4);
|
||||||
|
Loading…
Reference in New Issue
Block a user