mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-12 19:50:17 +00:00
sampling : one sequence per sampling context
ggml-ci
This commit is contained in:
parent
370359e5ba
commit
5261aee8d8
@ -1,14 +1,5 @@
|
|||||||
#include "sampling.h"
|
#include "sampling.h"
|
||||||
|
|
||||||
llama_sampling_context::~llama_sampling_context() {
|
|
||||||
for (auto & it : sequence_contexts) {
|
|
||||||
if (it.second.grammar != NULL) {
|
|
||||||
llama_grammar_free(it.second.grammar);
|
|
||||||
it.second.grammar = NULL;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_sampling_context llama_sampling_context_init(
|
llama_sampling_context llama_sampling_context_init(
|
||||||
const struct gpt_params & params,
|
const struct gpt_params & params,
|
||||||
llama_grammar * grammar) {
|
llama_grammar * grammar) {
|
||||||
@ -16,45 +7,17 @@ llama_sampling_context llama_sampling_context_init(
|
|||||||
|
|
||||||
result.params = params.sampling_params;
|
result.params = params.sampling_params;
|
||||||
result.grammar = grammar;
|
result.grammar = grammar;
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Note: Creates the context if it doesn't exist, so this always return something.
|
|
||||||
llama_sampler_sequence_context & llama_sampling_get_sequence_context(
|
|
||||||
llama_sampling_context & ctx_sampling,
|
|
||||||
const llama_seq_id seq) {
|
|
||||||
const auto it = ctx_sampling.sequence_contexts.find(seq);
|
|
||||||
if (it != ctx_sampling.sequence_contexts.end()) {
|
|
||||||
return it->second;
|
|
||||||
}
|
|
||||||
llama_sampler_sequence_context new_ctx = {
|
|
||||||
2.0f * ctx_sampling.params.mirostat_tau,
|
|
||||||
ctx_sampling.grammar != NULL ? llama_grammar_copy(ctx_sampling.grammar) : NULL,
|
|
||||||
};
|
|
||||||
return ctx_sampling.sequence_contexts.insert({seq, new_ctx}).first->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool llama_sampling_context_reset(
|
|
||||||
llama_sampling_context & ctx_sampling,
|
|
||||||
const llama_seq_id seq) {
|
|
||||||
const auto it = ctx_sampling.sequence_contexts.find(seq);
|
|
||||||
if (it == ctx_sampling.sequence_contexts.end()) return false;
|
|
||||||
if (it->second.grammar != NULL) {
|
|
||||||
llama_grammar_free(it->second.grammar);
|
|
||||||
it->second.grammar = NULL;
|
|
||||||
}
|
|
||||||
ctx_sampling.sequence_contexts.erase(it);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_token llama_sampling_sample(
|
llama_token llama_sampling_sample(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
struct llama_context * ctx_guidance,
|
struct llama_context * ctx_guidance,
|
||||||
struct llama_sampling_context & ctx_sampling,
|
struct llama_sampling_context & ctx_sampling,
|
||||||
const std::vector<llama_token> & last_tokens,
|
const std::vector<llama_token> & last_tokens,
|
||||||
std::vector<llama_token_data> & candidates,
|
std::vector<llama_token_data> & candidates,
|
||||||
const int idx,
|
const int idx) {
|
||||||
llama_seq_id seq) {
|
|
||||||
const int n_ctx = llama_n_ctx(ctx);
|
const int n_ctx = llama_n_ctx(ctx);
|
||||||
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||||
|
|
||||||
@ -115,10 +78,8 @@ llama_token llama_sampling_sample(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_sampler_sequence_context & ctx_seq = llama_sampling_get_sequence_context(ctx_sampling, seq);
|
if (ctx_sampling.grammar != NULL) {
|
||||||
|
llama_sample_grammar(ctx, &cur_p, ctx_sampling.grammar);
|
||||||
if (ctx_seq.grammar != NULL) {
|
|
||||||
llama_sample_grammar(ctx, &cur_p, ctx_seq.grammar);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (temp <= 0) {
|
if (temp <= 0) {
|
||||||
@ -128,10 +89,10 @@ llama_token llama_sampling_sample(
|
|||||||
if (mirostat == 1) {
|
if (mirostat == 1) {
|
||||||
const int mirostat_m = 100;
|
const int mirostat_m = 100;
|
||||||
llama_sample_temp(ctx, &cur_p, temp);
|
llama_sample_temp(ctx, &cur_p, temp);
|
||||||
id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_seq.mirostat_mu);
|
id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_sampling.mirostat_mu);
|
||||||
} else if (mirostat == 2) {
|
} else if (mirostat == 2) {
|
||||||
llama_sample_temp(ctx, &cur_p, temp);
|
llama_sample_temp(ctx, &cur_p, temp);
|
||||||
id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &ctx_seq.mirostat_mu);
|
id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling.mirostat_mu);
|
||||||
} else {
|
} else {
|
||||||
// Temperature sampling
|
// Temperature sampling
|
||||||
size_t min_keep = std::max(1, params.n_probs);
|
size_t min_keep = std::max(1, params.n_probs);
|
||||||
@ -158,8 +119,8 @@ llama_token llama_sampling_sample(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ctx_seq.grammar != NULL) {
|
if (ctx_sampling.grammar != NULL) {
|
||||||
llama_grammar_accept_token(ctx, ctx_seq.grammar, id);
|
llama_grammar_accept_token(ctx, ctx_sampling.grammar, id);
|
||||||
}
|
}
|
||||||
|
|
||||||
return id;
|
return id;
|
||||||
|
@ -34,27 +34,14 @@ typedef struct llama_sampling_params {
|
|||||||
|
|
||||||
} llama_sampling_params;
|
} llama_sampling_params;
|
||||||
|
|
||||||
// per-sequence sampler context
|
|
||||||
typedef struct llama_sampler_sequence_context {
|
|
||||||
float mirostat_mu; // mirostat sampler state
|
|
||||||
llama_grammar * grammar;
|
|
||||||
} llama_sampler_sequence_context;
|
|
||||||
|
|
||||||
// general sampler context
|
// general sampler context
|
||||||
typedef struct llama_sampling_context {
|
typedef struct llama_sampling_context {
|
||||||
~llama_sampling_context();
|
// parameters that will be used for sampling
|
||||||
|
|
||||||
// parameters that will be used for sampling and when creating
|
|
||||||
// new llama_sampler_sequence_context instances
|
|
||||||
llama_sampling_params params;
|
llama_sampling_params params;
|
||||||
|
|
||||||
// map of sequence ids to sampler contexts
|
// mirostat sampler state
|
||||||
std::unordered_map<llama_seq_id, llama_sampler_sequence_context> sequence_contexts;
|
float mirostat_mu;
|
||||||
|
|
||||||
// when non-NULL, new instances of llama_sampler_sequence_context
|
|
||||||
// will get a copy of the grammar here
|
|
||||||
// note: only the pointer is stored here, it is not a copy of
|
|
||||||
// the grammar and shouldn't be freed
|
|
||||||
llama_grammar * grammar;
|
llama_grammar * grammar;
|
||||||
} llama_sampling_context;
|
} llama_sampling_context;
|
||||||
|
|
||||||
@ -65,13 +52,6 @@ llama_sampling_context llama_sampling_context_init(
|
|||||||
const struct gpt_params & params,
|
const struct gpt_params & params,
|
||||||
llama_grammar * grammar = NULL);
|
llama_grammar * grammar = NULL);
|
||||||
|
|
||||||
// Fetches the sampler context for the specified sequence id (defaults to 0).
|
|
||||||
// If the context for that sequence id doesn't already exist, it will be created with
|
|
||||||
// default values based on the parameters in the ctx_sampling argument.
|
|
||||||
llama_sampler_sequence_context & llama_sampling_get_sequence_context(
|
|
||||||
llama_sampling_context & ctx_sampling,
|
|
||||||
const llama_seq_id seq = 0);
|
|
||||||
|
|
||||||
// Reset the sampler context for the supplied sequence id (defaults to 0).
|
// Reset the sampler context for the supplied sequence id (defaults to 0).
|
||||||
// This is necessary to reuse a sequence id or free memory used by sequences
|
// This is necessary to reuse a sequence id or free memory used by sequences
|
||||||
// that are no longer required.
|
// that are no longer required.
|
||||||
@ -104,5 +84,4 @@ llama_token llama_sampling_sample(
|
|||||||
struct llama_sampling_context & ctx_sampling,
|
struct llama_sampling_context & ctx_sampling,
|
||||||
const std::vector<llama_token> & last_tokens,
|
const std::vector<llama_token> & last_tokens,
|
||||||
std::vector<llama_token_data> & candidates,
|
std::vector<llama_token_data> & candidates,
|
||||||
const int idx = 0,
|
const int idx = 0);
|
||||||
llama_seq_id seq = 0);
|
|
||||||
|
@ -69,6 +69,8 @@ struct client {
|
|||||||
std::string response;
|
std::string response;
|
||||||
|
|
||||||
std::vector<llama_token> tokens_prev;
|
std::vector<llama_token> tokens_prev;
|
||||||
|
|
||||||
|
llama_sampling_context ctx_sampling;
|
||||||
};
|
};
|
||||||
|
|
||||||
static void print_date_time() {
|
static void print_date_time() {
|
||||||
@ -125,8 +127,6 @@ int main(int argc, char ** argv) {
|
|||||||
params.logits_all = true;
|
params.logits_all = true;
|
||||||
std::tie(model, ctx) = llama_init_from_gpt_params(params);
|
std::tie(model, ctx) = llama_init_from_gpt_params(params);
|
||||||
|
|
||||||
llama_sampling_context ctx_sampling = llama_sampling_context_init(params, NULL);
|
|
||||||
|
|
||||||
// load the prompts from an external file if there are any
|
// load the prompts from an external file if there are any
|
||||||
if (params.prompt.empty()) {
|
if (params.prompt.empty()) {
|
||||||
printf("\n\033[32mNo new questions so proceed with build-in defaults.\033[0m\n");
|
printf("\n\033[32mNo new questions so proceed with build-in defaults.\033[0m\n");
|
||||||
@ -156,6 +156,7 @@ int main(int argc, char ** argv) {
|
|||||||
client.id = i;
|
client.id = i;
|
||||||
client.tokens_prev.resize(std::max(256, params.n_predict));
|
client.tokens_prev.resize(std::max(256, params.n_predict));
|
||||||
std::fill(client.tokens_prev.begin(), client.tokens_prev.end(), 0);
|
std::fill(client.tokens_prev.begin(), client.tokens_prev.end(), 0);
|
||||||
|
client.ctx_sampling = llama_sampling_context_init(params, NULL);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<llama_token_data> candidates;
|
std::vector<llama_token_data> candidates;
|
||||||
@ -341,7 +342,7 @@ int main(int argc, char ** argv) {
|
|||||||
//printf("client %d, seq %d, token %d, pos %d, batch %d\n",
|
//printf("client %d, seq %d, token %d, pos %d, batch %d\n",
|
||||||
// client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch);
|
// client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch);
|
||||||
|
|
||||||
const llama_token id = llama_sampling_sample(ctx, NULL, ctx_sampling, client.tokens_prev, candidates, client.i_batch - i, client.seq_id);
|
const llama_token id = llama_sampling_sample(ctx, NULL, client.ctx_sampling, client.tokens_prev, candidates, client.i_batch - i);
|
||||||
|
|
||||||
if (client.n_decoded == 1) {
|
if (client.n_decoded == 1) {
|
||||||
// start measuring generation time after the first token to make sure all concurrent clients
|
// start measuring generation time after the first token to make sure all concurrent clients
|
||||||
@ -386,7 +387,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
n_total_prompt += client.n_prompt;
|
n_total_prompt += client.n_prompt;
|
||||||
n_total_gen += client.n_decoded;
|
n_total_gen += client.n_decoded;
|
||||||
llama_sampling_context_reset(ctx_sampling, client.seq_id);
|
|
||||||
client.seq_id = -1;
|
client.seq_id = -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -9,6 +9,12 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
struct seq_draft {
|
||||||
|
std::vector<llama_token> tokens;
|
||||||
|
|
||||||
|
struct llama_grammar * grammar = NULL;
|
||||||
|
};
|
||||||
|
|
||||||
int main(int argc, char ** argv) {
|
int main(int argc, char ** argv) {
|
||||||
gpt_params params;
|
gpt_params params;
|
||||||
|
|
||||||
@ -213,13 +219,8 @@ int main(int argc, char ** argv) {
|
|||||||
if (grammar_dft) {
|
if (grammar_dft) {
|
||||||
llama_grammar_free(grammar_dft);
|
llama_grammar_free(grammar_dft);
|
||||||
}
|
}
|
||||||
// Note: Hardcoded to sequence id 0, if this ever supports parallel generation
|
|
||||||
// that will need to change.
|
grammar_dft = llama_grammar_copy(ctx_sampling.grammar);
|
||||||
auto it = ctx_sampling.sequence_contexts.find(0);
|
|
||||||
GGML_ASSERT(it != ctx_sampling.sequence_contexts.end());
|
|
||||||
// This is necessary because each sequence id in sequence_contexts
|
|
||||||
// uses a copy of the original grammar.
|
|
||||||
grammar_dft = llama_grammar_copy(it->second.grammar);
|
|
||||||
|
|
||||||
LOG("copied target grammar to draft grammar\n");
|
LOG("copied target grammar to draft grammar\n");
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user