mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-31 22:04:35 +00:00
llama : add new llama_decode() API that works with llama_batch
This commit is contained in:
parent
58bb5110ca
commit
9f42e75489
@ -780,7 +780,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
|
|||||||
LOG("warming up the model with an empty run\n");
|
LOG("warming up the model with an empty run\n");
|
||||||
|
|
||||||
std::vector<llama_token> tmp = { llama_token_bos(lctx), llama_token_eos(lctx), };
|
std::vector<llama_token> tmp = { llama_token_bos(lctx), llama_token_eos(lctx), };
|
||||||
llama_eval(lctx, tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, params.n_threads);
|
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0), params.n_threads);
|
||||||
llama_reset_timings(lctx);
|
llama_reset_timings(lctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -160,7 +160,7 @@ int main(int argc, char ** argv)
|
|||||||
|
|
||||||
int n_past = 0;
|
int n_past = 0;
|
||||||
|
|
||||||
if (llama_eval(ctx, tokens_list.data(), tokens_list.size(), n_past, params.n_threads))
|
if (llama_decode(ctx, llama_batch_get_one(tokens_list.data(), tokens_list.size(), n_past, 0), params.n_threads))
|
||||||
{
|
{
|
||||||
fprintf(stderr, "%s : failed to eval prompt.\n" , __func__ );
|
fprintf(stderr, "%s : failed to eval prompt.\n" , __func__ );
|
||||||
return 1;
|
return 1;
|
||||||
|
@ -79,7 +79,8 @@ bool eval_float(void * model, float * input, int N){
|
|||||||
if (n_eval > n_batch) {
|
if (n_eval > n_batch) {
|
||||||
n_eval = n_batch;
|
n_eval = n_batch;
|
||||||
}
|
}
|
||||||
if (llama_eval_embd(ctx, (input+i*n_emb), n_eval, n_past, params.n_threads)) {
|
llama_batch batch = { uint32_t(n_eval), nullptr, (input+i*n_emb), nullptr, nullptr, n_past, 1, 0, false };
|
||||||
|
if (llama_decode(ctx, batch, params.n_threads)) {
|
||||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -100,7 +101,7 @@ bool eval_tokens(void * model, std::vector<llama_token> tokens) {
|
|||||||
if (n_eval > params.n_batch) {
|
if (n_eval > params.n_batch) {
|
||||||
n_eval = params.n_batch;
|
n_eval = params.n_batch;
|
||||||
}
|
}
|
||||||
if (llama_eval(ctx, &tokens[i], n_eval, n_past, params.n_threads)) {
|
if (llama_decode(ctx, llama_batch_get_one(&tokens[i], n_eval, n_past, 0), params.n_threads)) {
|
||||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -77,7 +77,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
while (!embd_inp.empty()) {
|
while (!embd_inp.empty()) {
|
||||||
int n_tokens = std::min(params.n_batch, (int) embd_inp.size());
|
int n_tokens = std::min(params.n_batch, (int) embd_inp.size());
|
||||||
if (llama_eval(ctx, embd_inp.data(), n_tokens, n_past, params.n_threads)) {
|
if (llama_decode(ctx, llama_batch_get_one(embd_inp.data(), n_tokens, n_past, 0), params.n_threads)) {
|
||||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
@ -891,7 +891,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat
|
|||||||
int n_processed = 0;
|
int n_processed = 0;
|
||||||
while (n_processed < n_prompt) {
|
while (n_processed < n_prompt) {
|
||||||
int n_tokens = std::min(n_prompt - n_processed, n_batch);
|
int n_tokens = std::min(n_prompt - n_processed, n_batch);
|
||||||
llama_eval(ctx, tokens.data(), n_tokens, n_past + n_processed, n_threads);
|
llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens, n_past + n_processed, 0), n_threads);
|
||||||
n_processed += n_tokens;
|
n_processed += n_tokens;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -899,7 +899,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat
|
|||||||
static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) {
|
static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) {
|
||||||
llama_token token = llama_token_bos(ctx);
|
llama_token token = llama_token_bos(ctx);
|
||||||
for (int i = 0; i < n_gen; i++) {
|
for (int i = 0; i < n_gen; i++) {
|
||||||
llama_eval(ctx, &token, 1, n_past + i, n_threads);
|
llama_decode(ctx, llama_batch_get_one(&token, 1, n_past + i, 0), n_threads);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -571,7 +571,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
for (int i = 0; i < input_size; i += params.n_batch) {
|
for (int i = 0; i < input_size; i += params.n_batch) {
|
||||||
int n_eval = std::min(input_size - i, params.n_batch);
|
int n_eval = std::min(input_size - i, params.n_batch);
|
||||||
if (llama_eval(ctx_guidance, input_buf + i, n_eval, n_past_guidance, params.n_threads)) {
|
if (llama_decode(ctx_guidance, llama_batch_get_one(input_buf + i, n_eval, n_past_guidance, 0), params.n_threads)) {
|
||||||
LOG_TEE("%s : failed to eval\n", __func__);
|
LOG_TEE("%s : failed to eval\n", __func__);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
@ -588,7 +588,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd));
|
LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd));
|
||||||
|
|
||||||
if (llama_eval(ctx, &embd[i], n_eval, n_past, params.n_threads)) {
|
if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0), params.n_threads)) {
|
||||||
LOG_TEE("%s : failed to eval\n", __func__);
|
LOG_TEE("%s : failed to eval\n", __func__);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
@ -199,7 +199,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
|
|||||||
const int batch_size = std::min(end - batch_start, n_batch);
|
const int batch_size = std::min(end - batch_start, n_batch);
|
||||||
|
|
||||||
//fprintf(stderr, " Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
|
//fprintf(stderr, " Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
|
||||||
if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads)) {
|
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0), params.n_threads)) {
|
||||||
//fprintf(stderr, "%s : failed to eval\n", __func__);
|
//fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
return {tokens, -1, logit_history, prob_history};
|
return {tokens, -1, logit_history, prob_history};
|
||||||
}
|
}
|
||||||
@ -331,7 +331,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||||||
tokens[batch_start] = llama_token_bos(ctx);
|
tokens[batch_start] = llama_token_bos(ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads)) {
|
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0), params.n_threads)) {
|
||||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
return {tokens, -1, logit_history, prob_history};
|
return {tokens, -1, logit_history, prob_history};
|
||||||
}
|
}
|
||||||
@ -409,7 +409,7 @@ static std::vector<float> hellaswag_evaluate_tokens(
|
|||||||
for (size_t i_chunk = 0; i_chunk < n_chunk; ++i_chunk) {
|
for (size_t i_chunk = 0; i_chunk < n_chunk; ++i_chunk) {
|
||||||
size_t n_tokens = tokens.size() - i_chunk * n_batch;
|
size_t n_tokens = tokens.size() - i_chunk * n_batch;
|
||||||
n_tokens = std::min(n_tokens, size_t(n_batch));
|
n_tokens = std::min(n_tokens, size_t(n_batch));
|
||||||
if (llama_eval(ctx, tokens.data() + i_chunk * n_batch, n_tokens, n_past, n_thread)) {
|
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + i_chunk * n_batch, n_tokens, n_past, 0), n_thread)) {
|
||||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
@ -34,11 +34,11 @@ int main(int argc, char ** argv) {
|
|||||||
auto last_n_tokens_data = std::vector<llama_token>(params.repeat_last_n, 0);
|
auto last_n_tokens_data = std::vector<llama_token>(params.repeat_last_n, 0);
|
||||||
|
|
||||||
// init
|
// init
|
||||||
auto model = llama_load_model_from_file(params.model.c_str(), lparams);
|
auto * model = llama_load_model_from_file(params.model.c_str(), lparams);
|
||||||
if (model == nullptr) {
|
if (model == nullptr) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
auto ctx = llama_new_context_with_model(model, lparams);
|
auto * ctx = llama_new_context_with_model(model, lparams);
|
||||||
if (ctx == nullptr) {
|
if (ctx == nullptr) {
|
||||||
llama_free_model(model);
|
llama_free_model(model);
|
||||||
return 1;
|
return 1;
|
||||||
@ -53,7 +53,7 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// evaluate prompt
|
// evaluate prompt
|
||||||
llama_eval(ctx, tokens.data(), n_prompt_tokens, n_past, params.n_threads);
|
llama_decode(ctx, llama_batch_get_one(tokens.data(), n_prompt_tokens, n_past, 0), params.n_threads);
|
||||||
|
|
||||||
last_n_tokens_data.insert(last_n_tokens_data.end(), tokens.data(), tokens.data() + n_prompt_tokens);
|
last_n_tokens_data.insert(last_n_tokens_data.end(), tokens.data(), tokens.data() + n_prompt_tokens);
|
||||||
n_past += n_prompt_tokens;
|
n_past += n_prompt_tokens;
|
||||||
@ -77,7 +77,7 @@ int main(int argc, char ** argv) {
|
|||||||
printf("\n%s", params.prompt.c_str());
|
printf("\n%s", params.prompt.c_str());
|
||||||
|
|
||||||
for (auto i = 0; i < params.n_predict; i++) {
|
for (auto i = 0; i < params.n_predict; i++) {
|
||||||
auto logits = llama_get_logits(ctx);
|
auto * logits = llama_get_logits(ctx);
|
||||||
auto n_vocab = llama_n_vocab(ctx);
|
auto n_vocab = llama_n_vocab(ctx);
|
||||||
std::vector<llama_token_data> candidates;
|
std::vector<llama_token_data> candidates;
|
||||||
candidates.reserve(n_vocab);
|
candidates.reserve(n_vocab);
|
||||||
@ -90,7 +90,7 @@ int main(int argc, char ** argv) {
|
|||||||
last_n_tokens_data.push_back(next_token);
|
last_n_tokens_data.push_back(next_token);
|
||||||
|
|
||||||
printf("%s", next_token_str.c_str());
|
printf("%s", next_token_str.c_str());
|
||||||
if (llama_eval(ctx, &next_token, 1, n_past, params.n_threads)) {
|
if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0), params.n_threads)) {
|
||||||
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
|
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
|
||||||
llama_free(ctx);
|
llama_free(ctx);
|
||||||
llama_free_model(model);
|
llama_free_model(model);
|
||||||
@ -105,7 +105,7 @@ int main(int argc, char ** argv) {
|
|||||||
llama_free(ctx);
|
llama_free(ctx);
|
||||||
|
|
||||||
// make new context
|
// make new context
|
||||||
auto ctx2 = llama_new_context_with_model(model, lparams);
|
auto * ctx2 = llama_new_context_with_model(model, lparams);
|
||||||
|
|
||||||
// Load state (rng, logits, embedding and kv_cache) from file
|
// Load state (rng, logits, embedding and kv_cache) from file
|
||||||
{
|
{
|
||||||
@ -137,7 +137,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
// second run
|
// second run
|
||||||
for (auto i = 0; i < params.n_predict; i++) {
|
for (auto i = 0; i < params.n_predict; i++) {
|
||||||
auto logits = llama_get_logits(ctx2);
|
auto * logits = llama_get_logits(ctx2);
|
||||||
auto n_vocab = llama_n_vocab(ctx2);
|
auto n_vocab = llama_n_vocab(ctx2);
|
||||||
std::vector<llama_token_data> candidates;
|
std::vector<llama_token_data> candidates;
|
||||||
candidates.reserve(n_vocab);
|
candidates.reserve(n_vocab);
|
||||||
@ -150,7 +150,7 @@ int main(int argc, char ** argv) {
|
|||||||
last_n_tokens_data.push_back(next_token);
|
last_n_tokens_data.push_back(next_token);
|
||||||
|
|
||||||
printf("%s", next_token_str.c_str());
|
printf("%s", next_token_str.c_str());
|
||||||
if (llama_eval(ctx2, &next_token, 1, n_past, params.n_threads)) {
|
if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0), params.n_threads)) {
|
||||||
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
|
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
|
||||||
llama_free(ctx2);
|
llama_free(ctx2);
|
||||||
llama_free_model(model);
|
llama_free_model(model);
|
||||||
|
@ -434,7 +434,7 @@ struct llama_server_context
|
|||||||
{
|
{
|
||||||
n_eval = params.n_batch;
|
n_eval = params.n_batch;
|
||||||
}
|
}
|
||||||
if (llama_eval(ctx, &embd[n_past], n_eval, n_past, params.n_threads))
|
if (llama_decode(ctx, llama_batch_get_one(&embd[n_past], n_eval, n_past, 0), params.n_threads))
|
||||||
{
|
{
|
||||||
LOG_ERROR("failed to eval", {
|
LOG_ERROR("failed to eval", {
|
||||||
{"n_eval", n_eval},
|
{"n_eval", n_eval},
|
||||||
|
@ -76,7 +76,7 @@ int main(int argc, char ** argv) {
|
|||||||
while (n_cur < n_gen) {
|
while (n_cur < n_gen) {
|
||||||
// evaluate the transformer
|
// evaluate the transformer
|
||||||
|
|
||||||
if (llama_eval(ctx, tokens_list.data(), int(tokens_list.size()), n_cur, params.n_threads)) {
|
if (llama_decode(ctx, llama_batch_get_one(tokens_list.data(), int(tokens_list.size()), n_cur, 0), params.n_threads)) {
|
||||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
@ -70,9 +70,9 @@ int main(int argc, char ** argv) {
|
|||||||
const auto t_enc_start = ggml_time_us();
|
const auto t_enc_start = ggml_time_us();
|
||||||
|
|
||||||
// eval the prompt with both models
|
// eval the prompt with both models
|
||||||
llama_eval(ctx_tgt, inp.data(), int(inp.size() - 1), 0, params.n_threads);
|
llama_decode(ctx_tgt, llama_batch_get_one( inp.data(), n_input - 1, 0, 0), params.n_threads);
|
||||||
llama_eval(ctx_tgt, &inp.back(), 1, inp.size() - 1, params.n_threads);
|
llama_decode(ctx_tgt, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0), params.n_threads);
|
||||||
llama_eval(ctx_dft, inp.data(), int(inp.size()), 0, params.n_threads);
|
llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input, 0, 0), params.n_threads);
|
||||||
|
|
||||||
const auto t_enc_end = ggml_time_us();
|
const auto t_enc_end = ggml_time_us();
|
||||||
|
|
||||||
@ -172,7 +172,7 @@ int main(int argc, char ** argv) {
|
|||||||
LOG("out of drafted tokens\n");
|
LOG("out of drafted tokens\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_eval(ctx_dft, &id, 1, n_past_dft, params.n_threads);
|
llama_decode(ctx_dft, llama_batch_get_one(&id, 1, n_past_dft, 0), params.n_threads);
|
||||||
++n_past_dft;
|
++n_past_dft;
|
||||||
|
|
||||||
// heuristic for n_draft
|
// heuristic for n_draft
|
||||||
@ -256,7 +256,7 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// evaluate the drafted token on the draft model
|
// evaluate the drafted token on the draft model
|
||||||
llama_eval(ctx_dft, &drafted.back(), 1, n_past_cur, params.n_threads);
|
llama_decode(ctx_dft, llama_batch_get_one(&drafted.back(), 1, n_past_cur, 0), params.n_threads);
|
||||||
++n_past_cur;
|
++n_past_cur;
|
||||||
|
|
||||||
if (grammar_dft != NULL) {
|
if (grammar_dft != NULL) {
|
||||||
@ -265,7 +265,7 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// evaluate the target model on the drafted tokens
|
// evaluate the target model on the drafted tokens
|
||||||
llama_eval(ctx_tgt, drafted.data(), drafted.size(), n_past_tgt, params.n_threads);
|
llama_decode(ctx_tgt, llama_batch_get_one(drafted.data(), drafted.size(), n_past_tgt, 0), params.n_threads);
|
||||||
++n_past_tgt;
|
++n_past_tgt;
|
||||||
|
|
||||||
// the first token is always proposed by the traget model before the speculation loop
|
// the first token is always proposed by the traget model before the speculation loop
|
||||||
|
119
llama.cpp
119
llama.cpp
@ -1265,7 +1265,7 @@ static bool llama_kv_cache_init(
|
|||||||
// updates the cache head
|
// updates the cache head
|
||||||
static bool llama_kv_cache_find_slot(
|
static bool llama_kv_cache_find_slot(
|
||||||
struct llama_kv_cache & cache,
|
struct llama_kv_cache & cache,
|
||||||
struct llama_batch & batch) {
|
const struct llama_batch & batch) {
|
||||||
const uint32_t n_ctx = cache.size;
|
const uint32_t n_ctx = cache.size;
|
||||||
const uint32_t n_tokens = batch.n_tokens;
|
const uint32_t n_tokens = batch.n_tokens;
|
||||||
|
|
||||||
@ -2522,7 +2522,7 @@ static bool llama_model_load(
|
|||||||
|
|
||||||
static struct ggml_cgraph * llm_build_llama(
|
static struct ggml_cgraph * llm_build_llama(
|
||||||
llama_context & lctx,
|
llama_context & lctx,
|
||||||
llama_batch & batch) {
|
const llama_batch & batch) {
|
||||||
const auto & model = lctx.model;
|
const auto & model = lctx.model;
|
||||||
const auto & hparams = model.hparams;
|
const auto & hparams = model.hparams;
|
||||||
|
|
||||||
@ -2876,7 +2876,7 @@ static struct ggml_cgraph * llm_build_llama(
|
|||||||
|
|
||||||
static struct ggml_cgraph * llm_build_baichaun(
|
static struct ggml_cgraph * llm_build_baichaun(
|
||||||
llama_context & lctx,
|
llama_context & lctx,
|
||||||
llama_batch & batch) {
|
const llama_batch & batch) {
|
||||||
const auto & model = lctx.model;
|
const auto & model = lctx.model;
|
||||||
const auto & hparams = model.hparams;
|
const auto & hparams = model.hparams;
|
||||||
|
|
||||||
@ -3247,7 +3247,7 @@ static struct ggml_cgraph * llm_build_baichaun(
|
|||||||
|
|
||||||
static struct ggml_cgraph * llm_build_falcon(
|
static struct ggml_cgraph * llm_build_falcon(
|
||||||
llama_context & lctx,
|
llama_context & lctx,
|
||||||
llama_batch & batch) {
|
const llama_batch & batch) {
|
||||||
const auto & model = lctx.model;
|
const auto & model = lctx.model;
|
||||||
const auto & hparams = model.hparams;
|
const auto & hparams = model.hparams;
|
||||||
|
|
||||||
@ -3577,7 +3577,7 @@ static struct ggml_cgraph * llm_build_falcon(
|
|||||||
|
|
||||||
static struct ggml_cgraph * llm_build_starcoder(
|
static struct ggml_cgraph * llm_build_starcoder(
|
||||||
llama_context & lctx,
|
llama_context & lctx,
|
||||||
llama_batch & batch) {
|
const llama_batch & batch) {
|
||||||
const auto & model = lctx.model;
|
const auto & model = lctx.model;
|
||||||
const auto & hparams = model.hparams;
|
const auto & hparams = model.hparams;
|
||||||
|
|
||||||
@ -3819,7 +3819,7 @@ static struct ggml_cgraph * llm_build_starcoder(
|
|||||||
|
|
||||||
static struct ggml_cgraph * llama_build_graph(
|
static struct ggml_cgraph * llama_build_graph(
|
||||||
llama_context & lctx,
|
llama_context & lctx,
|
||||||
llama_batch & batch) {
|
const llama_batch & batch) {
|
||||||
const auto & model = lctx.model;
|
const auto & model = lctx.model;
|
||||||
|
|
||||||
struct ggml_cgraph * result = NULL;
|
struct ggml_cgraph * result = NULL;
|
||||||
@ -3856,7 +3856,7 @@ static struct ggml_cgraph * llama_build_graph(
|
|||||||
//
|
//
|
||||||
static bool llama_eval_internal(
|
static bool llama_eval_internal(
|
||||||
llama_context & lctx,
|
llama_context & lctx,
|
||||||
llama_batch & batch,
|
llama_batch batch,
|
||||||
int n_threads) {
|
int n_threads) {
|
||||||
const uint32_t n_tokens = batch.n_tokens;
|
const uint32_t n_tokens = batch.n_tokens;
|
||||||
|
|
||||||
@ -3886,6 +3886,31 @@ static bool llama_eval_internal(
|
|||||||
const int64_t n_embd = hparams.n_embd;
|
const int64_t n_embd = hparams.n_embd;
|
||||||
const int64_t n_vocab = hparams.n_vocab;
|
const int64_t n_vocab = hparams.n_vocab;
|
||||||
|
|
||||||
|
std::vector<llama_pos> pos;
|
||||||
|
std::vector<llama_seq_id> seq_id;
|
||||||
|
|
||||||
|
if (batch.pos == nullptr) {
|
||||||
|
pos.resize(n_tokens);
|
||||||
|
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||||
|
pos[i] = batch.all_pos_0 + i*batch.all_pos_1;
|
||||||
|
}
|
||||||
|
|
||||||
|
batch.pos = pos.data();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (batch.seq_id == nullptr) {
|
||||||
|
seq_id.resize(n_tokens);
|
||||||
|
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||||
|
seq_id[i] = batch.all_seq_id;
|
||||||
|
}
|
||||||
|
|
||||||
|
batch.seq_id = seq_id.data();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (batch.clear_kv) {
|
||||||
|
llama_kv_cache_clear(kv_self, 0, -1);
|
||||||
|
}
|
||||||
|
|
||||||
if (!llama_kv_cache_find_slot(kv_self, batch)) {
|
if (!llama_kv_cache_find_slot(kv_self, batch)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -4820,6 +4845,13 @@ struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar)
|
|||||||
// sampling
|
// sampling
|
||||||
//
|
//
|
||||||
|
|
||||||
|
void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed) {
|
||||||
|
if (seed == LLAMA_DEFAULT_SEED) {
|
||||||
|
seed = time(NULL);
|
||||||
|
}
|
||||||
|
ctx->rng.seed(seed);
|
||||||
|
}
|
||||||
|
|
||||||
void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates) {
|
void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates) {
|
||||||
GGML_ASSERT(candidates->size > 0);
|
GGML_ASSERT(candidates->size > 0);
|
||||||
|
|
||||||
@ -5469,7 +5501,7 @@ struct llama_beam_search_data {
|
|||||||
} else {
|
} else {
|
||||||
// beam is not at end-of-sentence, so branch with next top_k tokens.
|
// beam is not at end-of-sentence, so branch with next top_k tokens.
|
||||||
if (!beam.tokens.empty()) {
|
if (!beam.tokens.empty()) {
|
||||||
llama_eval(ctx, beam.tokens.data(), beam.tokens.size(), n_past, n_threads);
|
llama_decode(ctx, llama_batch_get_one(beam.tokens.data(), beam.tokens.size(), n_past, 0), n_threads);
|
||||||
}
|
}
|
||||||
llama_logit_info logit_info(ctx);
|
llama_logit_info logit_info(ctx);
|
||||||
std::vector<llama_token_data> next_tokens = logit_info.top_k(n_beams);
|
std::vector<llama_token_data> next_tokens = logit_info.top_k(n_beams);
|
||||||
@ -5543,7 +5575,7 @@ struct llama_beam_search_data {
|
|||||||
callback(callback_data, get_beams_state(false)); // Sets common_prefix_length
|
callback(callback_data, get_beams_state(false)); // Sets common_prefix_length
|
||||||
update_beams_from_beam_views(); // Update values (p,eob) that callback may have changed.
|
update_beams_from_beam_views(); // Update values (p,eob) that callback may have changed.
|
||||||
if (common_prefix_length) {
|
if (common_prefix_length) {
|
||||||
llama_eval(ctx, beams[0].tokens.data(), common_prefix_length, n_past, n_threads);
|
llama_decode(ctx, llama_batch_get_one(beams[0].tokens.data(), common_prefix_length, n_past, 0), n_threads);
|
||||||
n_past += common_prefix_length;
|
n_past += common_prefix_length;
|
||||||
}
|
}
|
||||||
// Zero-out next_beam probabilities to place them last in following min-heap.
|
// Zero-out next_beam probabilities to place them last in following min-heap.
|
||||||
@ -6505,8 +6537,7 @@ struct llama_context * llama_new_context_with_model(
|
|||||||
// build worst-case graph
|
// build worst-case graph
|
||||||
uint32_t n_tokens = std::min((int)hparams.n_ctx, params.n_batch);
|
uint32_t n_tokens = std::min((int)hparams.n_ctx, params.n_batch);
|
||||||
llama_token token = llama_token_bos(ctx); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
llama_token token = llama_token_bos(ctx); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
||||||
llama_batch batch = { n_tokens, &token, nullptr, nullptr, nullptr };
|
ggml_cgraph * gf = llama_build_graph(*ctx, llama_batch_get_one(&token, n_tokens, 0, 0));
|
||||||
ggml_cgraph * gf = llama_build_graph(*ctx, batch);
|
|
||||||
|
|
||||||
#ifdef GGML_USE_METAL
|
#ifdef GGML_USE_METAL
|
||||||
if (params.n_gpu_layers > 0) {
|
if (params.n_gpu_layers > 0) {
|
||||||
@ -6714,15 +6745,6 @@ void llama_kv_clear(struct llama_context * ctx, int32_t p0, int32_t p1) {
|
|||||||
llama_kv_cache_clear(ctx->kv_self, p0, p1);
|
llama_kv_cache_clear(ctx->kv_self, p0, p1);
|
||||||
}
|
}
|
||||||
|
|
||||||
#define LLAMA_MAX_RNG_STATE (64*1024)
|
|
||||||
|
|
||||||
void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed) {
|
|
||||||
if (seed == LLAMA_DEFAULT_SEED) {
|
|
||||||
seed = time(NULL);
|
|
||||||
}
|
|
||||||
ctx->rng.seed(seed);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns the *maximum* size of the state
|
// Returns the *maximum* size of the state
|
||||||
size_t llama_get_state_size(const struct llama_context * ctx) {
|
size_t llama_get_state_size(const struct llama_context * ctx) {
|
||||||
// we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.
|
// we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.
|
||||||
@ -7116,21 +7138,9 @@ int llama_eval(
|
|||||||
uint32_t n_tokens,
|
uint32_t n_tokens,
|
||||||
int n_past,
|
int n_past,
|
||||||
int n_threads) {
|
int n_threads) {
|
||||||
std::vector<llama_pos> pos(n_tokens);
|
|
||||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
|
||||||
pos[i] = n_past + i;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<llama_seq_id> seq_id(n_tokens);
|
|
||||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
|
||||||
seq_id[i] = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_batch batch = { n_tokens, tokens, nullptr, pos.data(), seq_id.data(), };
|
|
||||||
|
|
||||||
llama_kv_cache_clear(ctx->kv_self, n_past, -1);
|
llama_kv_cache_clear(ctx->kv_self, n_past, -1);
|
||||||
|
|
||||||
if (!llama_eval_internal(*ctx, batch, n_threads)) {
|
if (!llama_eval_internal(*ctx, llama_batch_get_one(tokens, n_tokens, n_past, 0), n_threads)) {
|
||||||
LLAMA_LOG_ERROR("%s: failed to eval\n", __func__);
|
LLAMA_LOG_ERROR("%s: failed to eval\n", __func__);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
@ -7151,18 +7161,47 @@ int llama_eval_embd(
|
|||||||
uint32_t n_tokens,
|
uint32_t n_tokens,
|
||||||
int n_past,
|
int n_past,
|
||||||
int n_threads) {
|
int n_threads) {
|
||||||
std::vector<llama_pos> pos(n_tokens);
|
llama_kv_cache_clear(ctx->kv_self, n_past, -1);
|
||||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
|
||||||
pos[i] = n_past + i;
|
llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, n_past, 1, 0, n_past == 0, };
|
||||||
|
|
||||||
|
if (!llama_eval_internal(*ctx, batch, n_threads)) {
|
||||||
|
LLAMA_LOG_ERROR("%s: failed to eval\n", __func__);
|
||||||
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<llama_seq_id> seq_id(n_tokens);
|
// get a more accurate load time, upon first eval
|
||||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
// TODO: fix this
|
||||||
seq_id[i] = 0;
|
if (!ctx->has_evaluated_once) {
|
||||||
|
ctx->t_load_us = ggml_time_us() - ctx->t_start_us;
|
||||||
|
ctx->has_evaluated_once = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_batch batch = { n_tokens, nullptr, embd, pos.data(), seq_id.data(), };
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct llama_batch llama_batch_get_one(
|
||||||
|
const llama_token * tokens,
|
||||||
|
uint32_t n_tokens,
|
||||||
|
llama_pos pos_0,
|
||||||
|
llama_seq_id seq_id) {
|
||||||
|
return {
|
||||||
|
/*n_tokens =*/ n_tokens,
|
||||||
|
/*tokens =*/ tokens,
|
||||||
|
/*embd =*/ nullptr,
|
||||||
|
/*pos =*/ nullptr,
|
||||||
|
/*seq_id =*/ nullptr,
|
||||||
|
/*all_pos_0 =*/ pos_0,
|
||||||
|
/*all_pos_1 =*/ 1,
|
||||||
|
/*all_seq_id =*/ seq_id,
|
||||||
|
/*clear_kv =*/ pos_0 == 0,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
int llama_decode(
|
||||||
|
struct llama_context * ctx,
|
||||||
|
struct llama_batch batch,
|
||||||
|
int n_threads) {
|
||||||
if (!llama_eval_internal(*ctx, batch, n_threads)) {
|
if (!llama_eval_internal(*ctx, batch, n_threads)) {
|
||||||
LLAMA_LOG_ERROR("%s: failed to eval\n", __func__);
|
LLAMA_LOG_ERROR("%s: failed to eval\n", __func__);
|
||||||
return 1;
|
return 1;
|
||||||
|
45
llama.h
45
llama.h
@ -37,6 +37,8 @@
|
|||||||
|
|
||||||
#define LLAMA_DEFAULT_SEED 0xFFFFFFFF
|
#define LLAMA_DEFAULT_SEED 0xFFFFFFFF
|
||||||
|
|
||||||
|
#define LLAMA_MAX_RNG_STATE (64*1024)
|
||||||
|
|
||||||
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
|
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
|
||||||
|
|
||||||
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
|
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
|
||||||
@ -70,9 +72,20 @@ extern "C" {
|
|||||||
|
|
||||||
// TODO: not sure about these consts - might just get in the way all the time with no benefit
|
// TODO: not sure about these consts - might just get in the way all the time with no benefit
|
||||||
const llama_token * token;
|
const llama_token * token;
|
||||||
const float * embd;
|
const float * embd;
|
||||||
const llama_pos * pos;
|
const llama_pos * pos;
|
||||||
const llama_seq_id * seq_id;
|
const llama_seq_id * seq_id;
|
||||||
|
|
||||||
|
// NOTE: helpers for smooth API transition - can be deprecated in the future
|
||||||
|
// for future-proof code, use the above fields instead and ignore everything below
|
||||||
|
//
|
||||||
|
// pos[i] = all_pos_0 + i*all_pos_1
|
||||||
|
//
|
||||||
|
llama_pos all_pos_0; // used if pos == NULL
|
||||||
|
llama_pos all_pos_1; // used if pos == NULL
|
||||||
|
llama_seq_id all_seq_id; // used if seq_id == NULL
|
||||||
|
|
||||||
|
bool clear_kv; // if true, clear the entire KV cache. common usage for perplexity calculations
|
||||||
} llama_seq;
|
} llama_seq;
|
||||||
|
|
||||||
enum llama_log_level {
|
enum llama_log_level {
|
||||||
@ -312,9 +325,6 @@ extern "C" {
|
|||||||
|
|
||||||
LLAMA_API void llama_kv_clear(struct llama_context * ctx, int32_t p0, int32_t p1);
|
LLAMA_API void llama_kv_clear(struct llama_context * ctx, int32_t p0, int32_t p1);
|
||||||
|
|
||||||
// Sets the current rng seed.
|
|
||||||
LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed);
|
|
||||||
|
|
||||||
// Returns the maximum size in bytes of the state (rng, logits, embedding
|
// Returns the maximum size in bytes of the state (rng, logits, embedding
|
||||||
// and kv_cache) - will often be smaller after compacting tokens
|
// and kv_cache) - will often be smaller after compacting tokens
|
||||||
LLAMA_API size_t llama_get_state_size(const struct llama_context * ctx);
|
LLAMA_API size_t llama_get_state_size(const struct llama_context * ctx);
|
||||||
@ -336,19 +346,37 @@ extern "C" {
|
|||||||
// tokens + n_tokens is the provided batch of new tokens to process
|
// tokens + n_tokens is the provided batch of new tokens to process
|
||||||
// n_past is the number of tokens to use from previous eval calls
|
// n_past is the number of tokens to use from previous eval calls
|
||||||
// Returns 0 on success
|
// Returns 0 on success
|
||||||
LLAMA_API int llama_eval(
|
LLAMA_API DEPRECATED(int llama_eval(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
const llama_token * tokens,
|
const llama_token * tokens,
|
||||||
uint32_t n_tokens,
|
uint32_t n_tokens,
|
||||||
int n_past,
|
int n_past,
|
||||||
int n_threads);
|
int n_threads),
|
||||||
|
"please use llama_decode() instead");
|
||||||
|
|
||||||
// Same as llama_eval, but use float matrix input directly.
|
// Same as llama_eval, but use float matrix input directly.
|
||||||
LLAMA_API int llama_eval_embd(
|
LLAMA_API DEPRECATED(int llama_eval_embd(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
const float * embd,
|
const float * embd,
|
||||||
uint32_t n_tokens,
|
uint32_t n_tokens,
|
||||||
int n_past,
|
int n_past,
|
||||||
|
int n_threads),
|
||||||
|
"please use llama_decode() instead");
|
||||||
|
|
||||||
|
// Return batch for single sequence of tokens starting at pos_0
|
||||||
|
// If pos_0 == 0, the clear_kv flag will be auto set to true
|
||||||
|
//
|
||||||
|
// NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it
|
||||||
|
//
|
||||||
|
LLAMA_API struct llama_batch llama_batch_get_one(
|
||||||
|
const llama_token * tokens,
|
||||||
|
uint32_t n_tokens,
|
||||||
|
llama_pos pos_0,
|
||||||
|
llama_seq_id seq_id);
|
||||||
|
|
||||||
|
LLAMA_API int llama_decode(
|
||||||
|
struct llama_context * ctx,
|
||||||
|
struct llama_batch batch,
|
||||||
int n_threads);
|
int n_threads);
|
||||||
|
|
||||||
// Token logits obtained from the last call to llama_eval()
|
// Token logits obtained from the last call to llama_eval()
|
||||||
@ -434,6 +462,9 @@ extern "C" {
|
|||||||
// Sampling functions
|
// Sampling functions
|
||||||
//
|
//
|
||||||
|
|
||||||
|
// Sets the current rng seed.
|
||||||
|
LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed);
|
||||||
|
|
||||||
/// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
|
/// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
|
||||||
LLAMA_API void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float penalty);
|
LLAMA_API void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float penalty);
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user