diff --git a/common/sampling.cpp b/common/sampling.cpp index 56cd0df6b..f536c1e0a 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -203,7 +203,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co GGML_ASSERT(false && "unknown sampler type"); } } - llama_sampler_chain_add(result->chain, llama_sampler_init_softmax()); llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed)); } else if (params.mirostat == 1) { llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); @@ -222,7 +221,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co // the following will not produce exactly the same probs as applyging softmax to the full vocabulary, but // it is much faster, since we avoid sorting all tokens and should give a good approximation llama_sampler_chain_add(result->chain, llama_sampler_init_top_k(params.n_probs)); - llama_sampler_chain_add(result->chain, llama_sampler_init_softmax()); } llama_sampler_chain_add(result->chain, llama_sampler_init_greedy()); } diff --git a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift index dcd9803a2..65cd4eb51 100644 --- a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift +++ b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift @@ -46,7 +46,6 @@ actor LlamaContext { let sparams = llama_sampler_chain_default_params() self.sampling = llama_sampler_chain_init(sparams) llama_sampler_chain_add(self.sampling, llama_sampler_init_temp(0.4)) - llama_sampler_chain_add(self.sampling, llama_sampler_init_softmax()) llama_sampler_chain_add(self.sampling, llama_sampler_init_dist(1234)) } diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 3866cfa27..89d60ec2e 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -42,7 +42,6 @@ int main(int argc, char ** argv) { llama_sampler * smpl = llama_sampler_chain_init(sparams); - llama_sampler_chain_add(smpl, llama_sampler_init_softmax()); llama_sampler_chain_add(smpl, llama_sampler_init_dist(params.sparams.seed)); // tokenize prompt @@ -96,7 +95,6 @@ int main(int argc, char ** argv) { llama_sampler * smpl2 = llama_sampler_chain_init(sparams); - llama_sampler_chain_add(smpl2, llama_sampler_init_softmax()); llama_sampler_chain_add(smpl2, llama_sampler_init_dist(params.sparams.seed)); printf("\nsecond run: %s", params.prompt.c_str()); @@ -156,7 +154,6 @@ int main(int argc, char ** argv) { llama_sampler * smpl3 = llama_sampler_chain_init(sparams); - llama_sampler_chain_add(smpl3, llama_sampler_init_softmax()); llama_sampler_chain_add(smpl3, llama_sampler_init_dist(params.sparams.seed)); printf("\nsingle seq run: %s", params.prompt.c_str()); diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 5a7b3084f..df84af4a1 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -180,8 +180,6 @@ int main(int argc, char ** argv) { // target model sampling context (reuse the llama_context's sampling instance) struct common_sampler * smpl = common_sampler_init(model_tgt, params.sparams); - struct llama_sampler * softmax = llama_sampler_init_softmax(); - // draft sequence data std::vector drafts(n_seq_dft); @@ -624,7 +622,6 @@ int main(int argc, char ** argv) { common_sampler_free(drafts[s].smpl); } - llama_sampler_free(softmax); llama_batch_free(batch_dft); llama_free(ctx_tgt); diff --git a/include/llama.h b/include/llama.h index 02bc7f087..2206ef27d 100644 --- a/include/llama.h +++ b/include/llama.h @@ -217,6 +217,7 @@ extern "C" { typedef struct llama_token_data_array { // TODO: consider SoA + // NOTE: this pointer can be modified by the samplers llama_token_data * data; size_t size; int64_t selected; // this is the index in the data array (i.e. not the token id) @@ -1086,7 +1087,8 @@ extern "C" { /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. /// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first. - LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void); + DEPRECATED(LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void), + "will be removed in the future (see https://github.com/ggerganov/llama.cpp/pull/9896#discussion_r1800920915)"); /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 LLAMA_API struct llama_sampler * llama_sampler_init_top_k (int32_t k); diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 2e6550682..af5117e88 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -427,6 +427,9 @@ static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl* static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { auto * ctx = (llama_sampler_dist *) smpl->ctx; + + llama_sampler_softmax_impl(cur_p); + cur_p->selected = llama_sample_dist(cur_p, ctx->rng); } diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index 1372bdf13..7868aaa7a 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -24,20 +24,22 @@ static void dump(const llama_token_data_array * cur_p) { llama_sampler_free(cnstr); \ } while(0) +#define CUR_P_FROM_PROBS() \ + const size_t n_vocab = probs.size(); \ + std::vector cur; \ + cur.reserve(n_vocab); \ + for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { \ + const float logit = logf(probs[token_id]); \ + cur.emplace_back(llama_token_data{token_id, logit, 0.0f}); \ + } \ + llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false } + static void test_top_k(const std::vector & probs, const std::vector & expected_probs, int k) { - const size_t n_vocab = probs.size(); + CUR_P_FROM_PROBS(); - std::vector cur; - cur.reserve(n_vocab); - for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { - const float logit = logf(probs[token_id]); - cur.emplace_back(llama_token_data{token_id, logit, 0.0f}); - } - - llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; - APPLY(llama_sampler_init_softmax(), &cur_p); DUMP(&cur_p); APPLY(llama_sampler_init_top_k(k), &cur_p); + APPLY(llama_sampler_init_dist (0), &cur_p); DUMP(&cur_p); GGML_ASSERT(cur_p.size == expected_probs.size()); @@ -47,19 +49,12 @@ static void test_top_k(const std::vector & probs, const std::vector & probs, const std::vector & expected_probs, float p) { - const size_t n_vocab = probs.size(); + CUR_P_FROM_PROBS(); - std::vector cur; - cur.reserve(n_vocab); - for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { - const float logit = logf(probs[token_id]); - cur.emplace_back(llama_token_data{token_id, logit, 0.0f}); - } - - llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; - APPLY(llama_sampler_init_softmax(), &cur_p); DUMP(&cur_p); APPLY(llama_sampler_init_top_p(p, 1), &cur_p); + APPLY(llama_sampler_init_dist (0), &cur_p); + DUMP(&cur_p); DUMP(&cur_p); GGML_ASSERT(cur_p.size == expected_probs.size()); @@ -69,16 +64,8 @@ static void test_top_p(const std::vector & probs, const std::vector & probs, const std::vector & expected_probs, float z) { - const size_t n_vocab = probs.size(); + CUR_P_FROM_PROBS(); - std::vector cur; - cur.reserve(n_vocab); - for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { - const float logit = logf(probs[token_id]); - cur.emplace_back(llama_token_data{token_id, logit, 0.0f}); - } - - llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; DUMP(&cur_p); APPLY(llama_sampler_init_tail_free(z, 1), &cur_p); DUMP(&cur_p); @@ -90,20 +77,12 @@ static void test_tfs(const std::vector & probs, const std::vector } static void test_min_p(const std::vector & probs, const std::vector & expected_probs, float p) { - const size_t n_vocab = probs.size(); + CUR_P_FROM_PROBS(); - std::vector cur; - cur.reserve(n_vocab); - for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { - const float logit = logf(probs[token_id]); - cur.emplace_back(llama_token_data{token_id, logit, 0.0f}); - } - - llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; DUMP(&cur_p); APPLY(llama_sampler_init_min_p(p, 1), &cur_p); + APPLY(llama_sampler_init_dist (0), &cur_p); DUMP(&cur_p); - APPLY(llama_sampler_init_softmax(), &cur_p); GGML_ASSERT(cur_p.size == expected_probs.size()); for (size_t i = 0; i < cur_p.size; i++) { @@ -112,17 +91,8 @@ static void test_min_p(const std::vector & probs, const std::vector & probs, const std::vector & expected_probs, float p, float t) { - const size_t n_vocab = probs.size(); + CUR_P_FROM_PROBS(); - std::vector cur; - cur.reserve(n_vocab); - for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { - const float logit = logf(probs[token_id]); - cur.emplace_back(llama_token_data{token_id, logit, 0.0f}); - } - - llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; - APPLY(llama_sampler_init_softmax(), &cur_p); DUMP(&cur_p); APPLY(llama_sampler_init_xtc(p, t, 0, 0), &cur_p); DUMP(&cur_p); @@ -134,16 +104,8 @@ static void test_xtc(const std::vector & probs, const std::vector } static void test_typical(const std::vector & probs, const std::vector & expected_probs, float p) { - const size_t n_vocab = probs.size(); + CUR_P_FROM_PROBS(); - std::vector cur; - cur.reserve(n_vocab); - for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { - const float logit = logf(probs[token_id]); - cur.emplace_back(llama_token_data{token_id, logit, 0.0f}); - } - - llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; DUMP(&cur_p); APPLY(llama_sampler_init_typical(p, 1), &cur_p); DUMP(&cur_p); @@ -160,16 +122,7 @@ static void test_penalties( ) { GGML_ASSERT(probs.size() == expected_probs.size()); - const size_t n_vocab = probs.size(); - - std::vector cur; - cur.reserve(n_vocab); - for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { - const float logit = logf(probs[token_id]); - cur.emplace_back(llama_token_data{token_id, logit, 0.0f}); - } - - llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; + CUR_P_FROM_PROBS(); auto * sampler = llama_sampler_init_penalties(n_vocab, LLAMA_TOKEN_NULL, LLAMA_TOKEN_NULL, last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence, false, false); @@ -177,10 +130,9 @@ static void test_penalties( llama_sampler_accept(sampler, last_tokens[i]); } - APPLY(llama_sampler_init_softmax(), &cur_p); DUMP(&cur_p); APPLY(sampler, &cur_p); - APPLY(llama_sampler_init_softmax(), &cur_p); + APPLY(llama_sampler_init_dist(0), &cur_p); DUMP(&cur_p); GGML_ASSERT(cur_p.size == expected_probs.size()); @@ -214,7 +166,7 @@ static void test_sampler_queue(const size_t n_vocab, const std::string & sampler default : GGML_ABORT("Unknown sampler"); } - APPLY(llama_sampler_init_softmax(), &cur_p); // make sure tokens are sorted for tests + APPLY(llama_sampler_init_dist(0), &cur_p); const int size = cur_p.size; @@ -307,21 +259,20 @@ static void test_perf() { BENCH(llama_sampler_init_tail_free(0.5f, 1), data, 32); BENCH(llama_sampler_init_typical (0.5f, 1), data, 32); BENCH(llama_sampler_init_xtc (1.0f, 0.1f, 1, 1), data, 32); - BENCH(llama_sampler_init_softmax (), data, 32); } int main(void) { ggml_time_init(); - test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f}, 1); - test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f}, 3); + test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f}, 1); + test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.44444f, 0.33333f, 0.22222f}, 3); test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 4); test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0); - test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f}, 0); - test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f}, 0.7f); - test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f}, 0.8f); - test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1); + test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f}, 0); + test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.571429f, 0.428571f}, 0.7f); + test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.44444f, 0.33333f, 0.22222f}, 0.8f); + test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f); test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/1.0f, 0.3f/1.0f, 0.2f/1.0f, 0.1f/1.0f}, 0.00f); test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/1.0f, 0.3f/1.0f, 0.2f/1.0f, 0.1f/1.0f}, 0.24f);