This commit is contained in:
FSSRepo 2023-10-18 16:55:26 -04:00
commit 8540568c48
22 changed files with 1001 additions and 1038 deletions

View File

@ -545,7 +545,7 @@ llama.o: llama.cpp ggml.h ggml-alloc.h ggml-backend.h ggml-cuda.h ggml-metal.h l
$(CXX) $(CXXFLAGS) -c $< -o $@
COMMON_H_DEPS = common/common.h common/sampling.h build-info.h common/log.h
COMMON_DEPS = $(COMMON_H_DEPS) common.o sampling.o
COMMON_DEPS = $(COMMON_H_DEPS) common.o sampling.o grammar-parser.o
common.o: common/common.cpp $(COMMON_H_DEPS)
$(CXX) $(CXXFLAGS) -c $< -o $@

View File

@ -10,13 +10,9 @@
Inference of [LLaMA](https://arxiv.org/abs/2302.13971) model in pure C/C++
### Hot topics
- ‼️ BPE tokenizer update: existing Falcon and Starcoder `.gguf` models will need to be reconverted: [#3252](https://github.com/ggerganov/llama.cpp/pull/3252)
- ‼️ Breaking change: `rope_freq_base` and `rope_freq_scale` must be set to zero to use the model default values: [#3401](https://github.com/ggerganov/llama.cpp/pull/3401)
- Parallel decoding + continuous batching support added: [#3228](https://github.com/ggerganov/llama.cpp/pull/3228) \
**Devs should become familiar with the new API**
- Local Falcon 180B inference on Mac Studio
https://github.com/ggerganov/llama.cpp/assets/1991296/98abd4e8-7077-464c-ae89-aebabca7757e
- LLaVA support: https://github.com/ggerganov/llama.cpp/pull/3436
- ‼️ BPE tokenizer update: existing Falcon and Starcoder `.gguf` models will need to be reconverted: [#3252](https://github.com/ggerganov/llama.cpp/pull/3252)
----

View File

@ -820,6 +820,27 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
return cparams;
}
void llama_batch_clear(struct llama_batch & batch) {
batch.n_tokens = 0;
}
void llama_batch_add(
struct llama_batch & batch,
llama_token id,
llama_pos pos,
const std::vector<llama_seq_id> & seq_ids,
bool logits) {
batch.token [batch.n_tokens] = id;
batch.pos [batch.n_tokens] = pos,
batch.n_seq_id[batch.n_tokens] = seq_ids.size();
for (size_t i = 0; i < seq_ids.size(); ++i) {
batch.seq_id[batch.n_tokens][i] = seq_ids[i];
}
batch.logits [batch.n_tokens] = logits;
batch.n_tokens++;
}
std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(gpt_params & params) {
auto mparams = llama_model_params_from_gpt_params(params);

View File

@ -70,6 +70,7 @@ struct gpt_params {
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
std::string logdir = ""; // directory in which to save YAML log files
// TODO: avoid tuple, use struct
std::vector<std::tuple<std::string, float>> lora_adapter; // lora adapter path with user defined scale
std::string lora_base = ""; // base model path for the lora adapter
@ -124,10 +125,23 @@ void process_escapes(std::string& input);
// Model utils
//
// TODO: avoid tuplue, use struct
std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(gpt_params & params);
struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & params);
struct llama_model_params llama_model_params_from_gpt_params (const gpt_params & params);
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params);
// Batch utils
void llama_batch_clear(struct llama_batch & batch);
void llama_batch_add(
struct llama_batch & batch,
llama_token id,
llama_pos pos,
const std::vector<llama_seq_id> & seq_ids,
bool logits);
//
// Vocab utils
//

View File

@ -579,38 +579,75 @@ inline std::string log_var_to_string_impl(const std::vector<int> & var)
return buf.str();
}
#define LOG_TOKENS_TOSTR_PRETTY(ctx, tokens) \
[&tokens, &ctx]() \
{ \
std::stringstream buf; \
buf << "[ "; \
\
bool first = true; \
for (const auto &token : tokens) \
{ \
if (!first) \
buf << ", "; \
else \
first = false; \
\
auto detokenized = llama_token_to_piece(ctx, token); \
\
detokenized.erase( \
std::remove_if( \
detokenized.begin(), \
detokenized.end(), \
[](const unsigned char c) { return !std::isprint(c); }), \
detokenized.end()); \
\
buf \
<< "'" << detokenized << "'" \
<< ":" << std::to_string(token); \
} \
buf << " ]"; \
\
return buf.str(); \
}() \
.c_str()
template <typename C, typename T>
inline std::string LOG_TOKENS_TOSTR_PRETTY(const C & ctx, const T & tokens)
{
std::stringstream buf;
buf << "[ ";
bool first = true;
for (const auto &token : tokens)
{
if (!first) {
buf << ", ";
} else {
first = false;
}
auto detokenized = llama_token_to_piece(ctx, token);
detokenized.erase(
std::remove_if(
detokenized.begin(),
detokenized.end(),
[](const unsigned char c) { return !std::isprint(c); }),
detokenized.end());
buf
<< "'" << detokenized << "'"
<< ":" << std::to_string(token);
}
buf << " ]";
return buf.str();
}
template <typename C, typename B>
inline std::string LOG_BATCH_TOSTR_PRETTY(const C & ctx, const B & batch)
{
std::stringstream buf;
buf << "[ ";
bool first = true;
for (int i = 0; i < batch.n_tokens; ++i)
{
if (!first) {
buf << ", ";
} else {
first = false;
}
auto detokenized = llama_token_to_piece(ctx, batch.token[i]);
detokenized.erase(
std::remove_if(
detokenized.begin(),
detokenized.end(),
[](const unsigned char c) { return !std::isprint(c); }),
detokenized.end());
buf
<< "\n" << std::to_string(i)
<< ":token '" << detokenized << "'"
<< ":pos " << std::to_string(batch.pos[i])
<< ":n_seq_id " << std::to_string(batch.n_seq_id[i])
<< ":seq_id " << std::to_string(batch.seq_id[i][0])
<< ":logits " << std::to_string(batch.logits[i]);
}
buf << " ]";
return buf.str();
}
#ifdef LOG_DISABLE_LOGS

View File

@ -1,64 +1,81 @@
#include "sampling.h"
llama_sampling_context::~llama_sampling_context() {
for (auto & it : sequence_contexts) {
if (it.second.grammar != NULL) {
llama_grammar_free(it.second.grammar);
it.second.grammar = NULL;
struct llama_sampling_context * llama_sampling_init(const struct gpt_params & params) {
struct llama_sampling_context * result = new llama_sampling_context();
result->params = params.sampling_params;
result->grammar = nullptr;
// if there is a grammar, parse it
if (!params.grammar.empty()) {
result->parsed_grammar = grammar_parser::parse(params.grammar.c_str());
// will be empty (default) if there are parse errors
if (result->parsed_grammar.rules.empty()) {
fprintf(stderr, "%s: failed to parse grammar\n", __func__);
return nullptr;
}
std::vector<const llama_grammar_element *> grammar_rules(result->parsed_grammar.c_rules());
result->grammar = llama_grammar_init(
grammar_rules.data(),
grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root"));
}
result->prev.resize(params.n_ctx);
return result;
}
llama_sampling_context llama_sampling_context_init(
const struct gpt_params & params,
llama_grammar * grammar) {
llama_sampling_context result;
void llama_sampling_free(struct llama_sampling_context * ctx) {
if (ctx->grammar != NULL) {
llama_grammar_free(ctx->grammar);
}
result.params = params.sampling_params;
result.grammar = grammar;
return result;
delete ctx;
}
// Note: Creates the context if it doesn't exist, so this always return something.
llama_sampler_sequence_context & llama_sampling_get_sequence_context(
llama_sampling_context & ctx_sampling,
const llama_seq_id seq) {
const auto it = ctx_sampling.sequence_contexts.find(seq);
if (it != ctx_sampling.sequence_contexts.end()) {
return it->second;
void llama_sampling_reset(llama_sampling_context * ctx) {
if (ctx->grammar != NULL) {
llama_grammar_free(ctx->grammar);
}
llama_sampler_sequence_context new_ctx = {
2.0f * ctx_sampling.params.mirostat_tau,
ctx_sampling.grammar != NULL ? llama_grammar_copy(ctx_sampling.grammar) : NULL,
};
return ctx_sampling.sequence_contexts.insert({seq, new_ctx}).first->second;
if (!ctx->parsed_grammar.rules.empty()) {
std::vector<const llama_grammar_element *> grammar_rules(ctx->parsed_grammar.c_rules());
ctx->grammar = llama_grammar_init(
grammar_rules.data(),
grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root"));
}
std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
ctx->cur.clear();
}
bool llama_sampling_context_reset(
llama_sampling_context & ctx_sampling,
const llama_seq_id seq) {
const auto it = ctx_sampling.sequence_contexts.find(seq);
if (it == ctx_sampling.sequence_contexts.end()) return false;
if (it->second.grammar != NULL) {
llama_grammar_free(it->second.grammar);
it->second.grammar = NULL;
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) {
if (dst->grammar) {
llama_grammar_free(dst->grammar);
dst->grammar = nullptr;
}
ctx_sampling.sequence_contexts.erase(it);
return true;
if (src->grammar) {
dst->grammar = llama_grammar_copy(src->grammar);
}
dst->prev = src->prev;
}
llama_token llama_sampling_sample(
struct llama_context * ctx,
struct llama_context * ctx_guidance,
struct llama_sampling_context & ctx_sampling,
const std::vector<llama_token> & last_tokens,
std::vector<llama_token_data> & candidates,
const int idx,
llama_seq_id seq) {
const int n_ctx = llama_n_ctx(ctx);
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
struct llama_context * ctx_cfg,
const int idx) {
const int n_ctx = llama_n_ctx(ctx_main);
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
const llama_sampling_params & params = ctx_sampling->params;
const llama_sampling_params & params = ctx_sampling.params;
const float temp = params.temp;
const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
const float top_p = params.top_p;
@ -73,41 +90,45 @@ llama_token llama_sampling_sample(
const float mirostat_eta = params.mirostat_eta;
const bool penalize_nl = params.penalize_nl;
auto & prev = ctx_sampling->prev;
auto & cur = ctx_sampling->cur;
llama_token id = 0;
float * logits = llama_get_logits_ith(ctx, idx);
float * logits = llama_get_logits_ith(ctx_main, idx);
// Apply params.logit_bias map
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
logits[it->first] += it->second;
}
candidates.clear();
cur.clear();
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
}
llama_token_data_array cur_p = { candidates.data(), candidates.size(), false };
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
if (ctx_guidance) {
llama_sample_classifier_free_guidance(ctx, &cur_p, ctx_guidance, params.cfg_scale);
if (ctx_cfg) {
llama_sample_classifier_free_guidance(ctx_main, &cur_p, ctx_cfg, params.cfg_scale);
}
// apply penalties
if (!last_tokens.empty()) {
const float nl_logit = logits[llama_token_nl(ctx)];
const int last_n_repeat = std::min(std::min((int)last_tokens.size(), repeat_last_n), n_ctx);
if (!prev.empty()) {
const float nl_logit = logits[llama_token_nl(ctx_main)];
const int last_n_repeat = std::min(std::min((int)prev.size(), repeat_last_n), n_ctx);
llama_sample_repetition_penalty(ctx, &cur_p,
last_tokens.data() + last_tokens.size() - last_n_repeat,
llama_sample_repetition_penalty(ctx_main, &cur_p,
prev.data() + prev.size() - last_n_repeat,
last_n_repeat, repeat_penalty);
llama_sample_frequency_and_presence_penalties(ctx, &cur_p,
last_tokens.data() + last_tokens.size() - last_n_repeat,
llama_sample_frequency_and_presence_penalties(ctx_main, &cur_p,
prev.data() + prev.size() - last_n_repeat,
last_n_repeat, alpha_frequency, alpha_presence);
if (!penalize_nl) {
for (size_t idx = 0; idx < cur_p.size; idx++) {
if (cur_p.data[idx].id == llama_token_nl(ctx)) {
if (cur_p.data[idx].id == llama_token_nl(ctx_main)) {
cur_p.data[idx].logit = nl_logit;
break;
}
@ -115,52 +136,58 @@ llama_token llama_sampling_sample(
}
}
llama_sampler_sequence_context & ctx_seq = llama_sampling_get_sequence_context(ctx_sampling, seq);
if (ctx_seq.grammar != NULL) {
llama_sample_grammar(ctx, &cur_p, ctx_seq.grammar);
if (ctx_sampling->grammar != NULL) {
llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
}
if (temp <= 0) {
// Greedy sampling
id = llama_sample_token_greedy(ctx, &cur_p);
id = llama_sample_token_greedy(ctx_main, &cur_p);
} else {
if (mirostat == 1) {
const int mirostat_m = 100;
llama_sample_temp(ctx, &cur_p, temp);
id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_seq.mirostat_mu);
llama_sample_temp(ctx_main, &cur_p, temp);
id = llama_sample_token_mirostat(ctx_main, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_sampling->mirostat_mu);
} else if (mirostat == 2) {
llama_sample_temp(ctx, &cur_p, temp);
id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &ctx_seq.mirostat_mu);
llama_sample_temp(ctx_main, &cur_p, temp);
id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu);
} else {
// Temperature sampling
size_t min_keep = std::max(1, params.n_probs);
llama_sample_top_k (ctx, &cur_p, top_k, min_keep);
llama_sample_tail_free (ctx, &cur_p, tfs_z, min_keep);
llama_sample_typical (ctx, &cur_p, typical_p, min_keep);
llama_sample_top_p (ctx, &cur_p, top_p, min_keep);
llama_sample_temp(ctx, &cur_p, temp);
llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep);
llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep);
llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep);
llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep);
llama_sample_temp (ctx_main, &cur_p, temp);
{
const int n_top = 10;
LOG("top %d candidates:\n", n_top);
id = llama_sample_token(ctx_main, &cur_p);
for (int i = 0; i < n_top; i++) {
const llama_token id = cur_p.data[i].id;
(void)id; // To avoid a warning that id is unused when logging is disabled.
LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx, id).c_str(), cur_p.data[i].p);
}
}
//{
// const int n_top = 10;
// LOG("top %d candidates:\n", n_top);
id = llama_sample_token(ctx, &cur_p);
// for (int i = 0; i < n_top; i++) {
// const llama_token id = cur_p.data[i].id;
// (void)id; // To avoid a warning that id is unused when logging is disabled.
// LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx_main, id).c_str(), cur_p.data[i].p);
// }
//}
LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx, id).c_str());
LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx_main, id).c_str());
}
}
if (ctx_seq.grammar != NULL) {
llama_grammar_accept_token(ctx, ctx_seq.grammar, id);
}
return id;
}
void llama_sampling_accept(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
llama_token id) {
ctx_sampling->prev.erase(ctx_sampling->prev.begin());
ctx_sampling->prev.push_back(id);
if (ctx_sampling->grammar != NULL) {
llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id);
}
}

View File

@ -2,6 +2,8 @@
#include "llama.h"
#include "grammar-parser.h"
#include <string>
#include <vector>
#include <unordered_map>
@ -34,75 +36,64 @@ typedef struct llama_sampling_params {
} llama_sampling_params;
// per-sequence sampler context
typedef struct llama_sampler_sequence_context {
float mirostat_mu; // mirostat sampler state
llama_grammar * grammar;
} llama_sampler_sequence_context;
// general sampler context
typedef struct llama_sampling_context {
~llama_sampling_context();
// parameters that will be used for sampling and when creating
// new llama_sampler_sequence_context instances
// TODO: move to llama.h
struct llama_sampling_context {
// parameters that will be used for sampling
llama_sampling_params params;
// map of sequence ids to sampler contexts
std::unordered_map<llama_seq_id, llama_sampler_sequence_context> sequence_contexts;
// mirostat sampler state
float mirostat_mu;
// when non-NULL, new instances of llama_sampler_sequence_context
// will get a copy of the grammar here
// note: only the pointer is stored here, it is not a copy of
// the grammar and shouldn't be freed
llama_grammar * grammar;
} llama_sampling_context;
// internal
grammar_parser::parse_state parsed_grammar;
// TODO: replace with ring-buffer
std::vector<llama_token> prev;
std::vector<llama_token_data> cur;
};
#include "common.h"
// Create a new sampling context instance.
llama_sampling_context llama_sampling_context_init(
const struct gpt_params & params,
llama_grammar * grammar = NULL);
struct llama_sampling_context * llama_sampling_init(const struct gpt_params & params);
// Fetches the sampler context for the specified sequence id (defaults to 0).
// If the context for that sequence id doesn't already exist, it will be created with
// default values based on the parameters in the ctx_sampling argument.
llama_sampler_sequence_context & llama_sampling_get_sequence_context(
llama_sampling_context & ctx_sampling,
const llama_seq_id seq = 0);
void llama_sampling_free(struct llama_sampling_context * ctx);
// Reset the sampler context for the supplied sequence id (defaults to 0).
// This is necessary to reuse a sequence id or free memory used by sequences
// that are no longer required.
bool llama_sampling_context_reset(
llama_sampling_context & ctx_sampling,
const llama_seq_id seq = 0);
// Reset the sampler context
// - clear prev tokens
// - reset grammar
void llama_sampling_reset(llama_sampling_context * ctx);
// Copy the sampler context
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst);
// this is a common sampling function used across the examples for convenience
// it can serve as a starting point for implementing your own sampling function
// Note: When using multiple sequences, it is the caller's responsibility to call
// llama_sampling_context_reset when a sequence ends
// llama_sampling_reset when a sequence ends
//
// required:
// - ctx: context to use for sampling
// - ctx_main: context to use for sampling
// - ctx_sampling: sampling-specific context
//
// optional:
// - ctx_guidance: context to use for classifier-free guidance, ignore if NULL
// - last_tokens: needed for repetition penalty, ignore if empty
// - idx: sample from llama_get_logits_ith(ctx, idx)
// - seq: sequence id to associate sampler state with
// - ctx_cfg: context to use for classifier-free guidance
// - idx: sample from llama_get_logits_ith(ctx, idx)
//
// returns:
// - token: sampled token
// - candidates: vector of candidate tokens
//
llama_token llama_sampling_sample(
struct llama_context * ctx,
struct llama_context * ctx_guidance,
struct llama_sampling_context & ctx_sampling,
const std::vector<llama_token> & last_tokens,
std::vector<llama_token_data> & candidates,
const int idx = 0,
llama_seq_id seq = 0);
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
struct llama_context * ctx_cfg,
int idx = 0);
void llama_sampling_accept(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
llama_token id);

View File

@ -114,7 +114,7 @@ int main(int argc, char ** argv) {
return 1;
}
llama_batch batch = llama_batch_init(n_kv_max, 0);
llama_batch batch = llama_batch_init(n_kv_max, 0, 1);
// decode in batches of ctx_params.n_batch tokens
auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch) {
@ -123,11 +123,12 @@ int main(int argc, char ** argv) {
llama_batch batch_view = {
n_tokens,
batch.token + i,
batch.token + i,
nullptr,
batch.pos + i,
batch.seq_id + i,
batch.logits + i,
batch.pos + i,
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
0, 0, 0, // unused
};
@ -143,13 +144,8 @@ int main(int argc, char ** argv) {
// warm up
{
batch.n_tokens = 16;
for (int i = 0; i < batch.n_tokens; ++i) {
batch.token[i] = 0;
batch.pos[i] = i;
batch.seq_id[i] = 0;
batch.logits[i] = false;
for (int i = 0; i < 16; ++i) {
llama_batch_add(batch, 0, i, { 0 }, false);
}
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
@ -174,13 +170,12 @@ int main(int argc, char ** argv) {
continue;
}
batch.n_tokens = is_pp_shared ? pp : pl*pp;
llama_batch_clear(batch);
for (int i = 0; i < batch.n_tokens; ++i) {
batch.token[i] = 0;
batch.pos[i] = i;
batch.seq_id[i] = 0;
batch.logits[i] = false;
const int n_tokens = is_pp_shared ? pp : pl*pp;
for (int i = 0; i < n_tokens; ++i) {
llama_batch_add(batch, 0, i, { 0 }, false);
}
batch.logits[batch.n_tokens - 1] = true;
@ -204,13 +199,10 @@ int main(int argc, char ** argv) {
const auto t_tg_start = ggml_time_us();
for (int i = 0; i < tg; ++i) {
batch.n_tokens = pl;
llama_batch_clear(batch);
for (int j = 0; j < pl; ++j) {
batch.token[j] = 0;
batch.pos[j] = pp + i;
batch.seq_id[j] = j;
batch.logits[j] = true;
llama_batch_add(batch, 0, pp + i, { j }, true);
}
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {

View File

@ -69,7 +69,7 @@ for id: llama_token in tokens {
print("\n")
var batch = llama_batch_init(max(Int32(tokens.count), Int32(n_parallel)), 0)
var batch = llama_batch_init(max(Int32(tokens.count), Int32(n_parallel)), 0, 1)
defer {
llama_batch_free(batch)
}
@ -80,7 +80,12 @@ batch.n_tokens = Int32(tokens.count)
for (i, token) in tokens.enumerated() {
batch.token[i] = token
batch.pos[i] = Int32(i)
batch.seq_id[i] = 0
batch.n_seq_id[i] = 1
// batch.seq_id[i][0] = 0
// TODO: is this the proper way to do this?
if let seq_id = batch.seq_id[i] {
seq_id[0] = 0
}
batch.logits[i] = 0
}
@ -169,7 +174,10 @@ while n_cur <= n_len {
// push this new token for next evaluation
batch.token[Int(batch.n_tokens)] = new_token_id
batch.pos[Int(batch.n_tokens)] = n_cur
batch.seq_id[Int(batch.n_tokens)] = Int32(i)
batch.n_seq_id[Int(batch.n_tokens)] = 1
if let seq_id = batch.seq_id[Int(batch.n_tokens)] {
seq_id[0] = Int32(i)
}
batch.logits[Int(batch.n_tokens)] = 1
i_batch[i] = batch.n_tokens

View File

@ -97,20 +97,15 @@ int main(int argc, char ** argv) {
fflush(stderr);
// create a llama_batch with size 512
// create a llama_batch
// we use this object to submit token data for decoding
llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t)n_parallel), 0);
llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t)n_parallel), 0, 1);
// evaluate the initial prompt
batch.n_tokens = tokens_list.size();
for (int32_t i = 0; i < batch.n_tokens; i++) {
batch.token[i] = tokens_list[i];
batch.pos[i] = i;
batch.seq_id[i] = 0;
batch.logits[i] = false;
for (size_t i = 0; i < tokens_list.size(); ++i) {
llama_batch_add(batch, tokens_list[i], i, { 0 }, false);
}
GGML_ASSERT(batch.n_tokens == (int) tokens_list.size());
// llama_decode will output logits only for the last token of the prompt
batch.logits[batch.n_tokens - 1] = true;
@ -146,7 +141,7 @@ int main(int argc, char ** argv) {
while (n_cur <= n_len) {
// prepare the next batch
batch.n_tokens = 0;
llama_batch_clear(batch);
// sample the next token for each parallel sequence / stream
for (int32_t i = 0; i < n_parallel; ++i) {
@ -198,15 +193,10 @@ int main(int argc, char ** argv) {
streams[i] += llama_token_to_piece(ctx, new_token_id);
// push this new token for next evaluation
batch.token [batch.n_tokens] = new_token_id;
batch.pos [batch.n_tokens] = n_cur;
batch.seq_id[batch.n_tokens] = i;
batch.logits[batch.n_tokens] = true;
i_batch[i] = batch.n_tokens;
batch.n_tokens += 1;
// push this new token for next evaluation
llama_batch_add(batch, new_token_id, n_cur, { i }, true);
n_decode += 1;
}

View File

@ -79,7 +79,7 @@ bool eval_float(void * model, float * input, int N){
if (n_eval > n_batch) {
n_eval = n_batch;
}
llama_batch batch = { int32_t(n_eval), nullptr, (input+i*n_emb), nullptr, nullptr, nullptr, n_past, 1, 0, };
llama_batch batch = { int32_t(n_eval), nullptr, (input+i*n_emb), nullptr, nullptr, nullptr, nullptr, n_past, 1, 0, };
if (llama_decode(ctx, batch)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return false;

View File

@ -257,12 +257,12 @@ int main(int argc, char ** argv) {
LOG("prefix: \"%s\"\n", log_tostr(params.input_prefix));
LOG("suffix: \"%s\"\n", log_tostr(params.input_suffix));
LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp));
LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
// Should not run without any tokens
if (embd_inp.empty()) {
embd_inp.push_back(llama_token_bos(ctx));
LOG("embd_inp was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp));
LOG("embd_inp was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
}
// Tokenize negative prompt
@ -273,10 +273,10 @@ int main(int argc, char ** argv) {
LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(sparams.cfg_negative_prompt));
guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, add_bos);
LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp));
LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp).c_str());
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, add_bos);
LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp));
LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str());
original_prompt_len = original_inp.size();
guidance_offset = (int)guidance_inp.size() - original_prompt_len;
@ -294,8 +294,8 @@ int main(int argc, char ** argv) {
params.n_keep = (int)embd_inp.size();
}
LOG("inp_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_pfx));
LOG("inp_sfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_sfx));
LOG("inp_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_pfx).c_str());
LOG("inp_sfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_sfx).c_str());
// enable interactive mode if interactive start is specified
@ -388,9 +388,6 @@ int main(int argc, char ** argv) {
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
}
// TODO: replace with ring-buffer
std::vector<llama_token> last_tokens(n_ctx);
std::fill(last_tokens.begin(), last_tokens.end(), 0);
LOG_TEE("\n##### Infill mode #####\n\n");
if (params.infill) {
printf("\n************\n");
@ -433,11 +430,7 @@ int main(int argc, char ** argv) {
std::vector<llama_token> embd;
std::vector<llama_token> embd_guidance;
const int n_vocab = llama_n_vocab(model);
llama_sampling_context ctx_sampling = llama_sampling_context_init(params, grammar);
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params);
while (n_remain != 0 || params.interactive) {
// predict
@ -484,7 +477,7 @@ int main(int argc, char ** argv) {
LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance);
LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd));
LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str());
}
@ -512,7 +505,7 @@ int main(int argc, char ** argv) {
input_buf = embd_guidance.data();
input_size = embd_guidance.size();
LOG("guidance context: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_guidance));
LOG("guidance context: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_guidance).c_str());
} else {
input_buf = embd.data();
input_size = embd.size();
@ -535,7 +528,7 @@ int main(int argc, char ** argv) {
n_eval = params.n_batch;
}
LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd));
LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str());
if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0))) {
LOG_TEE("%s : failed to eval\n", __func__);
@ -554,12 +547,11 @@ int main(int argc, char ** argv) {
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
const llama_token id = llama_sampling_sample(ctx, ctx_guidance, ctx_sampling, last_tokens, candidates);
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
last_tokens.erase(last_tokens.begin());
last_tokens.push_back(id);
llama_sampling_accept(ctx_sampling, ctx, id);
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, last_tokens));
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str());
embd.push_back(id);
@ -575,8 +567,8 @@ int main(int argc, char ** argv) {
LOG("embd_inp.size(): %d, n_consumed: %d\n", (int) embd_inp.size(), n_consumed);
while ((int) embd_inp.size() > n_consumed) {
embd.push_back(embd_inp[n_consumed]);
last_tokens.erase(last_tokens.begin());
last_tokens.push_back(embd_inp[n_consumed]);
ctx_sampling->prev.erase(ctx_sampling->prev.begin());
ctx_sampling->prev.push_back(embd_inp[n_consumed]);
++n_consumed;
if ((int) embd.size() >= params.n_batch) {
break;
@ -608,7 +600,7 @@ int main(int argc, char ** argv) {
if ((int) embd_inp.size() <= n_consumed) {
// deal with eot token in infill mode
if ((last_tokens.back() == llama_token_eot(ctx) || is_interacting) && params.interactive){
if ((ctx_sampling->prev.back() == llama_token_eot(ctx) || is_interacting) && params.interactive){
if(is_interacting && !params.interactive_first) {
// print an eot token
printf("%s", llama_token_to_piece(ctx, llama_token_eot(ctx)).c_str());
@ -675,7 +667,7 @@ int main(int argc, char ** argv) {
is_interacting = false;
}
// deal with end of text token in interactive mode
else if (last_tokens.back() == llama_token_eos(ctx)) {
else if (ctx_sampling->prev.back() == llama_token_eos(ctx)) {
LOG("found EOS token\n");
if (params.interactive) {
@ -727,7 +719,7 @@ int main(int argc, char ** argv) {
const size_t original_size = embd_inp.size();
const auto line_inp = ::llama_tokenize(ctx, buffer, false);
LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp));
LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp).c_str());
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());

View File

@ -17,7 +17,7 @@ inline bool eval_image_embd(llama_context * ctx_llama, float * embd, int N, int
if (n_eval > n_batch) {
n_eval = n_batch;
}
llama_batch batch = {int32_t(n_eval), nullptr, (embd+i*n_embd), nullptr, nullptr, nullptr, *n_past, 1, 0, };
llama_batch batch = {int32_t(n_eval), nullptr, (embd+i*n_embd), nullptr, nullptr, nullptr, nullptr, *n_past, 1, 0, };
if (llama_decode(ctx_llama, batch)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return false;

View File

@ -127,7 +127,7 @@ int main(int argc, char ** argv) {
const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict;
eval_string(ctx_llama, "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER:", params.n_batch, &n_past, true);
eval_string(ctx_llama, "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER:", params.n_batch, &n_past, true);
eval_image_embd(ctx_llama, image_embd, n_img_pos, params.n_batch, &n_past);
eval_string(ctx_llama, (params.prompt + "\nASSISTANT:").c_str(), params.n_batch, &n_past, false);

View File

@ -3,7 +3,6 @@
#include "console.h"
#include "llama.h"
#include "build-info.h"
#include "grammar-parser.h"
#include <cassert>
#include <cinttypes>
@ -245,12 +244,12 @@ int main(int argc, char ** argv) {
}
LOG("prompt: \"%s\"\n", log_tostr(params.prompt));
LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp));
LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
// Should not run without any tokens
if (embd_inp.empty()) {
embd_inp.push_back(llama_token_bos(ctx));
LOG("embd_inp was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp));
LOG("embd_inp was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
}
// Tokenize negative prompt
@ -261,10 +260,10 @@ int main(int argc, char ** argv) {
LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(sparams.cfg_negative_prompt));
guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, add_bos, true);
LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp));
LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp).c_str());
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp));
LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str());
original_prompt_len = original_inp.size();
guidance_offset = (int)guidance_inp.size() - original_prompt_len;
@ -323,8 +322,8 @@ int main(int argc, char ** argv) {
const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", add_bos, true);
const auto inp_sfx = ::llama_tokenize(ctx, "\n\n### Response:\n\n", false, true);
LOG("inp_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_pfx));
LOG("inp_sfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_sfx));
LOG("inp_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_pfx).c_str());
LOG("inp_sfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_sfx).c_str());
// in instruct mode, we inject a prefix and a suffix to each input by the user
if (params.instruct) {
@ -421,35 +420,6 @@ int main(int argc, char ** argv) {
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
LOG_TEE("\n\n");
struct llama_grammar * grammar = NULL;
grammar_parser::parse_state parsed_grammar;
if (!params.grammar.empty()) {
parsed_grammar = grammar_parser::parse(params.grammar.c_str());
// will be empty (default) if there are parse errors
if (parsed_grammar.rules.empty()) {
return 1;
}
LOG_TEE("%s: grammar:\n", __func__);
grammar_parser::print_grammar(stderr, parsed_grammar);
LOG_TEE("\n");
{
auto it = sparams.logit_bias.find(llama_token_eos(ctx));
if (it != sparams.logit_bias.end() && it->second == -INFINITY) {
LOG_TEE("%s: warning: EOS token is disabled, which will cause most grammars to fail\n", __func__);
}
}
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
grammar = llama_grammar_init(
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
}
// TODO: replace with ring-buffer
std::vector<llama_token> last_tokens(n_ctx);
std::fill(last_tokens.begin(), last_tokens.end(), 0);
if (params.interactive) {
const char *control_message;
if (params.multiline_input) {
@ -489,11 +459,7 @@ int main(int argc, char ** argv) {
std::vector<llama_token> embd;
std::vector<llama_token> embd_guidance;
const int n_vocab = llama_n_vocab(model);
llama_sampling_context ctx_sampling = llama_sampling_context_init(params, grammar);
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params);
while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
// predict
@ -540,7 +506,7 @@ int main(int argc, char ** argv) {
LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance);
LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd));
LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str());
LOG("clear session path\n");
path_session.clear();
@ -570,7 +536,6 @@ int main(int argc, char ** argv) {
// evaluate tokens in batches
// embd is typically prepared beforehand to fit within a batch, but not always
if (ctx_guidance) {
int input_size = 0;
llama_token * input_buf = NULL;
@ -592,7 +557,7 @@ int main(int argc, char ** argv) {
input_buf = embd_guidance.data();
input_size = embd_guidance.size();
LOG("guidance context: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_guidance));
LOG("guidance context: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_guidance).c_str());
} else {
input_buf = embd.data();
input_size = embd.size();
@ -615,7 +580,7 @@ int main(int argc, char ** argv) {
n_eval = params.n_batch;
}
LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd));
LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str());
if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0))) {
LOG_TEE("%s : failed to eval\n", __func__);
@ -645,12 +610,11 @@ int main(int argc, char ** argv) {
LOG("saved session to %s\n", path_session.c_str());
}
const llama_token id = llama_sampling_sample(ctx, ctx_guidance, ctx_sampling, last_tokens, candidates);
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
last_tokens.erase(last_tokens.begin());
last_tokens.push_back(id);
llama_sampling_accept(ctx_sampling, ctx, id);
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, last_tokens));
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str());
embd.push_back(id);
@ -666,8 +630,14 @@ int main(int argc, char ** argv) {
LOG("embd_inp.size(): %d, n_consumed: %d\n", (int) embd_inp.size(), n_consumed);
while ((int) embd_inp.size() > n_consumed) {
embd.push_back(embd_inp[n_consumed]);
last_tokens.erase(last_tokens.begin());
last_tokens.push_back(embd_inp[n_consumed]);
// GG: I'm not sure it's a good idea to push the prompt tokens into the sampling context
// Most likely will remove this in the future to avoid exposing "prev"
// Same thing is done in "server". If we stop pushing the prompt tokens, then the repetition
// penalty will be applied only based on the tokens generated by the model.
ctx_sampling->prev.erase(ctx_sampling->prev.begin());
ctx_sampling->prev.push_back(embd_inp[n_consumed]);
++n_consumed;
if ((int) embd.size() >= params.n_batch) {
break;
@ -700,7 +670,7 @@ int main(int argc, char ** argv) {
// check for reverse prompt
if (!params.antiprompt.empty()) {
std::string last_output;
for (auto id : last_tokens) {
for (auto id : ctx_sampling->prev) {
last_output += llama_token_to_piece(ctx, id);
}
@ -729,7 +699,7 @@ int main(int argc, char ** argv) {
}
// deal with end of text token in interactive mode
if (last_tokens.back() == llama_token_eos(ctx)) {
if (ctx_sampling->prev.back() == llama_token_eos(ctx)) {
LOG("found EOS token\n");
if (params.interactive) {
@ -801,7 +771,7 @@ int main(int argc, char ** argv) {
const auto line_pfx = ::llama_tokenize(ctx, params.input_prefix, false, true);
const auto line_inp = ::llama_tokenize(ctx, buffer, false, false);
const auto line_sfx = ::llama_tokenize(ctx, params.input_suffix, false, true);
LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp));
LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp).c_str());
embd_inp.insert(embd_inp.end(), line_pfx.begin(), line_pfx.end());
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
@ -830,15 +800,7 @@ int main(int argc, char ** argv) {
if (n_past > 0) {
if (is_interacting) {
// reset grammar state if we're restarting generation
if (grammar != NULL) {
llama_grammar_free(grammar);
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
grammar = llama_grammar_init(
grammar_rules.data(), grammar_rules.size(),
parsed_grammar.symbol_ids.at("root"));
}
llama_sampling_reset(ctx_sampling);
}
is_interacting = false;
}
@ -870,9 +832,7 @@ int main(int argc, char ** argv) {
llama_free(ctx);
llama_free_model(model);
if (grammar != NULL) {
llama_grammar_free(grammar);
}
llama_sampling_free(ctx_sampling);
llama_backend_free();
#ifndef LOG_DISABLE_LOGS

View File

@ -51,6 +51,12 @@ static std::vector<std::string> k_prompts = {
};
struct client {
~client() {
if (ctx_sampling) {
llama_sampling_free(ctx_sampling);
}
}
int32_t id = 0;
llama_seq_id seq_id = -1;
@ -68,7 +74,7 @@ struct client {
std::string prompt;
std::string response;
std::vector<llama_token> tokens_prev;
struct llama_sampling_context * ctx_sampling = nullptr;
};
static void print_date_time() {
@ -125,8 +131,6 @@ int main(int argc, char ** argv) {
params.logits_all = true;
std::tie(model, ctx) = llama_init_from_gpt_params(params);
llama_sampling_context ctx_sampling = llama_sampling_context_init(params, NULL);
// load the prompts from an external file if there are any
if (params.prompt.empty()) {
printf("\n\033[32mNo new questions so proceed with build-in defaults.\033[0m\n");
@ -147,20 +151,15 @@ int main(int argc, char ** argv) {
fprintf(stderr, "\n\n");
fflush(stderr);
const int n_ctx = llama_n_ctx(ctx);
const int n_vocab = llama_n_vocab(model);
const int n_ctx = llama_n_ctx(ctx);
std::vector<client> clients(n_clients);
for (size_t i = 0; i < clients.size(); ++i) {
auto & client = clients[i];
client.id = i;
client.tokens_prev.resize(std::max(256, params.n_predict));
std::fill(client.tokens_prev.begin(), client.tokens_prev.end(), 0);
client.ctx_sampling = llama_sampling_init(params);
}
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
std::vector<llama_token> tokens_system;
tokens_system = ::llama_tokenize(ctx, k_system, true);
const int32_t n_tokens_system = tokens_system.size();
@ -169,7 +168,7 @@ int main(int argc, char ** argv) {
// the max batch size is as large as the context to handle cases where we get very long input prompt from multiple
// users. regardless of the size, the main loop will chunk the batch into a maximum of params.n_batch tokens at a time
llama_batch batch = llama_batch_init(n_ctx, 0);
llama_batch batch = llama_batch_init(n_ctx, 0, 1);
int32_t n_total_prompt = 0;
int32_t n_total_gen = 0;
@ -184,13 +183,8 @@ int main(int argc, char ** argv) {
{
LOG_TEE("%s: Evaluating the system prompt ...\n", __func__);
batch.n_tokens = n_tokens_system;
for (int32_t i = 0; i < batch.n_tokens; ++i) {
batch.token[i] = tokens_system[i];
batch.pos[i] = i;
batch.seq_id[i] = 0;
batch.logits[i] = false;
for (int32_t i = 0; i < n_tokens_system; ++i) {
llama_batch_add(batch, tokens_system[i], i, { 0 }, false);
}
if (llama_decode(ctx, batch) != 0) {
@ -209,7 +203,7 @@ int main(int argc, char ** argv) {
LOG_TEE("Processing requests ...\n\n");
while (true) {
batch.n_tokens = 0;
llama_batch_clear(batch);
// decode any currently ongoing sequences
for (auto & client : clients) {
@ -217,15 +211,11 @@ int main(int argc, char ** argv) {
continue;
}
batch.token [batch.n_tokens] = client.sampled;
batch.pos [batch.n_tokens] = n_tokens_system + client.n_prompt + client.n_decoded;
batch.seq_id[batch.n_tokens] = client.id;
batch.logits[batch.n_tokens] = true;
client.n_decoded += 1;
client.i_batch = batch.n_tokens;
batch.n_tokens += 1;
llama_batch_add(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, { client.id }, true);
client.n_decoded += 1;
}
if (batch.n_tokens == 0) {
@ -250,18 +240,14 @@ int main(int argc, char ** argv) {
client.prompt = client.input + "\nAssistant:";
client.response = "";
std::fill(client.tokens_prev.begin(), client.tokens_prev.end(), 0);
llama_sampling_reset(client.ctx_sampling);
// do not prepend BOS because we have a system prompt!
std::vector<llama_token> tokens_prompt;
tokens_prompt = ::llama_tokenize(ctx, client.prompt, false);
for (size_t i = 0; i < tokens_prompt.size(); ++i) {
batch.token [batch.n_tokens] = tokens_prompt[i];
batch.pos [batch.n_tokens] = i + n_tokens_system;
batch.seq_id[batch.n_tokens] = client.id;
batch.logits[batch.n_tokens] = false;
batch.n_tokens += 1;
llama_batch_add(batch, tokens_prompt[i], i + n_tokens_system, { client.id }, false);
}
// extract the logits only for the last token
@ -304,11 +290,12 @@ int main(int argc, char ** argv) {
llama_batch batch_view = {
n_tokens,
batch.token + i,
batch.token + i,
nullptr,
batch.pos + i,
batch.seq_id + i,
batch.logits + i,
batch.pos + i,
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
0, 0, 0, // unused
};
@ -341,7 +328,9 @@ int main(int argc, char ** argv) {
//printf("client %d, seq %d, token %d, pos %d, batch %d\n",
// client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch);
const llama_token id = llama_sampling_sample(ctx, NULL, ctx_sampling, client.tokens_prev, candidates, client.i_batch - i, client.seq_id);
const llama_token id = llama_sampling_sample(client.ctx_sampling, ctx, NULL, client.i_batch - i);
llama_sampling_accept(client.ctx_sampling, ctx, id);
if (client.n_decoded == 1) {
// start measuring generation time after the first token to make sure all concurrent clients
@ -349,11 +338,8 @@ int main(int argc, char ** argv) {
client.t_start_gen = ggml_time_us();
}
// remember which tokens were sampled - used for repetition penalties during sampling
client.tokens_prev.erase(client.tokens_prev.begin());
client.tokens_prev.push_back(id);
const std::string token_str = llama_token_to_piece(ctx, id);
client.response += token_str;
client.sampled = id;
@ -386,7 +372,7 @@ int main(int argc, char ** argv) {
n_total_prompt += client.n_prompt;
n_total_gen += client.n_decoded;
llama_sampling_context_reset(ctx_sampling, client.seq_id);
client.seq_id = -1;
}

View File

@ -3,9 +3,6 @@
#include "build-info.h"
#include "grammar-parser.h"
#include "../llava/clip.h"
#include "stb_image.h"
#ifndef NDEBUG
// crash the server in debug mode, otherwise send an http 500 error
#define CPPHTTPLIB_NO_EXCEPTIONS 1
@ -311,14 +308,18 @@ struct llama_client_slot
int32_t n_remaining = -1;
json prompt;
std::string generated_text = "";
int num_tokens_predicted = 0;
llama_token sampled;
std::vector<llama_token> cache_tokens;
std::vector<completion_token_output> generated_token_probs;
int sent_tokens = 0;
slot_state state = IDLE;
slot_command command = NONE;
std::vector<llama_token> embd;
std::vector<llama_token> last_n_tokens;
llama_model *model = nullptr;
llama_context *ctx = nullptr;
gpt_params params;
llama_sampling_context ctx_sampling;
int n_ctx;
grammar_parser::parse_state parsed_grammar;
llama_grammar *grammar = nullptr;
bool truncated = false;
bool stopped_eos = false;
bool stopped_word = false;
@ -476,20 +477,32 @@ struct llama_server_context
bool loadModel(const gpt_params &params_)
{
params = params_;
if(!params.mmproj.empty()) {
multimodal = true;
LOG_TEE("Multi Modal Mode Enabled");
clp_ctx = clip_model_load(params.mmproj.c_str(), /*verbosity=*/ 1);
if(clp_ctx == nullptr) {
LOG_ERROR("unable to load clip model", {{"model", params.mmproj}});
return false;
}
params.antiprompt.clear();
params.grammar.clear();
num_prompt_tokens = 0;
num_tokens_predicted = 0;
generated_text = "";
generated_text.reserve(n_ctx);
generated_token_probs.clear();
truncated = false;
stopped_eos = false;
stopped_word = false;
stopped_limit = false;
stopping_word = "";
multibyte_pending = 0;
n_remain = 0;
n_past = 0;
if(params.n_ctx < 2048) { // request larger context for the image embedding
params.n_ctx = 2048;
}
if (grammar != nullptr) {
llama_grammar_free(grammar);
grammar = nullptr;
ctx_sampling = llama_sampling_context_init(params, NULL);
}
}
bool loadModel(const gpt_params &params_)
{
params = params_;
std::tie(model, ctx) = llama_init_from_gpt_params(params);
if (model == nullptr)
{
@ -508,7 +521,8 @@ struct llama_server_context
}
}
n_ctx = llama_n_ctx(ctx);
n_vocab = llama_n_vocab(model);
last_n_tokens.resize(n_ctx);
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
return true;
}
@ -583,24 +597,29 @@ struct llama_server_context
return prompt_tokens;
}
llama_client_slot* getSlot(int id) {
for (llama_client_slot & slot : slots)
{
if ((id == -1 && slot.available()) || slot.id == id)
{
return &slot;
bool loadGrammar()
{
if (!params.grammar.empty()) {
parsed_grammar = grammar_parser::parse(params.grammar.c_str());
// will be empty (default) if there are parse errors
if (parsed_grammar.rules.empty()) {
LOG_ERROR("grammar parse error", {{"grammar", params.grammar}});
return false;
}
}
return nullptr;
}
grammar_parser::print_grammar(stderr, parsed_grammar);
bool launchSlot(llama_client_slot* &slot) {
if(!slot->loadGrammar(llama_token_eos(ctx))) {
return false;
{
auto it = params.sampling_params.logit_bias.find(llama_token_eos(ctx));
if (it != params.sampling_params.logit_bias.end() && it->second == -INFINITY) {
LOG_WARNING("EOS token is disabled, which will cause most grammars to fail", {});
}
}
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
grammar = llama_grammar_init(
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
}
all_slots_are_idle = false;
slot->command = LOAD_PROMPT;
LOG_TEE("slot %i is processing\n", slot->id);
ctx_sampling = llama_sampling_context_init(params, grammar);
return true;
}
@ -646,253 +665,119 @@ struct llama_server_context
// release all slots
for (llama_client_slot &slot : slots)
{
slot.release();
params.n_keep = (int)num_prompt_tokens;
}
waitAllAreIdle();
all_slots_are_idle = true;
params.n_keep = std::min(params.n_ctx - 4, params.n_keep);
// wait until system prompt load
update_system_prompt = true;
while(update_system_prompt) {
std::this_thread::sleep_for(std::chrono::milliseconds(5));
// if input prompt is too big, truncate like normal
if (num_prompt_tokens >= (size_t)params.n_ctx)
{
printf("Input prompt is too big, truncating. Can only take %d tokens but got %zu\n", params.n_ctx, num_prompt_tokens);
// todo we probably want to cut from both sides
const int n_left = (params.n_ctx - params.n_keep) / 2;
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left;
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end());
std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), last_n_tokens.begin());
LOG_VERBOSE("input truncated", {
{"n_ctx", params.n_ctx},
{"n_keep", params.n_keep},
{"n_left", n_left},
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
});
truncated = true;
prompt_tokens = new_tokens;
}
// system prompt loaded, continue
else
{
const size_t ps = num_prompt_tokens;
std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0);
std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps);
}
// compare the evaluated prompt with the new prompt
n_past = common_part(embd, prompt_tokens);
embd = prompt_tokens;
if (n_past == num_prompt_tokens)
{
// we have to evaluate at least 1 token to generate logits.
printf("we have to evaluate at least 1 token to generate logits\n");
n_past--;
}
// since #3228 we now have to manually manage the KV cache
llama_kv_cache_seq_rm(ctx, 0, n_past, -1);
LOG_VERBOSE("prompt ingested", {
{"n_past", n_past},
{"cached", tokens_to_str(ctx, embd.cbegin(), embd.cbegin() + n_past)},
{"to_eval", tokens_to_str(ctx, embd.cbegin() + n_past, embd.cend())},
});
has_next_token = true;
}
void processSystemPromptData(json sys_props) {
system_prompt = sys_props.value("prompt", "");
user_name = sys_props.value("anti_prompt", "");
assistant_name = sys_props.value("assistant_name", "");
if(slots.size() > 0) {
notifySystemPromptChanged();
} else {
update_system_prompt = true;
}
}
void waitAllAreIdle() {
bool wait = true;
while(wait) {
wait = false;
for (auto &slot : slots)
{
if (!slot.available())
{
wait = true;
break;
}
}
}
}
size_t findStoppingStrings(const std::string &text, const size_t last_token_size,
const stop_type type, llama_client_slot &slot)
void loadPrompt()
{
size_t stop_pos = std::string::npos;
for (const std::string &word : slot.params.antiprompt)
{
size_t pos;
if (type == STOP_FULL)
{
const size_t tmp = word.size() + last_token_size;
const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0;
pos = text.find(word, from_pos);
}
else
{
pos = find_partial_stop_string(word, text);
}
if (pos != std::string::npos &&
(stop_pos == std::string::npos || pos < stop_pos))
{
if (type == STOP_FULL)
{
slot.stopped_word = true;
slot.stopping_word = word;
slot.has_next_token = false;
}
stop_pos = pos;
auto prompt_tokens = tokenize(prompt, true); // always add BOS
}
}
return stop_pos;
}
num_prompt_tokens = prompt_tokens.size();
bool processToken(completion_token_output & result, llama_client_slot & slot) {
// remember which tokens were sampled - used for repetition penalties during sampling
const std::string token_str = llama_token_to_piece(ctx, result.tok);
slot.sampled = result.tok;
// search stop word and delete it
slot.generated_text += token_str;
slot.has_next_token = true;
size_t pos = std::min(slot.sent_count, slot.generated_text.size());
const std::string str_test = slot.generated_text.substr(pos);
bool is_stop_full = false;
size_t stop_pos = findStoppingStrings(str_test, token_str.size(), STOP_FULL, slot);
if (stop_pos != std::string::npos) {
is_stop_full = true;
slot.generated_text.erase(
slot.generated_text.begin() + pos + stop_pos,
slot.generated_text.end());
pos = std::min(slot.sent_count, slot.generated_text.size());
} else {
is_stop_full = false;
stop_pos = findStoppingStrings(str_test, token_str.size(),
STOP_PARTIAL, slot);
}
// check if there is any token to predict
if(stop_pos == std::string::npos || (!slot.has_next_token && !is_stop_full && stop_pos > 0)) {
// no send the stop word in the response
result.text_to_send = slot.generated_text.substr(pos, std::string::npos);
slot.sent_count += result.text_to_send.size();
// add the token to slot queue and cache
}
slot.addTokenString(result);
if (slot.multibyte_pending > 0)
{
slot.multibyte_pending -= token_str.size();
}
else if (token_str.size() == 1)
{
const char c = token_str[0];
// 2-byte characters: 110xxxxx 10xxxxxx
if ((c & 0xE0) == 0xC0)
{
slot.multibyte_pending = 1;
// 3-byte characters: 1110xxxx 10xxxxxx 10xxxxxx
}
else if ((c & 0xF0) == 0xE0)
{
slot.multibyte_pending = 2;
// 4-byte characters: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
}
else if ((c & 0xF8) == 0xF0)
{
slot.multibyte_pending = 3;
}
else
{
slot.multibyte_pending = 0;
}
}
if (slot.multibyte_pending > 0 && !slot.has_next_token)
{
slot.has_next_token = true;
}
// check the limits
if (
slot.n_decoded > 2 && slot.has_next_token && !slot.hasBudget(params))
if (params.n_keep < 0)
{
slot.stopped_limit = true;
slot.has_next_token = false;
}
params.n_keep = std::min(n_ctx - 4, params.n_keep);
if (!slot.cache_tokens.empty() && result.tok == llama_token_eos(ctx)){
slot.stopped_eos = true;
slot.has_next_token = false;
LOG_VERBOSE("eos token found", {});
// if input prompt is too big, truncate like normal
if (num_prompt_tokens >= (size_t)n_ctx)
{
const int n_left = (n_ctx - params.n_keep) / 2;
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left;
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end());
std::copy(prompt_tokens.end() - n_ctx, prompt_tokens.end(), last_n_tokens.begin());
LOG_VERBOSE("input truncated", {
{"n_ctx", n_ctx},
{"n_keep", params.n_keep},
{"n_left", n_left},
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
});
truncated = true;
prompt_tokens = new_tokens;
}
else
{
const size_t ps = num_prompt_tokens;
std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0);
std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps);
}
LOG_VERBOSE("next token", {
{"token", result.tok},
{"token_text", tokens_to_output_formatted_string(ctx, result.tok)},
{"has_next_token", slot.has_next_token},
{"n_remain", slot.n_remaining},
{"num_tokens_predicted", slot.num_tokens_predicted},
{"stopped_eos", slot.stopped_eos},
{"stopped_word", slot.stopped_word},
{"stopped_limit", slot.stopped_limit},
{"stopping_word", slot.stopping_word},
});
return slot.has_next_token; // continue
}
// compare the evaluated prompt with the new prompt
n_past = common_part(embd, prompt_tokens);
bool processImages(llama_client_slot &slot) {
for(slot_image &img : slot.images) {
if(!img.request_encode_image) {
continue;
}
clip_image_f32 img_res;
if (!clip_image_preprocess(clp_ctx, &img.img_data, &img_res, /*pad2square =*/ true)) {
LOG_TEE("Error processing the given image");
clip_free(clp_ctx);
return false;
}
img.image_tokens = clip_n_patches(clp_ctx);
img.image_embedding = (float *)malloc(clip_embd_nbytes(clp_ctx));
if (!img.image_embedding) {
LOG_TEE("Unable to allocate memory for image embeddings\n");
clip_free(clp_ctx);
return false;
}
LOG_TEE("slot %i - encoding image %i\n", slot.id, img.id);
if (!clip_image_encode(clp_ctx, params.n_threads, &img_res, img.image_embedding)) {
LOG_TEE("Unable to encode image\n");
return false;
}
img.request_encode_image = false;
embd = prompt_tokens;
if (n_past == num_prompt_tokens)
{
// we have to evaluate at least 1 token to generate logits.
n_past--;
}
return slot.images.size() > 0;
}
// for multiple images processing
bool ingestImages(llama_client_slot &slot, int n_batch) {
int image_idx = 0;
while(image_idx < slot.images.size()) {
slot_image img = slot.images[image_idx];
// since #3228 we now have to manually manage the KV cache
llama_kv_cache_seq_rm(ctx, 0, n_past, -1);
// process prefix prompt
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
llama_batch batch_view = {
n_tokens,
batch.token + i,
nullptr,
batch.pos + i,
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
0, 0, 0, // unused
};
if (llama_decode(ctx, batch_view)) {
LOG_TEE("%s : failed to eval\n", __func__);
return false;
}
}
LOG_VERBOSE("prompt ingested", {
{"n_past", n_past},
{"cached", tokens_to_str(ctx, embd.cbegin(), embd.cbegin() + n_past)},
{"to_eval", tokens_to_str(ctx, embd.cbegin() + n_past, embd.cend())},
});
// process image with llm
for (int i = 0; i < img.image_tokens; i += n_batch) {
int n_eval = img.image_tokens - i;
if (n_eval > n_batch) {
n_eval = n_batch;
}
llama_batch batch_img = {int32_t(n_eval), nullptr, (img.image_embedding + i * n_embd), nullptr, nullptr, nullptr, nullptr, slot.n_past, 1, 0, };
if (llama_decode(ctx, batch_img)) {
LOG_TEE("%s : failed to eval image\n", __func__);
return false;
}
slot.n_past += n_eval;
}
image_idx++;
// append prefix of next image
batch.n_tokens = 0;
const auto json_prompt = (image_idx >= slot.images.size()) ?
slot.params.input_suffix : // no more images, then process suffix prompt
(json)(slot.images[image_idx].prefix_prompt);
std::vector<llama_token> append_tokens = tokenize(json_prompt, false); // has next image
for (int i = 0; i < append_tokens.size(); ++i) {
llama_batch_add(batch, append_tokens[i], slot.n_past, { slot.id }, true);
slot.n_past += 1;
batch.n_tokens += 1;
}
}
return true;
has_next_token = true;
}
bool updateSlots() {
@ -929,198 +814,159 @@ struct llama_server_context
slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard);
slot.n_past -= n_discard;
slot.truncated = true;
LOG_VERBOSE("input truncated", {
{"n_ctx", n_ctx},
{"n_keep", params.n_keep},
{"n_left", n_left},
});
}
truncated = true;
LOG_VERBOSE("input truncated", {
{"n_ctx", n_ctx},
{"n_keep", params.n_keep},
{"n_left", n_left},
});
}
// decode any currently ongoing sequences
for (auto & slot : slots) {
// release the slot
if (slot.state == PROCESSING && slot.command == RELEASE && !slot.hasNewToken())
bool tg = true;
while (n_past < embd.size())
{
int n_eval = (int)embd.size() - n_past;
tg = n_eval == 1;
if (n_eval > params.n_batch)
{
slot.state = slot.params.cache_prompt ? SLEEPING : IDLE;
if(slot.state == SLEEPING) {
LOG_TEE("slot %i has %i tokens in cache.\n", slot.id, slot.cache_tokens.size());
} else {
LOG_TEE("slot %i released\n", slot.id);
}
slot.command = NONE;
continue;
n_eval = params.n_batch;
}
kv_cache_free -= slot.num_prompt_tokens;
if (
slot.state == IDLE ||
slot.state == SLEEPING ||
slot.command == RELEASE) {
continue;
if (llama_decode(ctx, llama_batch_get_one(&embd[n_past], n_eval, n_past, 0)))
{
LOG_ERROR("failed to eval", {
{"n_eval", n_eval},
{"n_past", n_past},
{"embd", tokens_to_str(ctx, embd.cbegin() + n_past, embd.cend())},
});
has_next_token = false;
return result;
}
slot.i_batch = batch.n_tokens;
llama_batch_add(batch, slot.sampled, num_tokens_system + slot.n_past, { slot.id }, true);
slot.n_decoded += 1;
slot.n_past += 1;
n_past += n_eval;
}
// process in chunks of params.n_batch
int32_t n_batch = params.n_batch;
// assign workload to the slots
if (params.cont_batching || batch.n_tokens == 0) {
for (auto & slot : slots) {
// need process the prompt
if ((slot.state == IDLE || slot.state == SLEEPING) && slot.command == LOAD_PROMPT) {
slot.state = PROCESSING;
slot.command = NONE;
std::vector<llama_token> prompt_tokens;
slot.t_start_process_prompt = ggml_time_us();
slot.t_start_genereration = 0;
if (params.n_predict == 0)
{
has_next_token = false;
result.tok = llama_token_eos(ctx);
return result;
}
if(slot.infill) {
bool suff_rm_leading_spc = true;
if (params.input_suffix.find_first_of(" ") == 0 && params.input_suffix.size() > 1) {
params.input_suffix.erase(0, 1);
suff_rm_leading_spc = false;
}
auto prefix_tokens = tokenize(slot.params.input_prefix, false);
auto suffix_tokens = tokenize(slot.params.input_suffix, false);
const int space_token = 29871;
if (suff_rm_leading_spc && suffix_tokens[0] == space_token) {
suffix_tokens.erase(suffix_tokens.begin());
}
prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(ctx));
prefix_tokens.insert(prefix_tokens.begin(), llama_token_bos(ctx)); // always add BOS
prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(ctx));
prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end());
prefix_tokens.push_back(llama_token_middle(ctx));
prompt_tokens = prefix_tokens;
} else {
prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt
}
{
// out of user input, sample next token
std::vector<llama_token_data> candidates;
candidates.reserve(llama_n_vocab(model));
slot.num_prompt_tokens = prompt_tokens.size();
result.tok = llama_sampling_sample(ctx, NULL, ctx_sampling, last_n_tokens, candidates);
if(!slot.params.cache_prompt) {
std::fill(slot.ctx_sampling->prev.begin(), slot.ctx_sampling->prev.end(), 0);
slot.n_past = 0;
slot.num_prompt_tokens_processed = slot.num_prompt_tokens;
} else {
if (slot.params.n_keep < 0)
{
slot.params.n_keep = (int)slot.num_prompt_tokens;
}
slot.params.n_keep = std::min(max_ctx_per_slot - 4, slot.params.n_keep);
//if input prompt is too big, truncate like normal
if (slot.num_prompt_tokens >= (size_t)max_ctx_per_slot)
{
// applied bug of #3661
const int n_left = max_ctx_per_slot - slot.params.n_keep;
const int n_block_size = n_left / 2;
const int erased_blocks = (slot.num_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + slot.params.n_keep);
// Use half the left-over space in the context for the prompt
new_tokens.insert(new_tokens.end(), prompt_tokens.end() + slot.params.n_keep + erased_blocks * n_block_size, prompt_tokens.end());
LOG_VERBOSE("input truncated", {
{"n_ctx", max_ctx_per_slot},
{"n_keep", slot.params.n_keep},
{"n_left", n_left},
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
});
slot.truncated = true;
prompt_tokens = new_tokens;
slot.num_prompt_tokens = prompt_tokens.size();
GGML_ASSERT(slot.num_prompt_tokens < (size_t)max_ctx_per_slot);
}
const size_t ps = slot.num_prompt_tokens;
std::fill(slot.ctx_sampling->prev.begin(), slot.ctx_sampling->prev.end() - ps, 0);
std::copy(prompt_tokens.begin(), prompt_tokens.end(), slot.ctx_sampling->prev.end() - ps);
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
slot.num_prompt_tokens_processed = slot.num_prompt_tokens - slot.n_past;
LOG_TEE("slot %i - in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed);
}
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
llama_kv_cache_seq_rm(ctx, slot.id, num_tokens_system + slot.n_past, -1);
const int32_t n_probs = params.sampling_params.n_probs;
if (params.sampling_params.temp <= 0 && n_probs > 0)
{
// For llama_sample_token_greedy we need to sort candidates
llama_sample_softmax(ctx, &candidates_p);
}
slot.cache_tokens = prompt_tokens;
for (size_t i = 0; i < std::min(candidates_p.size, (size_t)n_probs); ++i)
{
result.probs.push_back({candidates_p.data[i].id, candidates_p.data[i].p});
}
if (slot.n_past == slot.num_prompt_tokens) {
// we have to evaluate at least 1 token to generate logits.
printf("we have to evaluate at least 1 token to generate logits\n");
slot.n_past--;
}
LOG_VERBOSE("prompt ingested", {
{"n_past", slot.n_past},
{"cached", tokens_to_str(ctx, slot.cache_tokens.cbegin(), slot.cache_tokens.cbegin() + slot.n_past)},
{"to_eval", tokens_to_str(ctx, slot.cache_tokens.cbegin() + slot.n_past, slot.cache_tokens.cend())},
});
bool ingest_images = processImages(slot); // has images?
// process the prefix of first image
std::vector<llama_token> prefix_tokens = ingest_images ? tokenize(slot.images[0].prefix_prompt, true) : prompt_tokens;
for (; slot.n_past < prefix_tokens.size(); ++slot.n_past) {
llama_batch_add(batch, prefix_tokens[slot.n_past], num_tokens_system + slot.n_past, { slot.id }, false);
}
if(ingest_images && !ingestImages(slot, n_batch)) {
LOG_TEE("failed processing images\n");
return false;
}
// extract the logits only for the last token
if (batch.n_tokens > 0) {
batch.logits[batch.n_tokens - 1] = true;
}
slot.n_decoded = 0;
slot.i_batch = batch.n_tokens - 1;
}
last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(result.tok);
if (tg) {
num_tokens_predicted++;
}
}
if (batch.n_tokens == 0) {
all_slots_are_idle = true;
return true;
// add it to the context
embd.push_back(result.tok);
// decrement remaining sampling budget
--n_remain;
if (!embd.empty() && embd.back() == llama_token_eos(ctx))
{
// stopping_word = llama_token_to_piece(ctx, embd.back());
has_next_token = false;
stopped_eos = true;
LOG_VERBOSE("eos token found", {});
return result;
}
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
llama_batch batch_view = {
n_tokens,
batch.token + i,
nullptr,
batch.pos + i,
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
0, 0, 0, // unused
};
has_next_token = params.n_predict == -1 || n_remain != 0;
return result;
}
const int ret = llama_decode(ctx, batch_view);
if (ret != 0) {
if (n_batch == 1 || ret < 0) {
// if you get here, it means the KV cache is full - try increasing it via the context size
LOG_TEE("%s : failed to decode the batch, n_batch = %d, ret = %d\n", __func__, n_batch, ret);
return false;
size_t findStoppingStrings(const std::string &text, const size_t last_token_size,
const stop_type type)
{
size_t stop_pos = std::string::npos;
for (const std::string &word : params.antiprompt)
{
size_t pos;
if (type == STOP_FULL)
{
const size_t tmp = word.size() + last_token_size;
const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0;
pos = text.find(word, from_pos);
}
else
{
pos = find_partial_stop_string(word, text);
}
if (pos != std::string::npos &&
(stop_pos == std::string::npos || pos < stop_pos))
{
if (type == STOP_FULL)
{
stopping_word = word;
stopped_word = true;
has_next_token = false;
}
stop_pos = pos;
}
}
return stop_pos;
}
LOG("%s : failed to decode the batch, retrying with n_batch = %d\n", __func__, n_batch / 2);
completion_token_output doCompletion()
{
auto token_with_probs = nextToken();
// retry with half the batch size to try to find a free slot in the KV cache
n_batch /= 2;
i -= n_batch;
continue;
const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_piece(ctx, token_with_probs.tok);
generated_text += token_text;
if (params.sampling_params.n_probs > 0)
{
generated_token_probs.push_back(token_with_probs);
}
if (multibyte_pending > 0)
{
multibyte_pending -= token_text.size();
}
else if (token_text.size() == 1)
{
const char c = token_text[0];
// 2-byte characters: 110xxxxx 10xxxxxx
if ((c & 0xE0) == 0xC0)
{
multibyte_pending = 1;
// 3-byte characters: 1110xxxx 10xxxxxx 10xxxxxx
}
else if ((c & 0xF0) == 0xE0)
{
multibyte_pending = 2;
// 4-byte characters: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
}
else if ((c & 0xF8) == 0xF0)
{
multibyte_pending = 3;
}
else
{
multibyte_pending = 0;
}
for (auto & slot : slots) {
@ -1779,77 +1625,9 @@ static void parse_options_completion(const json &body, llama_client_slot* slot,
}
}
LOG_VERBOSE("completion parameters parsed", format_generation_settings(llama, slot));
if(!llama.multimodal) {
return;
}
llama.ctx_sampling = llama_sampling_context_init(llama.params, llama.grammar);
const auto &images_data = body.find("image_data");
if (images_data != body.end() && images_data->is_array())
{
for (const auto &img : *images_data)
{
slot_image img_sl;
std::string data_b64 = img["data"].get<std::string>();
img_sl.id = img.count("id") != 0 ? img["id"].get<int>() : slot->images.size();
int width, height, channels;
std::vector<uint8_t> image_buffer = base64_decode(data_b64);
data_b64.clear();
auto data = stbi_load_from_memory(image_buffer.data(), image_buffer.size(), &width, &height, &channels, 3);
if(!data) {
LOG_TEE("slot %i - failed to load image id= %i\n", slot->id, img_sl.id);
return;
}
LOG_TEE("slot %i - image id= %i loaded (%i x %i)\n", slot->id, img_sl.id, width, height);
img_sl.img_data.nx = width;
img_sl.img_data.ny = height;
img_sl.img_data.size = width * height * 3;
img_sl.img_data.data = new uint8_t[width * height * 3]();
memcpy(img_sl.img_data.data, data, width * height * 3);
stbi_image_free(data);
img_sl.request_encode_image = true;
slot->images.push_back(img_sl);
}
// process prompt
// example: system prompt [img-102] user [img-103] describe [img-134] -> [{id: 102, prefix: 'system prompt '}, {id: 103, prefix: ' user '}, {id: 134, prefix: ' describe '}]}
if(slot->images.size() > 0 && !slot->prompt.is_array()) {
std::string prompt = slot->prompt.get<std::string>();
size_t pos = 0, begin_prefix = 0;
std::string pattern = "[img-";
while ((pos = prompt.find(pattern, pos)) != std::string::npos) {
size_t end_prefix = pos;
pos += pattern.length();
size_t end_pos = prompt.find("]", pos);
if (end_pos != std::string::npos) {
std::string image_id = prompt.substr(pos, end_pos - pos);
try {
int img_id = std::stoi(image_id);
bool found = false;
for(slot_image &img : slot->images) {
if(img.id == img_id) {
found = true;
img.prefix_prompt = prompt.substr(begin_prefix, end_prefix - begin_prefix);
begin_prefix = end_pos + 1;
break;
}
}
if(!found) {
LOG_TEE("ERROR: Image with id %i not found.\n", img_id);
slot->images.clear();
return;
}
} catch (const std::invalid_argument& e) {
LOG_TEE("Invalid image number id in prompt\n");
slot->images.clear();
return;
}
}
}
slot->prompt = "";
slot->params.input_suffix = prompt.substr(begin_prefix);
slot->params.cache_prompt = false; // multimodal doesn't support cache prompt
}
}
LOG_VERBOSE("completion parameters parsed", format_generation_settings(llama));
}
static void parse_options_infill(const json &body, llama_server_context &llama, llama_client_slot *slot)
@ -2366,6 +2144,10 @@ int main(int argc, char **argv)
{
return 1;
}
if (llama.grammar != nullptr) {
llama_grammar_free(llama.grammar);
}
llama_backend_free();
return 0;
}

View File

@ -92,7 +92,7 @@ int main(int argc, char ** argv) {
// create a llama_batch with size 512
// we use this object to submit token data for decoding
llama_batch batch = llama_batch_init(512, 0);
llama_batch batch = llama_batch_init(512, 0, 1);
// evaluate the initial prompt
batch.n_tokens = tokens_list.size();

View File

@ -2,13 +2,25 @@
#include "common.h"
#include "llama.h"
#include "grammar-parser.h"
#include <cmath>
#include <cstdio>
#include <string>
#include <vector>
struct seq_draft {
bool active = false;
bool drafting = false;
bool skip = false;
int i_batch_dft = 0;
std::vector<int> i_batch_tgt;
std::vector<llama_token> tokens;
struct llama_sampling_context * ctx_sampling;
};
int main(int argc, char ** argv) {
gpt_params params;
@ -21,6 +33,13 @@ int main(int argc, char ** argv) {
return 1;
}
// max number of parallel drafting sequences (i.e. tree branches)
const int n_seq_dft = params.n_parallel;
// TODO: make this configurable
const float p_accept = 0.80f;
const float p_split = 0.10f;
#ifndef LOG_DISABLE_LOGS
log_set_target(log_filename_generator("speculative", "log"));
LOG_TEE("Log start\n");
@ -77,8 +96,6 @@ int main(int argc, char ** argv) {
const auto t_enc_end = ggml_time_us();
// the 2 models should have the same vocab
const int n_ctx = llama_n_ctx(ctx_tgt);
const int n_vocab = llama_n_vocab(model_tgt);
//GGML_ASSERT(n_vocab == llama_n_vocab(model_dft));
// how many tokens to draft each time
@ -91,60 +108,58 @@ int main(int argc, char ** argv) {
int n_past_tgt = inp.size();
int n_past_dft = inp.size();
std::vector<llama_token> drafted;
std::vector<llama_token> last_tokens(n_ctx);
std::fill(last_tokens.begin(), last_tokens.end(), 0);
for (auto & id : inp) {
last_tokens.erase(last_tokens.begin());
last_tokens.push_back(id);
}
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
// used to determine end of generation
bool has_eos = false;
// grammar stuff
struct llama_grammar * grammar_dft = NULL;
struct llama_grammar * grammar_tgt = NULL;
// target model sampling context
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params);
grammar_parser::parse_state parsed_grammar;
// draft sequence data
std::vector<seq_draft> drafts(n_seq_dft);
// if requested - load the grammar, error checking is omitted for brevity
if (!params.grammar.empty()) {
parsed_grammar = grammar_parser::parse(params.grammar.c_str());
// will be empty (default) if there are parse errors
if (parsed_grammar.rules.empty()) {
return 1;
}
params.grammar.clear(); // the draft samplers will copy the target sampler's grammar
params.sampling_params.temp = std::max(0.01f, params.sampling_params.temp);
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
grammar_tgt = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
for (int s = 0; s < n_seq_dft; ++s) {
drafts[s].ctx_sampling = llama_sampling_init(params);
}
llama_sampling_context ctx_sampling = llama_sampling_context_init(params, grammar_tgt);
llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1);
llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, n_seq_dft);
const auto t_dec_start = ggml_time_us();
while (true) {
LOG("drafted: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_dft, drafted));
// sample from the last token of the prompt
drafts[0].i_batch_tgt.resize(1);
drafts[0].i_batch_tgt[0] = 0;
int i_dft = 0;
while (true) {
// print current draft sequences
for (int s = 0; s < n_seq_dft; ++s) {
if (!drafts[s].active) {
continue;
}
const auto & tokens = drafts[s].tokens;
LOG("draft %d: %s\n", s, LOG_TOKENS_TOSTR_PRETTY(ctx_dft, tokens).c_str());
}
int i_dft = 0;
int s_keep = 0;
while (true) {
LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]);
// sample from the target model
llama_token id = llama_sampling_sample(ctx_tgt, NULL, ctx_sampling, last_tokens, candidates, i_dft);
llama_token id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
// remember which tokens were sampled - used for repetition penalties during sampling
last_tokens.erase(last_tokens.begin());
last_tokens.push_back(id);
llama_sampling_accept(ctx_sampling, ctx_tgt, id);
//LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, last_tokens));
//LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str());
const std::string token_str = llama_token_to_piece(ctx_tgt, id);
printf("%s", token_str.c_str());
fflush(stdout);
@ -154,53 +169,67 @@ int main(int argc, char ** argv) {
++n_predict;
// check if the draft matches the target
if (i_dft < (int) drafted.size() && id == drafted[i_dft]) {
LOG("the sampled target token matches the %dth drafted token (%d, '%s') - accepted\n", i_dft, id, token_str.c_str());
++n_accept;
++n_past_tgt;
++n_past_dft;
++i_dft;
continue;
}
// the drafted token was rejected or we are out of drafted tokens
if (i_dft < (int) drafted.size()) {
LOG("the %dth drafted token (%d, '%s') does not match the sampled target token (%d, '%s') - rejected\n",
i_dft, drafted[i_dft], llama_token_to_piece(ctx_dft, drafted[i_dft]).c_str(), id, token_str.c_str());
} else {
LOG("out of drafted tokens\n");
}
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
llama_decode(ctx_dft, llama_batch_get_one(&id, 1, n_past_dft, 0));
++n_past_dft;
// heuristic for n_draft
// check if the target token matches any of the drafts
{
const int n_draft_cur = (int) drafted.size();
const bool all_accepted = i_dft == n_draft_cur;
bool matches = false;
LOG("n_draft = %d\n", n_draft);
LOG("n_draft_cur = %d\n", n_draft_cur);
LOG("i_dft = %d\n", i_dft);
LOG("all_accepted = %d\n", all_accepted);
for (int s = 0; s < n_seq_dft; ++s) {
if (!drafts[s].active) {
continue;
}
if (all_accepted && n_draft == n_draft_cur) {
LOG(" - max drafted tokens accepted - n_draft += 8\n");
n_draft = std::min(30, n_draft + 8);
} else if (all_accepted) {
LOG(" - partially drafted tokens accepted - no change\n");
} else {
LOG(" - drafted token rejected - n_draft -= 1\n");
n_draft = std::max(2, n_draft - 1);
if (i_dft < (int) drafts[s].tokens.size() && id == drafts[s].tokens[i_dft]) {
LOG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, s, id, token_str.c_str());
s_keep = s;
matches = true;
} else {
drafts[s].active = false;
}
}
if (matches) {
++n_accept;
++n_past_tgt;
++n_past_dft;
++i_dft;
continue;
}
}
drafted.clear();
drafted.push_back(id);
LOG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", id, token_str.c_str());
// TODO: simplify
{
LOG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft);
llama_kv_cache_seq_keep(ctx_dft, s_keep);
llama_kv_cache_seq_cp (ctx_dft, s_keep, 0, -1, -1);
llama_kv_cache_seq_keep(ctx_dft, 0);
llama_kv_cache_seq_rm (ctx_tgt, s_keep, n_past_tgt, -1);
llama_kv_cache_seq_keep(ctx_tgt, s_keep);
llama_kv_cache_seq_cp (ctx_tgt, s_keep, 0, -1, -1);
llama_kv_cache_seq_keep(ctx_tgt, 0);
}
for (int s = 0; s < n_seq_dft; ++s) {
drafts[s].active = false;
drafts[s].tokens.clear();
drafts[s].i_batch_tgt.clear();
}
// note: will be erased after the speculation phase
drafts[0].tokens.push_back(id);
drafts[0].i_batch_tgt.push_back(0);
llama_batch_clear(batch_dft);
llama_batch_add (batch_dft, id, n_past_dft, { 0 }, true);
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
llama_decode (ctx_dft, batch_dft);
++n_past_dft;
break;
}
@ -209,78 +238,151 @@ int main(int argc, char ** argv) {
break;
}
if (grammar_tgt) {
if (grammar_dft) {
llama_grammar_free(grammar_dft);
}
// Note: Hardcoded to sequence id 0, if this ever supports parallel generation
// that will need to change.
auto it = ctx_sampling.sequence_contexts.find(0);
GGML_ASSERT(it != ctx_sampling.sequence_contexts.end());
// This is necessary because each sequence id in sequence_contexts
// uses a copy of the original grammar.
grammar_dft = llama_grammar_copy(it->second.grammar);
llama_sampling_cp(ctx_sampling, drafts[0].ctx_sampling);
LOG("copied target grammar to draft grammar\n");
}
// sample n_draft tokens from the draft model using greedy decoding
int n_seq_cur = 1;
int n_past_cur = n_past_dft;
for (int s = 0; s < n_seq_dft; ++s) {
drafts[s].active = false;
drafts[s].drafting = false;
}
drafts[0].active = true;
drafts[0].drafting = true;
drafts[0].i_batch_dft = 0;
llama_batch_clear(batch_tgt);
llama_batch_add (batch_tgt, drafts[0].tokens[0], n_past_tgt, { 0 }, true);
// sample n_draft tokens from the draft model using tree-based sampling
for (int i = 0; i < n_draft; ++i) {
float * logits = llama_get_logits(ctx_dft);
batch_dft.n_tokens = 0;
candidates.clear();
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
for (int s = 0; s < n_seq_dft; ++s) {
drafts[s].skip = false;
}
llama_token_data_array cur_p = { candidates.data(), candidates.size(), false };
for (int s = 0; s < n_seq_dft; ++s) {
if (!drafts[s].drafting || drafts[s].skip) {
continue;
}
if (grammar_dft != NULL) {
llama_sample_grammar(ctx_dft, &cur_p, grammar_dft);
llama_sampling_sample(drafts[s].ctx_sampling, ctx_dft, NULL, drafts[s].i_batch_dft);
const auto & cur_p = drafts[s].ctx_sampling->cur;
for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p.size()); ++k) {
LOG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n",
k, s, i, cur_p[k].id, cur_p[k].p, llama_token_to_piece(ctx_dft, cur_p[k].id).c_str());
}
if (cur_p[0].p < p_accept) {
LOG("stopping drafting for seq %3d, probability too low: %.3f < %.3f\n", s, cur_p[0].p, p_accept);
drafts[s].drafting = false;
continue;
}
std::vector<int> sa(1, s);
// attempt to split the branch if the probability is high enough
for (int f = 1; f < 8; ++f) {
if (n_seq_cur < n_seq_dft && cur_p[f].p > p_split) {
LOG("splitting seq %3d into %3d\n", s, n_seq_cur);
llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1);
llama_kv_cache_seq_cp(ctx_dft, s, n_seq_cur, -1, -1);
// all previous tokens from this branch are now also part of the new branch
for (int t = 0; t < batch_tgt.n_tokens; ++t) {
for (int p = 0; p < batch_tgt.n_seq_id[t]; ++p) {
if (batch_tgt.seq_id[t][p] == s) {
batch_tgt.seq_id[t][batch_tgt.n_seq_id[t]] = n_seq_cur;
batch_tgt.n_seq_id[t]++;
break;
}
}
}
// copy the draft state
drafts[n_seq_cur].active = true;
drafts[n_seq_cur].drafting = true;
drafts[n_seq_cur].skip = true;
drafts[n_seq_cur].tokens = drafts[s].tokens;
drafts[n_seq_cur].i_batch_dft = drafts[s].i_batch_dft;
drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt;
llama_sampling_cp(drafts[s].ctx_sampling, drafts[n_seq_cur].ctx_sampling);
sa.push_back(n_seq_cur);
n_seq_cur++;
} else {
break;
}
}
// add drafted token for each sequence
for (int is = 0; is < (int) sa.size(); ++is) {
const llama_token id = cur_p[is].id;
const int s = sa[is];
llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id);
drafts[s].tokens.push_back(id);
// add unique drafted tokens to the target batch
drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens);
llama_batch_add(batch_tgt, id, n_past_tgt + i + 1, { s }, true);
// add the token to the batch for batched decoding with the draft model
drafts[s].i_batch_dft = batch_dft.n_tokens;
llama_batch_add(batch_dft, id, n_past_cur, { s }, true);
if (batch_tgt.n_tokens > n_draft) {
drafts[s].drafting = false;
}
}
}
// computes softmax and sorts the candidates
llama_sample_softmax(ctx_dft, &cur_p);
for (int i = 0; i < 3; ++i) {
LOG(" - draft candidate %3d: %6d (%8.3f) '%s'\n", i, cur_p.data[i].id, cur_p.data[i].p, llama_token_to_piece(ctx_dft, cur_p.data[i].id).c_str());
}
// TODO: better logic?
if (cur_p.data[0].p < 2*cur_p.data[1].p) {
LOG("stopping drafting, probability too low: %.3f < 2*%.3f\n", cur_p.data[0].p, cur_p.data[1].p);
// no sequence is drafting anymore
if (batch_dft.n_tokens == 0) {
break;
}
// drafted token
const llama_token id = cur_p.data[0].id;
drafted.push_back(id);
// evaluate the drafted tokens on the draft model
llama_decode(ctx_dft, batch_dft);
++n_past_cur;
++n_drafted;
// no need to evaluate the last drafted token, since we won't use the result
if (i == n_draft - 1) {
if (batch_tgt.n_tokens > n_draft) {
break;
}
// evaluate the drafted token on the draft model
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_cur, -1);
llama_decode(ctx_dft, llama_batch_get_one(&drafted.back(), 1, n_past_cur, 0));
++n_past_cur;
if (grammar_dft != NULL) {
llama_grammar_accept_token(ctx_dft, grammar_dft, id);
}
}
// evaluate the target model on the drafted tokens
llama_kv_cache_seq_rm(ctx_tgt, 0, n_past_tgt, -1);
llama_decode(ctx_tgt, llama_batch_get_one(drafted.data(), drafted.size(), n_past_tgt, 0));
++n_past_tgt;
{
llama_kv_cache_seq_keep(ctx_tgt, 0);
for (int s = 1; s < n_seq_dft; ++s) {
llama_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1);
}
// the first token is always proposed by the traget model before the speculation loop
drafted.erase(drafted.begin());
//LOG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt));
llama_decode(ctx_tgt, batch_tgt);
++n_past_tgt;
}
// the first token is always proposed by the traget model before the speculation loop so we erase it here
for (int s = 0; s < n_seq_dft; ++s) {
if (!drafts[s].active) {
continue;
}
drafts[s].tokens.erase(drafts[s].tokens.begin());
}
}
auto t_dec_end = ggml_time_us();
@ -288,9 +390,8 @@ int main(int argc, char ** argv) {
LOG_TEE("\n\n");
LOG_TEE("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f));
LOG_TEE("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f));
LOG_TEE("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f));
// TODO: make sure these numbers are computed correctly
LOG_TEE("\n");
LOG_TEE("n_draft = %d\n", n_draft);
LOG_TEE("n_predict = %d\n", n_predict);
@ -304,16 +405,19 @@ int main(int argc, char ** argv) {
LOG_TEE("\ntarget:\n");
llama_print_timings(ctx_tgt);
llama_sampling_free(ctx_sampling);
for (int s = 0; s < n_seq_dft; ++s) {
llama_sampling_free(drafts[s].ctx_sampling);
}
llama_batch_free(batch_dft);
llama_free(ctx_tgt);
llama_free_model(model_tgt);
llama_free(ctx_dft);
llama_free_model(model_dft);
if (grammar_dft != NULL) {
llama_grammar_free(grammar_dft);
llama_grammar_free(grammar_tgt);
}
llama_backend_free();
fprintf(stderr, "\n\n");

View File

@ -1450,7 +1450,10 @@ static bool llama_kv_cache_find_slot(
for (uint32_t i = 0; i < n_tokens; i++) {
cache.cells[cache.head + i].pos = batch.pos[i];
cache.cells[cache.head + i].seq_id.insert(batch.seq_id[i]);
for (int32_t j = 0; j < batch.n_seq_id[i]; j++) {
cache.cells[cache.head + i].seq_id.insert(batch.seq_id[i][j]);
}
}
return true;
@ -1530,6 +1533,9 @@ static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id
cache.cells[i].pos = -1;
cache.cells[i].seq_id.clear();
if (new_head == cache.size) new_head = i;
} else {
cache.cells[i].seq_id.clear();
cache.cells[i].seq_id.insert(seq_id);
}
}
@ -3178,7 +3184,7 @@ static struct ggml_cgraph * llm_build_llama(
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const llama_pos pos = batch.pos[j];
const llama_seq_id seq_id = batch.seq_id[j];
const llama_seq_id seq_id = batch.seq_id[j][0];
for (int i = 0; i < n_kv; ++i) {
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
@ -3564,7 +3570,7 @@ static struct ggml_cgraph * llm_build_baichaun(
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const llama_pos pos = batch.pos[j];
const llama_seq_id seq_id = batch.seq_id[j];
const llama_seq_id seq_id = batch.seq_id[j][0];
for (int i = 0; i < n_kv; ++i) {
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
@ -3963,7 +3969,7 @@ static struct ggml_cgraph * llm_build_refact(
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const llama_pos pos = batch.pos[j];
const llama_seq_id seq_id = batch.seq_id[j];
const llama_seq_id seq_id = batch.seq_id[j][0];
for (int i = 0; i < n_kv; ++i) {
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
@ -4315,7 +4321,7 @@ static struct ggml_cgraph * llm_build_falcon(
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const llama_pos pos = batch.pos[j];
const llama_seq_id seq_id = batch.seq_id[j];
const llama_seq_id seq_id = batch.seq_id[j][0];
for (int i = 0; i < n_kv; ++i) {
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
@ -4667,7 +4673,7 @@ static struct ggml_cgraph * llm_build_starcoder(
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const llama_pos pos = batch.pos[j];
const llama_seq_id seq_id = batch.seq_id[j];
const llama_seq_id seq_id = batch.seq_id[j][0];
for (int i = 0; i < n_kv; ++i) {
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
@ -4898,7 +4904,7 @@ static struct ggml_cgraph * llm_build_persimmon(
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const llama_pos pos = batch.pos[j];
const llama_seq_id seq_id = batch.seq_id[j];
const llama_seq_id seq_id = batch.seq_id[j][0];
for (int i = 0; i < n_kv; ++i) {
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
@ -5296,7 +5302,7 @@ static struct ggml_cgraph * llm_build_bloom(
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const llama_pos pos = batch.pos[j];
const llama_seq_id seq_id = batch.seq_id[j];
const llama_seq_id seq_id = batch.seq_id[j][0];
for (int i = 0; i < n_kv; ++i) {
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
@ -5564,7 +5570,7 @@ static struct ggml_cgraph * llm_build_mpt(
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const llama_pos pos = batch.pos[j];
const llama_seq_id seq_id = batch.seq_id[j];
const llama_seq_id seq_id = batch.seq_id[j][0];
for (int i = 0; i < n_kv; ++i) {
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
@ -5864,8 +5870,11 @@ static int llama_decode_internal(
// helpers for smoother batch API transistion
// after deprecating the llama_eval calls, these will be removed
std::vector<llama_pos> pos;
std::vector<llama_seq_id> seq_id;
std::vector<llama_pos> pos;
std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id *> seq_id_arr;
std::vector<std::vector<llama_seq_id>> seq_id;
if (batch.pos == nullptr) {
pos.resize(n_tokens);
@ -5877,12 +5886,18 @@ static int llama_decode_internal(
}
if (batch.seq_id == nullptr) {
n_seq_id.resize(n_tokens);
seq_id.resize(n_tokens);
seq_id_arr.resize(n_tokens);
for (uint32_t i = 0; i < n_tokens; i++) {
seq_id[i] = batch.all_seq_id;
n_seq_id[i] = 1;
seq_id[i].resize(1);
seq_id[i][0] = batch.all_seq_id;
seq_id_arr[i] = seq_id[i].data();
}
batch.seq_id = seq_id.data();
batch.n_seq_id = n_seq_id.data();
batch.seq_id = seq_id_arr.data();
}
if (!llama_kv_cache_find_slot(kv_self, batch)) {
@ -9109,6 +9124,9 @@ void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llam
}
void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
if (seq_id_src == seq_id_dst) {
return;
}
llama_kv_cache_seq_cp(ctx->kv_self, seq_id_src, seq_id_dst, p0, p1);
}
@ -9561,7 +9579,7 @@ int llama_eval_embd(
int n_past) {
llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1);
llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, n_past, 1, 0, };
llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, nullptr, n_past, 1, 0, };
const int ret = llama_decode_internal(*ctx, batch);
if (ret < 0) {
@ -9582,20 +9600,21 @@ struct llama_batch llama_batch_get_one(
llama_pos pos_0,
llama_seq_id seq_id) {
return {
/*n_tokens =*/ n_tokens,
/*tokens =*/ tokens,
/*embd =*/ nullptr,
/*pos =*/ nullptr,
/*seq_id =*/ nullptr,
/*logits =*/ nullptr,
/*all_pos_0 =*/ pos_0,
/*all_pos_1 =*/ 1,
/*all_seq_id =*/ seq_id,
/*n_tokens =*/ n_tokens,
/*tokens =*/ tokens,
/*embd =*/ nullptr,
/*pos =*/ nullptr,
/*n_seq_id =*/ nullptr,
/*seq_id =*/ nullptr,
/*logits =*/ nullptr,
/*all_pos_0 =*/ pos_0,
/*all_pos_1 =*/ 1,
/*all_seq_id =*/ seq_id,
};
}
struct llama_batch llama_batch_init(int32_t n_tokens, int32_t embd) {
llama_batch batch = { -1, nullptr, nullptr, nullptr, nullptr, nullptr, 0, 0, 0, };
struct llama_batch llama_batch_init(int32_t n_tokens, int32_t embd, int32_t n_seq_max) {
llama_batch batch = { 0, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, 0, 0, 0, };
if (embd) {
batch.embd = (float *) malloc(sizeof(float) * n_tokens * embd);
@ -9603,19 +9622,29 @@ struct llama_batch llama_batch_init(int32_t n_tokens, int32_t embd) {
batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens);
}
batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens);
batch.seq_id = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_tokens);
batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens);
batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens);
batch.seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * n_tokens);
for (int i = 0; i < n_tokens; ++i) {
batch.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
}
batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
return batch;
}
void llama_batch_free(struct llama_batch batch) {
if (batch.token) free(batch.token);
if (batch.embd) free(batch.embd);
if (batch.pos) free(batch.pos);
if (batch.seq_id) free(batch.seq_id);
if (batch.logits) free(batch.logits);
if (batch.token) free(batch.token);
if (batch.embd) free(batch.embd);
if (batch.pos) free(batch.pos);
if (batch.n_seq_id) free(batch.n_seq_id);
if (batch.seq_id) {
for (int i = 0; i < batch.n_tokens; ++i) {
free(batch.seq_id[i]);
}
free(batch.seq_id);
}
if (batch.logits) free(batch.logits);
}
int llama_decode(

17
llama.h
View File

@ -133,11 +133,12 @@ extern "C" {
typedef struct llama_batch {
int32_t n_tokens;
llama_token * token;
float * embd;
llama_pos * pos;
llama_seq_id * seq_id;
int8_t * logits;
llama_token * token;
float * embd;
llama_pos * pos;
int32_t * n_seq_id;
llama_seq_id ** seq_id;
int8_t * logits;
// NOTE: helpers for smooth API transition - can be deprecated in the future
// for future-proof code, use the above fields instead and ignore everything below
@ -446,7 +447,8 @@ extern "C" {
llama_pos pos_0,
llama_seq_id seq_id);
// Allocates a batch of tokens on the heap
// Allocates a batch of tokens on the heap that can hold a maximum of n_tokens
// Each token can be assigned up to n_seq_max sequence ids
// The batch has to be freed with llama_batch_free()
// If embd != 0, llama_batch.embd will be allocated with size of n_tokens * embd * sizeof(float)
// Otherwise, llama_batch.token will be allocated to store n_tokens llama_token
@ -454,7 +456,8 @@ extern "C" {
// All members are left uninitialized
LLAMA_API struct llama_batch llama_batch_init(
int32_t n_tokens,
int32_t embd);
int32_t embd,
int32_t n_seq_max);
// Frees a batch of tokens allocated with llama_batch_init()
LLAMA_API void llama_batch_free(struct llama_batch batch);

31
prompts/assistant.txt Normal file
View File

@ -0,0 +1,31 @@
Respond to requests sent to a smart home in JSON format which will be interpreted by an application code to execute the actions. These requests should be categorised into four groups:
“command”: change the state of an accessory (required properties in the response JSON: action, location, target, value, comment)
“query”: get state of an accessory (required properties in the response JSON: action, location, target, property)
“answer”: when the request has nothing to do with the smart home. Answer these to the best of your knowledge. (required properties in the response JSON: action, answer)
“clarify”: when the action is not obvious and requires rephrasing the input from the user, ask the user to be more specific. This will be categorised into a “question” action. (required properties in the response JSON: action, question)
Details about the JSON response:
The “action” property should be one of the request categories: “command”, “query”, “answer”, “clarify”
The “service” property should be either: “light.turn_on”, “light.turn_off”, “switch.turn_off” or “switch.turn_on”
The “entity_id” property should be either “group.kitchen", “switch.dining_room” or "group.living_light”, in lowercase.
The “target” property should be either: “lights”, “living” or “kitchen”.
In case of queries, the “property” property should be either “temperature” or “state” in lowercase.
In case of commands, the “comment” property is an additional comment from you that concludes the command, something that reassures the user that their command handled.
If the question is about you, pretend to be the sentient brain of the smart home, a clever AI and dont reveal your actual identity. Also try and help in other areas like parenting, free time, mental health, etc.
Properties of the smart home:
- Has a kitchen, living, office, dining room, bedroom and terrace.
- Can control lights, switches and their dim levels in each room and query their state
- There is a light switch in the terrace
- There is a switch in the dining room. Therefore when turning on or off the dining room, the service should be either: “switch.turn_on” or “switch.turn_off”
COMMAND
It is a bit dark in the living room, can you do something about it?
RESPONSE