llama : handle temp <= 0.0 in the temp_ext sampler too
Some checks are pending
flake8 Lint / Lint (push) Waiting to run

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-10-17 22:53:22 +03:00
parent cd978508ac
commit 4a5b5870f1
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 44 additions and 30 deletions

View File

@ -63,6 +63,33 @@ static void llama_log_softmax(float * array, size_t size) {
}
*/
static void llama_sampler_temp_impl(llama_token_data_array * cur_p, float temp) {
if (temp <= 0.0f) {
// find the token with the highest logit and set the rest to -inf
llama_token max_id = cur_p->data[0].id;
float max_logit = cur_p->data[0].logit;
for (size_t i = 1; i < cur_p->size; ++i) {
if (cur_p->data[i].logit > max_logit) {
max_id = cur_p->data[i].id;
max_logit = cur_p->data[i].logit;
}
}
for (size_t i = 0; i < cur_p->size; ++i) {
if (cur_p->data[i].id != max_id) {
cur_p->data[i].logit = -INFINITY;
}
}
return;
}
for (size_t i = 0; i < cur_p->size; ++i) {
cur_p->data[i].logit /= temp;
}
}
static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) {
GGML_ASSERT(cur_p->size > 0);
@ -916,30 +943,7 @@ static const char * llama_sampler_temp_name(const struct llama_sampler * /*smpl*
static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
const auto * ctx = (llama_sampler_temp *) smpl->ctx;
if (ctx->temp <= 0.0f) {
// find the token with the highest logit and set the rest to -inf
llama_token max_id = cur_p->data[0].id;
float max_logit = cur_p->data[0].logit;
for (size_t i = 1; i < cur_p->size; ++i) {
if (cur_p->data[i].logit > max_logit) {
max_id = cur_p->data[i].id;
max_logit = cur_p->data[i].logit;
}
}
for (size_t i = 0; i < cur_p->size; ++i) {
if (cur_p->data[i].id != max_id) {
cur_p->data[i].logit = -INFINITY;
}
}
return;
}
for (size_t i = 0; i < cur_p->size; ++i) {
cur_p->data[i].logit /= ctx->temp;
}
llama_sampler_temp_impl(cur_p, ctx->temp);
}
static struct llama_sampler * llama_sampler_temp_clone(const struct llama_sampler * smpl) {
@ -1024,9 +1028,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke
#endif
// Apply the dynamically calculated temperature scaling
for (size_t i = 0; i < cur_p->size; ++i) {
cur_p->data[i].logit /= dyn_temp;
}
llama_sampler_temp_impl(cur_p, dyn_temp);
// Re-compute softmax probabilities after scaling logits with dynamic temperature
const double max_l_double = cur_p->data[0].logit;
@ -1050,9 +1052,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke
}
#endif
} else {
for (size_t i = 0; i < cur_p->size; ++i) {
cur_p->data[i].logit /= ctx->temp;
}
llama_sampler_temp_impl(cur_p, ctx->temp);
}
}

View File

@ -70,6 +70,17 @@ static void test_temp(const std::vector<float> & probs, const std::vector<float>
tester.check();
}
static void test_temp_ext(const std::vector<float> & probs, const std::vector<float> & probs_expected, float temp, float delta, float exponent) {
sampler_tester tester(probs, probs_expected);
DUMP(&tester.cur_p);
tester.apply(llama_sampler_init_temp_ext(temp, delta, exponent));
tester.apply(llama_sampler_init_dist (0));
DUMP(&tester.cur_p);
tester.check();
}
static void test_top_k(const std::vector<float> & probs, const std::vector<float> & probs_expected, int k) {
sampler_tester tester(probs, probs_expected);
@ -277,6 +288,9 @@ int main(void) {
test_temp({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f);
test_temp({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.0f);
test_temp_ext({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f, 0.0f, 1.0f);
test_temp_ext({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.0f, 0.0f, 1.0f);
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f}, 1);
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.44444f, 0.33333f, 0.22222f}, 3);
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 4);