sampling : one sequence per sampling context

ggml-ci
This commit is contained in:
Georgi Gerganov 2023-10-12 20:35:01 +03:00
parent 370359e5ba
commit 5261aee8d8
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
4 changed files with 28 additions and 86 deletions

View File

@ -1,50 +1,14 @@
#include "sampling.h" #include "sampling.h"
llama_sampling_context::~llama_sampling_context() {
for (auto & it : sequence_contexts) {
if (it.second.grammar != NULL) {
llama_grammar_free(it.second.grammar);
it.second.grammar = NULL;
}
}
}
llama_sampling_context llama_sampling_context_init( llama_sampling_context llama_sampling_context_init(
const struct gpt_params & params, const struct gpt_params & params,
llama_grammar * grammar) { llama_grammar * grammar) {
llama_sampling_context result; llama_sampling_context result;
result.params = params.sampling_params; result.params = params.sampling_params;
result.grammar = grammar; result.grammar = grammar;
return result;
}
// Note: Creates the context if it doesn't exist, so this always return something. return result;
llama_sampler_sequence_context & llama_sampling_get_sequence_context(
llama_sampling_context & ctx_sampling,
const llama_seq_id seq) {
const auto it = ctx_sampling.sequence_contexts.find(seq);
if (it != ctx_sampling.sequence_contexts.end()) {
return it->second;
}
llama_sampler_sequence_context new_ctx = {
2.0f * ctx_sampling.params.mirostat_tau,
ctx_sampling.grammar != NULL ? llama_grammar_copy(ctx_sampling.grammar) : NULL,
};
return ctx_sampling.sequence_contexts.insert({seq, new_ctx}).first->second;
}
bool llama_sampling_context_reset(
llama_sampling_context & ctx_sampling,
const llama_seq_id seq) {
const auto it = ctx_sampling.sequence_contexts.find(seq);
if (it == ctx_sampling.sequence_contexts.end()) return false;
if (it->second.grammar != NULL) {
llama_grammar_free(it->second.grammar);
it->second.grammar = NULL;
}
ctx_sampling.sequence_contexts.erase(it);
return true;
} }
llama_token llama_sampling_sample( llama_token llama_sampling_sample(
@ -53,8 +17,7 @@ llama_token llama_sampling_sample(
struct llama_sampling_context & ctx_sampling, struct llama_sampling_context & ctx_sampling,
const std::vector<llama_token> & last_tokens, const std::vector<llama_token> & last_tokens,
std::vector<llama_token_data> & candidates, std::vector<llama_token_data> & candidates,
const int idx, const int idx) {
llama_seq_id seq) {
const int n_ctx = llama_n_ctx(ctx); const int n_ctx = llama_n_ctx(ctx);
const int n_vocab = llama_n_vocab(llama_get_model(ctx)); const int n_vocab = llama_n_vocab(llama_get_model(ctx));
@ -115,10 +78,8 @@ llama_token llama_sampling_sample(
} }
} }
llama_sampler_sequence_context & ctx_seq = llama_sampling_get_sequence_context(ctx_sampling, seq); if (ctx_sampling.grammar != NULL) {
llama_sample_grammar(ctx, &cur_p, ctx_sampling.grammar);
if (ctx_seq.grammar != NULL) {
llama_sample_grammar(ctx, &cur_p, ctx_seq.grammar);
} }
if (temp <= 0) { if (temp <= 0) {
@ -128,10 +89,10 @@ llama_token llama_sampling_sample(
if (mirostat == 1) { if (mirostat == 1) {
const int mirostat_m = 100; const int mirostat_m = 100;
llama_sample_temp(ctx, &cur_p, temp); llama_sample_temp(ctx, &cur_p, temp);
id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_seq.mirostat_mu); id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_sampling.mirostat_mu);
} else if (mirostat == 2) { } else if (mirostat == 2) {
llama_sample_temp(ctx, &cur_p, temp); llama_sample_temp(ctx, &cur_p, temp);
id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &ctx_seq.mirostat_mu); id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling.mirostat_mu);
} else { } else {
// Temperature sampling // Temperature sampling
size_t min_keep = std::max(1, params.n_probs); size_t min_keep = std::max(1, params.n_probs);
@ -158,8 +119,8 @@ llama_token llama_sampling_sample(
} }
} }
if (ctx_seq.grammar != NULL) { if (ctx_sampling.grammar != NULL) {
llama_grammar_accept_token(ctx, ctx_seq.grammar, id); llama_grammar_accept_token(ctx, ctx_sampling.grammar, id);
} }
return id; return id;

View File

