From 5fb5e24811cb01d48b482c15a974bfbd9f433e1d Mon Sep 17 00:00:00 2001 From: slaren Date: Mon, 9 Sep 2024 17:10:46 +0200 Subject: [PATCH] llama : minor sampling refactor (2) (#9386) --- examples/batched.swift/Sources/main.swift | 2 - examples/batched/batched.cpp | 2 - examples/gritlm/gritlm.cpp | 1 - .../llama/src/main/cpp/llama-android.cpp | 2 - .../llama.cpp.swift/LibLlama.swift | 2 - examples/passkey/passkey.cpp | 2 - examples/save-load-state/save-load-state.cpp | 6 - examples/server/server.cpp | 2 +- examples/simple/simple.cpp | 2 - include/llama.h | 11 +- src/llama-sampling.cpp | 194 ++++++++++-------- tests/test-sampling.cpp | 2 +- 12 files changed, 115 insertions(+), 113 deletions(-) diff --git a/examples/batched.swift/Sources/main.swift b/examples/batched.swift/Sources/main.swift index 4bc2bbf2c..9f7c49492 100644 --- a/examples/batched.swift/Sources/main.swift +++ b/examples/batched.swift/Sources/main.swift @@ -140,8 +140,6 @@ while n_cur <= n_len { let new_token_id = llama_sampler_sample(smpl, context, i_batch[i]) - llama_sampler_accept(smpl, new_token_id) - // is it an end of stream? -> mark the stream as finished if llama_token_is_eog(model, new_token_id) || n_cur == n_len { i_batch[i] = -1 diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index f5f309022..615d6f0f5 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -172,8 +172,6 @@ int main(int argc, char ** argv) { const llama_token new_token_id = llama_sampler_sample(smpl, ctx, i_batch[i]); - llama_sampler_accept(smpl, new_token_id); - // is it an end of generation? -> mark the stream as finished if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) { i_batch[i] = -1; diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index e1efbf573..6f060e2dc 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -121,7 +121,6 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std llama_decode(ctx, bat); llama_token token = llama_sampler_sample(smpl, ctx, bat.n_tokens - 1); - llama_sampler_accept(smpl, token); if (token == eos_token) { break; diff --git a/examples/llama.android/llama/src/main/cpp/llama-android.cpp b/examples/llama.android/llama/src/main/cpp/llama-android.cpp index 06ec160c2..f611809c6 100644 --- a/examples/llama.android/llama/src/main/cpp/llama-android.cpp +++ b/examples/llama.android/llama/src/main/cpp/llama-android.cpp @@ -414,8 +414,6 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop( // sample the most likely token const auto new_token_id = llama_sampler_sample(sampler, context, -1); - llama_sampler_accept(sampler, new_token_id); - const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value); if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) { return nullptr; diff --git a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift index 92f61fe83..dcd9803a2 100644 --- a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift +++ b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift @@ -152,8 +152,6 @@ actor LlamaContext { new_token_id = llama_sampler_sample(sampling, context, batch.n_tokens - 1) - llama_sampler_accept(sampling, new_token_id) - if llama_token_is_eog(model, new_token_id) || n_cur == n_len { print("\n") is_done = true diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index 76d235c2c..271ef3a98 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -220,8 +220,6 @@ int main(int argc, char ** argv) { { const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1); - llama_sampler_accept(smpl, new_token_id); - // is it an end of generation? if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) { LOG_TEE("\n"); diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index b54ec3bd8..e17ab0ed0 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -74,8 +74,6 @@ int main(int argc, char ** argv) { auto next_token = llama_sampler_sample(smpl, ctx, -1); auto next_token_str = llama_token_to_piece(ctx, next_token); - llama_sampler_accept(smpl, next_token); - printf("%s", next_token_str.c_str()); result0 += next_token_str; @@ -132,8 +130,6 @@ int main(int argc, char ** argv) { auto next_token = llama_sampler_sample(smpl2, ctx2, -1); auto next_token_str = llama_token_to_piece(ctx2, next_token); - llama_sampler_accept(smpl2, next_token); - printf("%s", next_token_str.c_str()); result1 += next_token_str; @@ -222,8 +218,6 @@ int main(int argc, char ** argv) { auto next_token = llama_sampler_sample(smpl3, ctx3, -1); auto next_token_str = llama_token_to_piece(ctx3, next_token); - llama_sampler_accept(smpl3, next_token); - printf("%s", next_token_str.c_str()); result2 += next_token_str; diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 9ab8f8ca6..de3ea313c 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -613,7 +613,7 @@ struct server_context { gpt_params params; - llama_batch batch; + llama_batch batch = {}; bool clean_kv_cache = true; bool add_bos_token = true; diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index a53cef547..d040172a5 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -118,8 +118,6 @@ int main(int argc, char ** argv) { { const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1); - llama_sampler_accept(smpl, new_token_id); - // is it an end of generation? if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) { LOG_TEE("\n"); diff --git a/include/llama.h b/include/llama.h index 6334fc30d..93b3e6e85 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1127,15 +1127,16 @@ extern "C" { int32_t n_logit_bias, const llama_logit_bias * logit_bias); - // Shorthand for: + /// @details Sample and accept a token from the idx-th output of the last evaluation // + // Shorthand for: // const auto * logits = llama_get_logits_ith(ctx, idx); // llama_token_data_array cur_p = { ... init from logits ... }; // llama_sampler_apply(smpl, &cur_p); - // return cur_p.data[cur_p.selected].id; - // - // At this point, this is mostly a convenience function. - // + // auto token = cur_p.data[cur_p.selected].id; + // llama_sampler_accept(smpl, token); + // return token; + // Returns the sampled token LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx); // TODO: extend in the future diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 41f48ec28..6f448b80c 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -8,49 +8,44 @@ #include #include #include +#include #include #include #include -static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng, std::vector & probs) { -#if 1 - probs.resize(cur_p->size); - for (size_t i = 0; i < cur_p->size; ++i) { - probs[i] = cur_p->data[i].p; - } - - std::discrete_distribution dist(probs.begin(), probs.end()); -#else - // avoid the copy with a custom iterator +static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng) { + // iterator for the probabilities +#ifdef __GNUC__ #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wunused-local-typedefs" +#endif struct probs_iterator { typedef std::input_iterator_tag iterator_category; typedef float value_type; typedef float * pointer; typedef float & reference; - typedef size_t difference_type; + typedef ptrdiff_t difference_type; - const llama_token_data_array * data; - size_t i; + const llama_token_data * data; - bool operator==(const probs_iterator & other) const { return data + i == other.data + other.i; } - bool operator!=(const probs_iterator & other) const { return data + i != other.data + other.i; } - float operator*() const { return data->data[i].p; } - probs_iterator & operator++() { ++i; return *this; } - probs_iterator operator++(int) { probs_iterator tmp = *this; ++i; return tmp; } + bool operator==(const probs_iterator & other) const { return data == other.data; } + bool operator!=(const probs_iterator & other) const { return data != other.data; } + const float & operator*() const { return data->p; } + probs_iterator & operator++() { ++data; return *this; } + probs_iterator operator++(int) { probs_iterator tmp = *this; ++data; return tmp; } }; + +#ifdef __GNUC__ #pragma GCC diagnostic pop - - std::discrete_distribution dist(probs_iterator{cur_p, 0}, probs_iterator{cur_p, cur_p->size}); - - GGML_UNUSED(probs); #endif + std::discrete_distribution dist(probs_iterator{cur_p->data}, probs_iterator{cur_p->data + cur_p->size}); + return dist(rng); } +/* static void llama_log_softmax(float * array, size_t size) { float max_l = *std::max_element(array, array + size); float sum = 0.f; @@ -64,6 +59,7 @@ static void llama_log_softmax(float * array, size_t size) { array[i] = logf(array[i] / sum); } } +*/ static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) { GGML_ASSERT(cur_p->size > 0); @@ -231,67 +227,92 @@ llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_conte cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; } - llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; + llama_token_data_array cur_p = { + /* .data = */ cur.data(), + /* .size = */ cur.size(), + /* .selected = */ -1, + /* .sorted = */ false, + }; llama_sampler_apply(smpl, &cur_p); - return cur_p.data[cur_p.selected].id; + GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size); + + auto token = cur_p.data[cur_p.selected].id; + + llama_sampler_accept(smpl, token); + + return token; } // sampler chain +static const char * llama_sampler_chain_name(const struct llama_sampler * /*smpl*/) { + return "chain"; +} + +static void llama_sampler_chain_accept(struct llama_sampler * smpl, llama_token token) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + time_meas tm(chain->t_sample_us, chain->params.no_perf); + + for (auto * smpl : chain->samplers) { + llama_sampler_accept(smpl, token); + } + + chain->n_sample++; +} + +static void llama_sampler_chain_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + time_meas tm(chain->t_sample_us, chain->params.no_perf); + + for (auto * smpl : chain->samplers) { + llama_sampler_apply(smpl, cur_p); + } +} + +static void llama_sampler_chain_reset(struct llama_sampler * smpl) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + for (auto * smpl : chain->samplers) { + llama_sampler_reset(smpl); + } + + chain->t_sample_us = 0; + chain->n_sample = 0; +} + +static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampler * smpl) { + const auto * chain_src = (const llama_sampler_chain *) smpl->ctx; + + auto * result = llama_sampler_chain_init(chain_src->params); + + for (auto * smpl : chain_src->samplers) { + llama_sampler_chain_add(result, llama_sampler_clone(smpl)); + } + + return result; +} + +static void llama_sampler_chain_free(struct llama_sampler * smpl) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + for (auto * smpl : chain->samplers) { + llama_sampler_free(smpl); + } + + delete chain; +} + static struct llama_sampler_i llama_sampler_chain_i = { - /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "chain"; }, - /* .accept = */ [](struct llama_sampler * smpl, llama_token token) { - auto * chain = (llama_sampler_chain *) smpl->ctx; - - time_meas tm(chain->t_sample_us, chain->params.no_perf); - - for (auto * smpl : chain->samplers) { - llama_sampler_accept(smpl, token); - } - - chain->n_sample++; - }, - /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { - auto * chain = (llama_sampler_chain *) smpl->ctx; - - time_meas tm(chain->t_sample_us, chain->params.no_perf); - - for (auto * smpl : chain->samplers) { - llama_sampler_apply(smpl, cur_p); - } - }, - /* .reset = */ [](struct llama_sampler * smpl) { - auto * chain = (llama_sampler_chain *) smpl->ctx; - - for (auto * smpl : chain->samplers) { - llama_sampler_reset(smpl); - } - - chain->t_sample_us = 0; - chain->n_sample = 0; - }, - /* .clone = */ [](const struct llama_sampler * smpl) { - const auto * chain_src = (const llama_sampler_chain *) smpl->ctx; - - auto * result = llama_sampler_chain_init(chain_src->params); - - for (auto * smpl : chain_src->samplers) { - llama_sampler_chain_add(result, llama_sampler_clone(smpl)); - } - - return result; - }, - /* .free = */ [](struct llama_sampler * smpl) { - auto * chain = (llama_sampler_chain *) smpl->ctx; - - for (auto * smpl : chain->samplers) { - llama_sampler_free(smpl); - } - - delete chain; - }, + /* .name = */ llama_sampler_chain_name, + /* .accept = */ llama_sampler_chain_accept, + /* .apply = */ llama_sampler_chain_apply, + /* .reset = */ llama_sampler_chain_reset, + /* .clone = */ llama_sampler_chain_clone, + /* .free = */ llama_sampler_chain_free, }; struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) { @@ -368,8 +389,6 @@ struct llama_sampler_dist { const uint32_t seed; std::mt19937 rng; - - std::vector probs; // work array }; static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*/) { @@ -378,7 +397,7 @@ static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl* static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { auto * ctx = (llama_sampler_dist *) smpl->ctx; - cur_p->selected = llama_sample_dist(cur_p, ctx->rng, ctx->probs); + cur_p->selected = llama_sample_dist(cur_p, ctx->rng); } static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) { @@ -419,7 +438,6 @@ struct llama_sampler * llama_sampler_init_dist(uint32_t seed) { /* .ctx = */ new llama_sampler_dist { /* .seed = */ seed, /* .rng = */ std::mt19937(seed), - /* .probs = */ {}, }, }; } @@ -1023,8 +1041,6 @@ struct llama_sampler_mirostat { float mu; std::mt19937 rng; - - std::vector probs; }; static const char * llama_sampler_mirostat_name(const struct llama_sampler * /*smpl*/) { @@ -1055,7 +1071,7 @@ static void llama_sampler_mirostat_apply(struct llama_sampler * smpl, llama_toke llama_sampler_top_k_impl(cur_p, std::max(int(k), 1)); llama_sampler_softmax_impl(cur_p); - const int idx = llama_sample_dist(cur_p, ctx->rng, ctx->probs); + const int idx = llama_sample_dist(cur_p, ctx->rng); cur_p->selected = idx; @@ -1111,7 +1127,6 @@ struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t see /* .m = */ m, /* .mu = */ 2.0f*tau, /* .rng = */ std::mt19937(seed), - /* .probs = */ {}, }, }; } @@ -1127,8 +1142,6 @@ struct llama_sampler_mirostat_v2 { float mu; std::mt19937 rng; - - std::vector probs; }; static const char * llama_sampler_mirostat_v2_name(const struct llama_sampler * /*smpl*/) { @@ -1152,7 +1165,7 @@ static void llama_sampler_mirostat_v2_apply(struct llama_sampler * smpl, llama_t // Normalize the probabilities of the remaining words llama_sampler_softmax_impl(cur_p); - const int idx = llama_sample_dist(cur_p, ctx->rng, ctx->probs); + const int idx = llama_sample_dist(cur_p, ctx->rng); cur_p->selected = idx; @@ -1207,7 +1220,6 @@ struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, /* .eta = */ eta, /* .mu = */ 2.0f*tau, /* .rng = */ std::mt19937(seed), - /* .probs = */ {}, }, }; } @@ -1527,6 +1539,10 @@ static const char * llama_sampler_logit_bias_name(const struct llama_sampler * / static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { auto * ctx = (llama_sampler_logit_bias *) smpl->ctx; + if (ctx->logit_bias.empty()) { + return; + } + ctx->to_search.clear(); // update the candidates that have not been shuffled in the vocabulary (i.e. idx == id) @@ -1538,6 +1554,10 @@ static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_to } } + if (ctx->to_search.empty()) { + return; + } + // search for the remaining candidates that were not found in the previous step for (size_t i = 0; i < cur_p->size; ++i) { for (const auto & lb : ctx->to_search) { diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index 37400c179..d738b7a45 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -245,7 +245,7 @@ static void test_sampler_queue(const size_t n_vocab, const std::string & sampler } } - printf("Sampler queue %3s OK with n_vocab=%05ld top_k=%05d top_p=%f min_p=%f\n", + printf("Sampler queue %3s OK with n_vocab=%05zu top_k=%05d top_p=%f min_p=%f\n", samplers_sequence.c_str(), n_vocab, top_k, top_p, min_p); }