From e3c337d87ca650972105a51c6ce302dd236c07ad Mon Sep 17 00:00:00 2001 From: Rick G <26732651+TheFlipbook@users.noreply.github.com> Date: Mon, 8 Apr 2024 06:02:30 -0700 Subject: [PATCH] llama : support negative ith in llama_get_ API (#6519) * llama_sampling_sample with default args is more naively usable * Batches populated by either llama_batch_get_one or llama_batch_add work with default args * Previously get_one could use the default argument * Previously add should usually have used the last index where logits[idx] == true * This hopefully encourages the use of llama_batch_add * By giving expected results when using default arguments. * Adds "negative indexing" feature to llama_get_logits_ith and llama_get_embeddings_ith * Believed to work with any currently well behaved program * Default arg now works for both cases (previously would give strange results for add case) * Any non-negative number is unaffected and behaves as previously * Negative arguments were previously invalid. * Implemented as a special case of indexing as suggested by @compilade in https://github.com/ggerganov/llama.cpp/pull/6519 * Fixed mismatch type errors * cited in macOS CI tests * Missed in original updates based on PR feedback in https://github.com/ggerganov/llama.cpp/pull/6519 --- common/sampling.h | 2 +- llama.cpp | 38 +++++++++++++++++++++++++++++--------- llama.h | 6 ++++-- 3 files changed, 34 insertions(+), 12 deletions(-) diff --git a/common/sampling.h b/common/sampling.h index 56ed991b8..639b819ab 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -129,7 +129,7 @@ llama_token llama_sampling_sample( struct llama_sampling_context * ctx_sampling, struct llama_context * ctx_main, struct llama_context * ctx_cfg, - int idx = 0); + int idx = -1); // Prepares and adjusts the set of token candidates for sampling based on penalties, biases, and sampling parameters. llama_token_data_array llama_sampling_prepare( diff --git a/llama.cpp b/llama.cpp index 96d75518f..9dde3efd0 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2177,7 +2177,7 @@ struct llama_context { std::vector output_ids; // map batch token positions to ids of the logits and embd buffers size_t output_size = 0; // capacity (of tokens positions) for the output buffers - int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch + int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch bool logits_all = false; @@ -10411,6 +10411,9 @@ static int llama_decode_internal( n_outputs_prev += lctx.n_outputs; } + // set to total number of outputs in the batch, for use in llama_get_logits_ith + lctx.n_outputs = n_outputs; + // wait for the computation to finish (automatically done when obtaining the model output) //llama_synchronize(&lctx); @@ -15944,23 +15947,31 @@ float * llama_get_logits(struct llama_context * ctx) { } float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) { + int32_t j = -1; llama_synchronize(ctx); try { if (ctx->logits == nullptr) { throw std::runtime_error("no logits"); } - if ((size_t) i >= ctx->output_ids.size()) { + + if (i < 0) { + j = ctx->n_outputs + i; + if (j < 0) { + throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs)); + } + } else if ((size_t) i >= ctx->output_ids.size()) { throw std::runtime_error(format("out of range [0, %lu)", ctx->output_ids.size())); + } else { + j = ctx->output_ids[i]; } - const int32_t j = ctx->output_ids[i]; if (j < 0) { throw std::runtime_error(format("batch.logits[%d] != true", i)); } - if ((size_t) j >= ctx->output_size) { + if (j >= ctx->n_outputs) { // This should not happen - throw std::runtime_error(format("corrupt output buffer (j=%d, output_size=%lu)", j, ctx->output_size)); + throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs)); } return ctx->logits + j*ctx->model.hparams.n_vocab; @@ -15980,23 +15991,32 @@ float * llama_get_embeddings(struct llama_context * ctx) { } float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) { + int32_t j = -1; + llama_synchronize(ctx); try { if (ctx->embd == nullptr) { throw std::runtime_error("no embeddings"); } - if ((size_t) i >= ctx->output_ids.size()) { + + if (i < 0) { + j = ctx->n_outputs + i; + if (j < 0) { + throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs)); + } + } else if ((size_t) i >= ctx->output_ids.size()) { throw std::runtime_error(format("out of range [0, %lu)", ctx->output_ids.size())); + } else { + j = ctx->output_ids[i]; } - const int32_t j = ctx->output_ids[i]; if (j < 0) { throw std::runtime_error(format("batch.logits[%d] != true", i)); } - if ((size_t) j >= ctx->output_size) { + if (j >= ctx->n_outputs) { // This should not happen - throw std::runtime_error(format("corrupt output buffer (j=%d, output_size=%lu)", j, ctx->output_size)); + throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs)); } return ctx->embd + j*ctx->model.hparams.n_embd; diff --git a/llama.h b/llama.h index 2250130e2..6a5bbe26d 100644 --- a/llama.h +++ b/llama.h @@ -747,8 +747,9 @@ extern "C" { // Cols: n_vocab LLAMA_API float * llama_get_logits(struct llama_context * ctx); - // Logits for the ith token. Equivalent to: + // Logits for the ith token. For positive indices, Equivalent to: // llama_get_logits(ctx) + ctx->output_ids[i]*n_vocab + // Negative indicies can be used to access logits in reverse order, -1 is the last logit. // returns NULL for invalid ids. LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i); @@ -760,8 +761,9 @@ extern "C" { // Otherwise, returns NULL. LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); - // Get the embeddings for the ith token. Equivalent to: + // Get the embeddings for the ith token. For positive indices, Equivalent to: // llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd + // Negative indicies can be used to access embeddings in reverse order, -1 is the last embedding. // shape: [n_embd] (1-dimensional) // returns NULL for invalid ids. LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);