mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-11-14 06:49:54 +00:00
llama : support negative ith in llama_get_ API (#6519)
* llama_sampling_sample with default args is more naively usable * Batches populated by either llama_batch_get_one or llama_batch_add work with default args * Previously get_one could use the default argument * Previously add should usually have used the last index where logits[idx] == true * This hopefully encourages the use of llama_batch_add * By giving expected results when using default arguments. * Adds "negative indexing" feature to llama_get_logits_ith and llama_get_embeddings_ith * Believed to work with any currently well behaved program * Default arg now works for both cases (previously would give strange results for add case) * Any non-negative number is unaffected and behaves as previously * Negative arguments were previously invalid. * Implemented as a special case of indexing as suggested by @compilade in https://github.com/ggerganov/llama.cpp/pull/6519 * Fixed mismatch type errors * cited in macOS CI tests * Missed in original updates based on PR feedback in https://github.com/ggerganov/llama.cpp/pull/6519
This commit is contained in:
parent
beea6e1b16
commit
e3c337d87c
@ -129,7 +129,7 @@ llama_token llama_sampling_sample(
|
|||||||
struct llama_sampling_context * ctx_sampling,
|
struct llama_sampling_context * ctx_sampling,
|
||||||
struct llama_context * ctx_main,
|
struct llama_context * ctx_main,
|
||||||
struct llama_context * ctx_cfg,
|
struct llama_context * ctx_cfg,
|
||||||
int idx = 0);
|
int idx = -1);
|
||||||
|
|
||||||
// Prepares and adjusts the set of token candidates for sampling based on penalties, biases, and sampling parameters.
|
// Prepares and adjusts the set of token candidates for sampling based on penalties, biases, and sampling parameters.
|
||||||
llama_token_data_array llama_sampling_prepare(
|
llama_token_data_array llama_sampling_prepare(
|
||||||
|
38
llama.cpp
38
llama.cpp
@ -2177,7 +2177,7 @@ struct llama_context {
|
|||||||
|
|
||||||
std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
|
std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
|
||||||
size_t output_size = 0; // capacity (of tokens positions) for the output buffers
|
size_t output_size = 0; // capacity (of tokens positions) for the output buffers
|
||||||
int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch
|
int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
|
||||||
|
|
||||||
bool logits_all = false;
|
bool logits_all = false;
|
||||||
|
|
||||||
@ -10411,6 +10411,9 @@ static int llama_decode_internal(
|
|||||||
n_outputs_prev += lctx.n_outputs;
|
n_outputs_prev += lctx.n_outputs;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// set to total number of outputs in the batch, for use in llama_get_logits_ith
|
||||||
|
lctx.n_outputs = n_outputs;
|
||||||
|
|
||||||
// wait for the computation to finish (automatically done when obtaining the model output)
|
// wait for the computation to finish (automatically done when obtaining the model output)
|
||||||
//llama_synchronize(&lctx);
|
//llama_synchronize(&lctx);
|
||||||
|
|
||||||
@ -15944,23 +15947,31 @@ float * llama_get_logits(struct llama_context * ctx) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
|
float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
|
||||||
|
int32_t j = -1;
|
||||||
llama_synchronize(ctx);
|
llama_synchronize(ctx);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
if (ctx->logits == nullptr) {
|
if (ctx->logits == nullptr) {
|
||||||
throw std::runtime_error("no logits");
|
throw std::runtime_error("no logits");
|
||||||
}
|
}
|
||||||
if ((size_t) i >= ctx->output_ids.size()) {
|
|
||||||
|
if (i < 0) {
|
||||||
|
j = ctx->n_outputs + i;
|
||||||
|
if (j < 0) {
|
||||||
|
throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs));
|
||||||
|
}
|
||||||
|
} else if ((size_t) i >= ctx->output_ids.size()) {
|
||||||
throw std::runtime_error(format("out of range [0, %lu)", ctx->output_ids.size()));
|
throw std::runtime_error(format("out of range [0, %lu)", ctx->output_ids.size()));
|
||||||
|
} else {
|
||||||
|
j = ctx->output_ids[i];
|
||||||
}
|
}
|
||||||
const int32_t j = ctx->output_ids[i];
|
|
||||||
|
|
||||||
if (j < 0) {
|
if (j < 0) {
|
||||||
throw std::runtime_error(format("batch.logits[%d] != true", i));
|
throw std::runtime_error(format("batch.logits[%d] != true", i));
|
||||||
}
|
}
|
||||||
if ((size_t) j >= ctx->output_size) {
|
if (j >= ctx->n_outputs) {
|
||||||
// This should not happen
|
// This should not happen
|
||||||
throw std::runtime_error(format("corrupt output buffer (j=%d, output_size=%lu)", j, ctx->output_size));
|
throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
|
||||||
}
|
}
|
||||||
|
|
||||||
return ctx->logits + j*ctx->model.hparams.n_vocab;
|
return ctx->logits + j*ctx->model.hparams.n_vocab;
|
||||||
@ -15980,23 +15991,32 @@ float * llama_get_embeddings(struct llama_context * ctx) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
|
float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
|
||||||
|
int32_t j = -1;
|
||||||
|
|
||||||
llama_synchronize(ctx);
|
llama_synchronize(ctx);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
if (ctx->embd == nullptr) {
|
if (ctx->embd == nullptr) {
|
||||||
throw std::runtime_error("no embeddings");
|
throw std::runtime_error("no embeddings");
|
||||||
}
|
}
|
||||||
if ((size_t) i >= ctx->output_ids.size()) {
|
|
||||||
|
if (i < 0) {
|
||||||
|
j = ctx->n_outputs + i;
|
||||||
|
if (j < 0) {
|
||||||
|
throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs));
|
||||||
|
}
|
||||||
|
} else if ((size_t) i >= ctx->output_ids.size()) {
|
||||||
throw std::runtime_error(format("out of range [0, %lu)", ctx->output_ids.size()));
|
throw std::runtime_error(format("out of range [0, %lu)", ctx->output_ids.size()));
|
||||||
|
} else {
|
||||||
|
j = ctx->output_ids[i];
|
||||||
}
|
}
|
||||||
const int32_t j = ctx->output_ids[i];
|
|
||||||
|
|
||||||
if (j < 0) {
|
if (j < 0) {
|
||||||
throw std::runtime_error(format("batch.logits[%d] != true", i));
|
throw std::runtime_error(format("batch.logits[%d] != true", i));
|
||||||
}
|
}
|
||||||
if ((size_t) j >= ctx->output_size) {
|
if (j >= ctx->n_outputs) {
|
||||||
// This should not happen
|
// This should not happen
|
||||||
throw std::runtime_error(format("corrupt output buffer (j=%d, output_size=%lu)", j, ctx->output_size));
|
throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
|
||||||
}
|
}
|
||||||
|
|
||||||
return ctx->embd + j*ctx->model.hparams.n_embd;
|
return ctx->embd + j*ctx->model.hparams.n_embd;
|
||||||
|
6
llama.h
6
llama.h
@ -747,8 +747,9 @@ extern "C" {
|
|||||||
// Cols: n_vocab
|
// Cols: n_vocab
|
||||||
LLAMA_API float * llama_get_logits(struct llama_context * ctx);
|
LLAMA_API float * llama_get_logits(struct llama_context * ctx);
|
||||||
|
|
||||||
// Logits for the ith token. Equivalent to:
|
// Logits for the ith token. For positive indices, Equivalent to:
|
||||||
// llama_get_logits(ctx) + ctx->output_ids[i]*n_vocab
|
// llama_get_logits(ctx) + ctx->output_ids[i]*n_vocab
|
||||||
|
// Negative indicies can be used to access logits in reverse order, -1 is the last logit.
|
||||||
// returns NULL for invalid ids.
|
// returns NULL for invalid ids.
|
||||||
LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i);
|
LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i);
|
||||||
|
|
||||||
@ -760,8 +761,9 @@ extern "C" {
|
|||||||
// Otherwise, returns NULL.
|
// Otherwise, returns NULL.
|
||||||
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
|
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
|
||||||
|
|
||||||
// Get the embeddings for the ith token. Equivalent to:
|
// Get the embeddings for the ith token. For positive indices, Equivalent to:
|
||||||
// llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd
|
// llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd
|
||||||
|
// Negative indicies can be used to access embeddings in reverse order, -1 is the last embedding.
|
||||||
// shape: [n_embd] (1-dimensional)
|
// shape: [n_embd] (1-dimensional)
|
||||||
// returns NULL for invalid ids.
|
// returns NULL for invalid ids.
|
||||||
LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
|
LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
|
||||||
|
Loading…
Reference in New Issue
Block a user