diff --git a/common/sampling.cpp b/common/sampling.cpp index 8ce419459..e0704713f 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -1,50 +1,14 @@ #include "sampling.h" -llama_sampling_context::~llama_sampling_context() { - for (auto & it : sequence_contexts) { - if (it.second.grammar != NULL) { - llama_grammar_free(it.second.grammar); - it.second.grammar = NULL; - } - } -} - llama_sampling_context llama_sampling_context_init( const struct gpt_params & params, llama_grammar * grammar) { - llama_sampling_context result; + llama_sampling_context result; - result.params = params.sampling_params; - result.grammar = grammar; - return result; -} + result.params = params.sampling_params; + result.grammar = grammar; -// Note: Creates the context if it doesn't exist, so this always return something. -llama_sampler_sequence_context & llama_sampling_get_sequence_context( - llama_sampling_context & ctx_sampling, - const llama_seq_id seq) { - const auto it = ctx_sampling.sequence_contexts.find(seq); - if (it != ctx_sampling.sequence_contexts.end()) { - return it->second; - } - llama_sampler_sequence_context new_ctx = { - 2.0f * ctx_sampling.params.mirostat_tau, - ctx_sampling.grammar != NULL ? llama_grammar_copy(ctx_sampling.grammar) : NULL, - }; - return ctx_sampling.sequence_contexts.insert({seq, new_ctx}).first->second; -} - -bool llama_sampling_context_reset( - llama_sampling_context & ctx_sampling, - const llama_seq_id seq) { - const auto it = ctx_sampling.sequence_contexts.find(seq); - if (it == ctx_sampling.sequence_contexts.end()) return false; - if (it->second.grammar != NULL) { - llama_grammar_free(it->second.grammar); - it->second.grammar = NULL; - } - ctx_sampling.sequence_contexts.erase(it); - return true; + return result; } llama_token llama_sampling_sample( @@ -53,8 +17,7 @@ llama_token llama_sampling_sample( struct llama_sampling_context & ctx_sampling, const std::vector & last_tokens, std::vector & candidates, - const int idx, - llama_seq_id seq) { + const int idx) { const int n_ctx = llama_n_ctx(ctx); const int n_vocab = llama_n_vocab(llama_get_model(ctx)); @@ -115,10 +78,8 @@ llama_token llama_sampling_sample( } } - llama_sampler_sequence_context & ctx_seq = llama_sampling_get_sequence_context(ctx_sampling, seq); - - if (ctx_seq.grammar != NULL) { - llama_sample_grammar(ctx, &cur_p, ctx_seq.grammar); + if (ctx_sampling.grammar != NULL) { + llama_sample_grammar(ctx, &cur_p, ctx_sampling.grammar); } if (temp <= 0) { @@ -128,10 +89,10 @@ llama_token llama_sampling_sample( if (mirostat == 1) { const int mirostat_m = 100; llama_sample_temp(ctx, &cur_p, temp); - id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_seq.mirostat_mu); + id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_sampling.mirostat_mu); } else if (mirostat == 2) { llama_sample_temp(ctx, &cur_p, temp); - id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &ctx_seq.mirostat_mu); + id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling.mirostat_mu); } else { // Temperature sampling size_t min_keep = std::max(1, params.n_probs); @@ -158,8 +119,8 @@ llama_token llama_sampling_sample( } } - if (ctx_seq.grammar != NULL) { - llama_grammar_accept_token(ctx, ctx_seq.grammar, id); + if (ctx_sampling.grammar != NULL) { + llama_grammar_accept_token(ctx, ctx_sampling.grammar, id); } return id; diff --git a/common/sampling.h b/common/sampling.h index 0aab5d03c..fda5902a8 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -34,27 +34,14 @@ typedef struct llama_sampling_params { } llama_sampling_params; -// per-sequence sampler context -typedef struct llama_sampler_sequence_context { - float mirostat_mu; // mirostat sampler state - llama_grammar * grammar; -} llama_sampler_sequence_context; - // general sampler context typedef struct llama_sampling_context { - ~llama_sampling_context(); - - // parameters that will be used for sampling and when creating - // new llama_sampler_sequence_context instances + // parameters that will be used for sampling llama_sampling_params params; - // map of sequence ids to sampler contexts - std::unordered_map sequence_contexts; + // mirostat sampler state + float mirostat_mu; - // when non-NULL, new instances of llama_sampler_sequence_context - // will get a copy of the grammar here - // note: only the pointer is stored here, it is not a copy of - // the grammar and shouldn't be freed llama_grammar * grammar; } llama_sampling_context; @@ -65,13 +52,6 @@ llama_sampling_context llama_sampling_context_init( const struct gpt_params & params, llama_grammar * grammar = NULL); -// Fetches the sampler context for the specified sequence id (defaults to 0). -// If the context for that sequence id doesn't already exist, it will be created with -// default values based on the parameters in the ctx_sampling argument. -llama_sampler_sequence_context & llama_sampling_get_sequence_context( - llama_sampling_context & ctx_sampling, - const llama_seq_id seq = 0); - // Reset the sampler context for the supplied sequence id (defaults to 0). // This is necessary to reuse a sequence id or free memory used by sequences // that are no longer required. @@ -104,5 +84,4 @@ llama_token llama_sampling_sample( struct llama_sampling_context & ctx_sampling, const std::vector & last_tokens, std::vector & candidates, - const int idx = 0, - llama_seq_id seq = 0); + const int idx = 0); diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 63ddcd8ed..165315db0 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -69,6 +69,8 @@ struct client { std::string response; std::vector tokens_prev; + + llama_sampling_context ctx_sampling; }; static void print_date_time() { @@ -125,8 +127,6 @@ int main(int argc, char ** argv) { params.logits_all = true; std::tie(model, ctx) = llama_init_from_gpt_params(params); - llama_sampling_context ctx_sampling = llama_sampling_context_init(params, NULL); - // load the prompts from an external file if there are any if (params.prompt.empty()) { printf("\n\033[32mNo new questions so proceed with build-in defaults.\033[0m\n"); @@ -156,6 +156,7 @@ int main(int argc, char ** argv) { client.id = i; client.tokens_prev.resize(std::max(256, params.n_predict)); std::fill(client.tokens_prev.begin(), client.tokens_prev.end(), 0); + client.ctx_sampling = llama_sampling_context_init(params, NULL); } std::vector candidates; @@ -341,7 +342,7 @@ int main(int argc, char ** argv) { //printf("client %d, seq %d, token %d, pos %d, batch %d\n", // client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch); - const llama_token id = llama_sampling_sample(ctx, NULL, ctx_sampling, client.tokens_prev, candidates, client.i_batch - i, client.seq_id); + const llama_token id = llama_sampling_sample(ctx, NULL, client.ctx_sampling, client.tokens_prev, candidates, client.i_batch - i); if (client.n_decoded == 1) { // start measuring generation time after the first token to make sure all concurrent clients @@ -386,7 +387,7 @@ int main(int argc, char ** argv) { n_total_prompt += client.n_prompt; n_total_gen += client.n_decoded; - llama_sampling_context_reset(ctx_sampling, client.seq_id); + client.seq_id = -1; } diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 018dbf9a2..c3e97d71f 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -9,6 +9,12 @@ #include #include +struct seq_draft { + std::vector tokens; + + struct llama_grammar * grammar = NULL; +}; + int main(int argc, char ** argv) { gpt_params params; @@ -213,13 +219,8 @@ int main(int argc, char ** argv) { if (grammar_dft) { llama_grammar_free(grammar_dft); } - // Note: Hardcoded to sequence id 0, if this ever supports parallel generation - // that will need to change. - auto it = ctx_sampling.sequence_contexts.find(0); - GGML_ASSERT(it != ctx_sampling.sequence_contexts.end()); - // This is necessary because each sequence id in sequence_contexts - // uses a copy of the original grammar. - grammar_dft = llama_grammar_copy(it->second.grammar); + + grammar_dft = llama_grammar_copy(ctx_sampling.grammar); LOG("copied target grammar to draft grammar\n"); }