diff --git a/common/arg.cpp b/common/arg.cpp index 922391069..c210595cd 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -963,7 +963,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex, } ).set_sparam()); add_opt(llama_arg( - {"--tfs"}, "N", + {"--tfs", "--tfs-z"}, "Z", format("tail free sampling, parameter z (default: %.1f, 1.0 = disabled)", (double)params.sparams.tfs_z), [](gpt_params & params, const std::string & value) { params.sparams.tfs_z = std::stof(value); diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 5299f5116..e2ee328d4 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -756,20 +756,22 @@ static void llama_sampler_tail_free_apply(struct llama_sampler * smpl, llama_tok } } + assert(cur_p->size > 0); // guaranteed earlier + size_t last_idx = cur_p->size - 1; + float cum_sum = 0.0f; - size_t last_idx = cur_p->size; for (size_t i = 0; i < second_derivatives.size(); ++i) { cum_sum += second_derivatives[i]; // Check if the running sum is greater than z or if we have kept at least min_keep tokens - if (cum_sum > ctx->z && i >= ctx->min_keep) { + if (cum_sum > ctx->z && (i + 1) >= ctx->min_keep) { last_idx = i; break; } } // Resize the output vector to keep only the tokens above the tail location - cur_p->size = last_idx; + cur_p->size = last_idx + 1; } static struct llama_sampler * llama_sampler_tail_free_clone(const struct llama_sampler * smpl) { diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index d738b7a45..3a8ca9fe8 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -271,9 +271,9 @@ int main(void) { test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 0.76f); test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 1.00f); - test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f}, 0.25f); - test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.75f); - test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.99f); + test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f}, 0.25f); + test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.50f); + test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f, 0.20f}, 0.80f); test_typical({0.97f, 0.01f, 0.01f, 0.01f}, {0.97f}, 0.5f); test_typical({0.4f, 0.2f, 0.2f, 0.2f}, {0.2f, 0.2f, 0.2f}, 0.5f);