mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 10:54:36 +00:00
sampling : fix off-by-one in tail-free sampling
ggml-ci
This commit is contained in:
parent
37f8c7b4c9
commit
114ab6347e
@ -963,7 +963,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(llama_arg(
|
||||
{"--tfs"}, "N",
|
||||
{"--tfs", "--tfs-z"}, "Z",
|
||||
format("tail free sampling, parameter z (default: %.1f, 1.0 = disabled)", (double)params.sparams.tfs_z),
|
||||
[](gpt_params & params, const std::string & value) {
|
||||
params.sparams.tfs_z = std::stof(value);
|
||||
|
@ -756,20 +756,22 @@ static void llama_sampler_tail_free_apply(struct llama_sampler * smpl, llama_tok
|
||||
}
|
||||
}
|
||||
|
||||
assert(cur_p->size > 0); // guaranteed earlier
|
||||
size_t last_idx = cur_p->size - 1;
|
||||
|
||||
float cum_sum = 0.0f;
|
||||
size_t last_idx = cur_p->size;
|
||||
for (size_t i = 0; i < second_derivatives.size(); ++i) {
|
||||
cum_sum += second_derivatives[i];
|
||||
|
||||
// Check if the running sum is greater than z or if we have kept at least min_keep tokens
|
||||
if (cum_sum > ctx->z && i >= ctx->min_keep) {
|
||||
if (cum_sum > ctx->z && (i + 1) >= ctx->min_keep) {
|
||||
last_idx = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Resize the output vector to keep only the tokens above the tail location
|
||||
cur_p->size = last_idx;
|
||||
cur_p->size = last_idx + 1;
|
||||
}
|
||||
|
||||
static struct llama_sampler * llama_sampler_tail_free_clone(const struct llama_sampler * smpl) {
|
||||
|
@ -271,9 +271,9 @@ int main(void) {
|
||||
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 0.76f);
|
||||
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 1.00f);
|
||||
|
||||
test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f}, 0.25f);
|
||||
test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.75f);
|
||||
test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.99f);
|
||||
test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f}, 0.25f);
|
||||
test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.50f);
|
||||
test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f, 0.20f}, 0.80f);
|
||||
|
||||
test_typical({0.97f, 0.01f, 0.01f, 0.01f}, {0.97f}, 0.5f);
|
||||
test_typical({0.4f, 0.2f, 0.2f, 0.2f}, {0.2f, 0.2f, 0.2f}, 0.5f);
|
||||
|
Loading…
Reference in New Issue
Block a user