mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-08 09:41:45 +00:00
llama : handle temp <= 0.0 in the temp_ext sampler too
Some checks are pending
flake8 Lint / Lint (push) Waiting to run
Some checks are pending
flake8 Lint / Lint (push) Waiting to run
ggml-ci
This commit is contained in:
parent
cd978508ac
commit
4a5b5870f1
@ -63,6 +63,33 @@ static void llama_log_softmax(float * array, size_t size) {
|
||||
}
|
||||
*/
|
||||
|
||||
static void llama_sampler_temp_impl(llama_token_data_array * cur_p, float temp) {
|
||||
if (temp <= 0.0f) {
|
||||
// find the token with the highest logit and set the rest to -inf
|
||||
llama_token max_id = cur_p->data[0].id;
|
||||
float max_logit = cur_p->data[0].logit;
|
||||
|
||||
for (size_t i = 1; i < cur_p->size; ++i) {
|
||||
if (cur_p->data[i].logit > max_logit) {
|
||||
max_id = cur_p->data[i].id;
|
||||
max_logit = cur_p->data[i].logit;
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
if (cur_p->data[i].id != max_id) {
|
||||
cur_p->data[i].logit = -INFINITY;
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
cur_p->data[i].logit /= temp;
|
||||
}
|
||||
}
|
||||
|
||||
static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) {
|
||||
GGML_ASSERT(cur_p->size > 0);
|
||||
|
||||
@ -916,30 +943,7 @@ static const char * llama_sampler_temp_name(const struct llama_sampler * /*smpl*
|
||||
static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||
const auto * ctx = (llama_sampler_temp *) smpl->ctx;
|
||||
|
||||
if (ctx->temp <= 0.0f) {
|
||||
// find the token with the highest logit and set the rest to -inf
|
||||
llama_token max_id = cur_p->data[0].id;
|
||||
float max_logit = cur_p->data[0].logit;
|
||||
|
||||
for (size_t i = 1; i < cur_p->size; ++i) {
|
||||
if (cur_p->data[i].logit > max_logit) {
|
||||
max_id = cur_p->data[i].id;
|
||||
max_logit = cur_p->data[i].logit;
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
if (cur_p->data[i].id != max_id) {
|
||||
cur_p->data[i].logit = -INFINITY;
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
cur_p->data[i].logit /= ctx->temp;
|
||||
}
|
||||
llama_sampler_temp_impl(cur_p, ctx->temp);
|
||||
}
|
||||
|
||||
static struct llama_sampler * llama_sampler_temp_clone(const struct llama_sampler * smpl) {
|
||||
@ -1024,9 +1028,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke
|
||||
#endif
|
||||
|
||||
// Apply the dynamically calculated temperature scaling
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
cur_p->data[i].logit /= dyn_temp;
|
||||
}
|
||||
llama_sampler_temp_impl(cur_p, dyn_temp);
|
||||
|
||||
// Re-compute softmax probabilities after scaling logits with dynamic temperature
|
||||
const double max_l_double = cur_p->data[0].logit;
|
||||
@ -1050,9 +1052,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke
|
||||
}
|
||||
#endif
|
||||
} else {
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
cur_p->data[i].logit /= ctx->temp;
|
||||
}
|
||||
llama_sampler_temp_impl(cur_p, ctx->temp);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -70,6 +70,17 @@ static void test_temp(const std::vector<float> & probs, const std::vector<float>
|
||||
tester.check();
|
||||
}
|
||||
|
||||
static void test_temp_ext(const std::vector<float> & probs, const std::vector<float> & probs_expected, float temp, float delta, float exponent) {
|
||||
sampler_tester tester(probs, probs_expected);
|
||||
|
||||
DUMP(&tester.cur_p);
|
||||
tester.apply(llama_sampler_init_temp_ext(temp, delta, exponent));
|
||||
tester.apply(llama_sampler_init_dist (0));
|
||||
DUMP(&tester.cur_p);
|
||||
|
||||
tester.check();
|
||||
}
|
||||
|
||||
static void test_top_k(const std::vector<float> & probs, const std::vector<float> & probs_expected, int k) {
|
||||
sampler_tester tester(probs, probs_expected);
|
||||
|
||||
@ -277,6 +288,9 @@ int main(void) {
|
||||
test_temp({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f);
|
||||
test_temp({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.0f);
|
||||
|
||||
test_temp_ext({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f, 0.0f, 1.0f);
|
||||
test_temp_ext({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.0f, 0.0f, 1.0f);
|
||||
|
||||
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f}, 1);
|
||||
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.44444f, 0.33333f, 0.22222f}, 3);
|
||||
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 4);
|
||||
|
Loading…
Reference in New Issue
Block a user