@ -34,27 +34,14 @@ typedef struct llama_sampling_params {
} llama_sampling_params; } llama_sampling_params;
// per-sequence sampler context
typedef struct llama_sampler_sequence_context {
float mirostat_mu; // mirostat sampler state
llama_grammar * grammar;
} llama_sampler_sequence_context;
// general sampler context // general sampler context
typedef struct llama_sampling_context { typedef struct llama_sampling_context {
~llama_sampling_context(); // parameters that will be used for sampling
// parameters that will be used for sampling and when creating
// new llama_sampler_sequence_context instances
llama_sampling_params params; llama_sampling_params params;
// map of sequence ids to sampler contexts // mirostat sampler state
std::unordered_map<llama_seq_id, llama_sampler_sequence_context> sequence_contexts; float mirostat_mu;
// when non-NULL, new instances of llama_sampler_sequence_context
// will get a copy of the grammar here
// note: only the pointer is stored here, it is not a copy of
// the grammar and shouldn't be freed
llama_grammar * grammar; llama_grammar * grammar;
} llama_sampling_context; } llama_sampling_context;
@ -65,13 +52,6 @@ llama_sampling_context llama_sampling_context_init(
const struct gpt_params & params, const struct gpt_params & params,
llama_grammar * grammar = NULL); llama_grammar * grammar = NULL);
// Fetches the sampler context for the specified sequence id (defaults to 0).
// If the context for that sequence id doesn't already exist, it will be created with
// default values based on the parameters in the ctx_sampling argument.
llama_sampler_sequence_context & llama_sampling_get_sequence_context(
llama_sampling_context & ctx_sampling,
const llama_seq_id seq = 0);
// Reset the sampler context for the supplied sequence id (defaults to 0). // Reset the sampler context for the supplied sequence id (defaults to 0).
// This is necessary to reuse a sequence id or free memory used by sequences // This is necessary to reuse a sequence id or free memory used by sequences
// that are no longer required. // that are no longer required.
@ -104,5 +84,4 @@ llama_token llama_sampling_sample(
struct llama_sampling_context & ctx_sampling, struct llama_sampling_context & ctx_sampling,
const std::vector<llama_token> & last_tokens, const std::vector<llama_token> & last_tokens,
std::vector<llama_token_data> & candidates, std::vector<llama_token_data> & candidates,
const int idx = 0, const int idx = 0);
llama_seq_id seq = 0);

View File

@ -69,6 +69,8 @@ struct client {
std::string response; std::string response;
std::vector<llama_token> tokens_prev; std::vector<llama_token> tokens_prev;
llama_sampling_context ctx_sampling;
}; };
static void print_date_time() { static void print_date_time() {
@ -125,8 +127,6 @@ int main(int argc, char ** argv) {
params.logits_all = true; params.logits_all = true;
std::tie(model, ctx) = llama_init_from_gpt_params(params); std::tie(model, ctx) = llama_init_from_gpt_params(params);
llama_sampling_context ctx_sampling = llama_sampling_context_init(params, NULL);
// load the prompts from an external file if there are any // load the prompts from an external file if there are any
if (params.prompt.empty()) { if (params.prompt.empty()) {
printf("\n\033[32mNo new questions so proceed with build-in defaults.\033[0m\n"); printf("\n\033[32mNo new questions so proceed with build-in defaults.\033[0m\n");
@ -156,6 +156,7 @@ int main(int argc, char ** argv) {
client.id = i; client.id = i;
client.tokens_prev.resize(std::max(256, params.n_predict)); client.tokens_prev.resize(std::max(256, params.n_predict));
std::fill(client.tokens_prev.begin(), client.tokens_prev.end(), 0); std::fill(client.tokens_prev.begin(), client.tokens_prev.end(), 0);
client.ctx_sampling = llama_sampling_context_init(params, NULL);
} }
std::vector<llama_token_data> candidates; std::vector<llama_token_data> candidates;
@ -341,7 +342,7 @@ int main(int argc, char ** argv) {
//printf("client %d, seq %d, token %d, pos %d, batch %d\n", //printf("client %d, seq %d, token %d, pos %d, batch %d\n",
// client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch); // client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch);
const llama_token id = llama_sampling_sample(ctx, NULL, ctx_sampling, client.tokens_prev, candidates, client.i_batch - i, client.seq_id); const llama_token id = llama_sampling_sample(ctx, NULL, client.ctx_sampling, client.tokens_prev, candidates, client.i_batch - i);
if (client.n_decoded == 1) { if (client.n_decoded == 1) {
// start measuring generation time after the first token to make sure all concurrent clients // start measuring generation time after the first token to make sure all concurrent clients
@ -386,7 +387,7 @@ int main(int argc, char ** argv) {
n_total_prompt += client.n_prompt; n_total_prompt += client.n_prompt;
n_total_gen += client.n_decoded; n_total_gen += client.n_decoded;
llama_sampling_context_reset(ctx_sampling, client.seq_id);
client.seq_id = -1; client.seq_id = -1;
} }

View File

@ -9,6 +9,12 @@
#include <string> #include <string>
#include <vector> #include <vector>
struct seq_draft {
std::vector<llama_token> tokens;
struct llama_grammar * grammar = NULL;
};
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
gpt_params params; gpt_params params;
@ -213,13 +219,8 @@ int main(int argc, char ** argv) {
if (grammar_dft) { if (grammar_dft) {
llama_grammar_free(grammar_dft); llama_grammar_free(grammar_dft);
} }
// Note: Hardcoded to sequence id 0, if this ever supports parallel generation
// that will need to change. grammar_dft = llama_grammar_copy(ctx_sampling.grammar);
auto it = ctx_sampling.sequence_contexts.find(0);
GGML_ASSERT(it != ctx_sampling.sequence_contexts.end());
// This is necessary because each sequence id in sequence_contexts
// uses a copy of the original grammar.
grammar_dft = llama_grammar_copy(it->second.grammar);
LOG("copied target grammar to draft grammar\n"); LOG("copied target grammar to draft grammar\n");
} }