From 8241bc71b55c2489764c18aa6eb264418ae19bbb Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 23 Sep 2024 12:19:32 +0300 Subject: [PATCH] sampling : avoid expensive softmax during greedy sampling ggml-ci --- common/sampling.cpp | 5 ++++- include/llama.h | 1 + src/llama-sampling.cpp | 7 ++++--- tests/test-sampling.cpp | 42 ++++++++++++++++++++++++++++++++++++++++- 4 files changed, 50 insertions(+), 5 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index e51d07611..345abd221 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -209,7 +209,10 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st GGML_ASSERT(false && "unknown mirostat version"); } } else { - llama_sampler_chain_add(result->chain, llama_sampler_init_softmax()); + if (params.n_probs > 0) { + llama_sampler_chain_add(result->chain, llama_sampler_init_top_k(params.n_probs)); + llama_sampler_chain_add(result->chain, llama_sampler_init_softmax()); + } llama_sampler_chain_add(result->chain, llama_sampler_init_greedy()); } diff --git a/include/llama.h b/include/llama.h index f316a87ba..132937a07 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1066,6 +1066,7 @@ extern "C" { LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed); /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. + /// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first. LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void); /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 5299f5116..e255a8fc4 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -3,13 +3,14 @@ #include "llama-vocab.h" #include "llama-grammar.h" -#include #include -#include -#include +#include #include #include #include +#include +#include +#include #include #include #include diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index d738b7a45..2c79ec472 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -1,6 +1,5 @@ #include "ggml.h" #include "llama.h" -#include "llama-sampling.h" #ifdef NDEBUG #undef NDEBUG @@ -249,6 +248,45 @@ static void test_sampler_queue(const size_t n_vocab, const std::string & sampler samplers_sequence.c_str(), n_vocab, top_k, top_p, min_p); } +#define BENCH(__cnstr, __data, __n_iter) do { \ + auto * cnstr = (__cnstr); \ + std::vector cur((__data).size()); \ + std::copy((__data).begin(), (__data).end(), cur.begin()); \ + llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; \ + llama_sampler_apply(cnstr, &cur_p); \ + llama_sampler_reset(cnstr); \ + const int64_t t_start = ggml_time_us(); \ + const int n_iter = (__n_iter); \ + for (int i = 0; i < n_iter; i++) { \ + std::copy((__data).begin(), (__data).end(), cur.begin()); \ + llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; \ + llama_sampler_apply(cnstr, &cur_p); \ + llama_sampler_reset(cnstr); \ + } \ + const int64_t t_end = ggml_time_us(); \ + llama_sampler_free(cnstr); \ + printf("%-42s: %8.3f us/iter\n", #__cnstr, (t_end - t_start) / (float)n_iter); \ +} while(0) + +static void test_perf() { + const int n_vocab = 1 << 17; + + std::vector data; + + data.reserve(n_vocab); + for (int i = 0; i < n_vocab; i++) { + const float logit = 2.0f*((float)(rand())/RAND_MAX - 0.5f); + data.emplace_back(llama_token_data{i, logit, 0.0f}); + } + + BENCH(llama_sampler_init_top_k (40), data, 32); + BENCH(llama_sampler_init_top_p (0.8f, 1), data, 32); + BENCH(llama_sampler_init_min_p (0.2f, 1), data, 32); + BENCH(llama_sampler_init_tail_free(0.5f, 1), data, 32); + BENCH(llama_sampler_init_typical (0.5f, 1), data, 32); + BENCH(llama_sampler_init_softmax (), data, 32); +} + int main(void) { ggml_time_init(); @@ -316,5 +354,7 @@ int main(void) { printf("OK\n"); + test_perf(); + return 0; }